Skip to content

Commit

Permalink
Marcopremier fix confirmcomputationparticipant (#1272)
Browse files Browse the repository at this point in the history
PR to solve issue #1269

---------

Co-authored-by: marcopremier <[email protected]>
  • Loading branch information
Marco-Premier and marcopremier authored Oct 11, 2023
1 parent 3beb156 commit 192c5ee
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@

package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers

import com.google.cloud.spanner.Statement
import com.google.cloud.spanner.Value
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.bufferUpdateMutation
import org.wfanet.measurement.gcloud.spanner.set
import org.wfanet.measurement.gcloud.spanner.statement
import org.wfanet.measurement.internal.kingdom.ComputationParticipant
import org.wfanet.measurement.internal.kingdom.ConfirmComputationParticipantRequest
import org.wfanet.measurement.internal.kingdom.Measurement
Expand Down Expand Up @@ -106,13 +111,17 @@ class ConfirmComputationParticipant(private val request: ConfirmComputationParti
set("State" to NEXT_COMPUTATION_PARTICIPANT_STATE)
}

val otherDuchyIds: List<InternalId> =
DuchyIds.entries.map { InternalId(it.internalDuchyId) }.filter { it.value != duchyId }
val duchyIds: List<InternalId> =
getComputationParticipantsDuchyIds(
InternalId(measurementConsumerId),
InternalId(measurementId)
)
.filter { it.value != duchyId }

if (
computationParticipantsInState(
transactionContext,
otherDuchyIds,
duchyIds,
InternalId(measurementConsumerId),
InternalId(measurementId),
NEXT_COMPUTATION_PARTICIPANT_STATE
Expand All @@ -136,6 +145,29 @@ class ConfirmComputationParticipant(private val request: ConfirmComputationParti
return computationParticipant.copy { state = NEXT_COMPUTATION_PARTICIPANT_STATE }
}

private suspend fun TransactionScope.getComputationParticipantsDuchyIds(
measurementConsumerId: InternalId,
measurementId: InternalId
): List<InternalId> {
val sql =
"""
SELECT DuchyId
FROM ComputationParticipants
WHERE MeasurementConsumerId = @measurement_consumer_id
AND MeasurementId = @measurement_id
"""
.trimIndent()
val statement: Statement =
statement(sql) {
bind("measurement_consumer_id" to measurementConsumerId.value)
bind("measurement_id" to measurementId)
}
return transactionContext
.executeQuery(statement)
.map { InternalId(it.getLong("DuchyId")) }
.toList()
}

override fun ResultScope<ComputationParticipant>.buildResult(): ComputationParticipant {
return checkNotNull(transactionResult).copy { updateTime = commitTimestamp.toProto() }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
accountsService = services.accountsService
}

private fun createDuchyCertificates() {
val externalDuchyIds = DUCHIES.map { it.externalDuchyId }
private fun createDuchyCertificates(
externalDuchyIds: List<String> = DUCHIES.map { it.externalDuchyId }
) {
duchyCertificates =
externalDuchyIds.associateWith { externalDuchyId ->
runBlocking { population.createDuchyCertificate(certificatesService, externalDuchyId) }
Expand Down Expand Up @@ -699,6 +700,87 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
assertThat(updatedMeasurement.state).isEqualTo(Measurement.State.PENDING_COMPUTATION)
}

@Test
fun `confirmComputationParticipant succeeds with a subset of registered duchies`() = runBlocking {
val duchies = DUCHIES.dropLast(1)
createDuchyCertificates(duchies.map { it.externalDuchyId })
val measurement =
population.createComputedMeasurement(
measurementsService,
population.createMeasurementConsumer(measurementConsumersService, accountsService),
"measurement",
population.createDataProvider(
dataProvidersService,
listOf(Population.AGGREGATOR_DUCHY.externalDuchyId)
),
)

val setParticipantRequisitionParamsDetails = liquidLegionsV2Details {
elGamalPublicKey = EL_GAMAL_PUBLIC_KEY
elGamalPublicKeySignature = EL_GAMAL_PUBLIC_KEY_SIGNATURE
}

// Step 1 - SetParticipantRequisitionParams for all ComputationParticipants. This transitions
// the measurement state to PENDING_REQUISITION_FULFILLMENT.
for (duchyCertificate in duchyCertificates.values) {
computationParticipantsService.setParticipantRequisitionParams(
setParticipantRequisitionParamsRequest {
externalComputationId = measurement.externalComputationId
externalDuchyId = duchyCertificate.externalDuchyId
externalDuchyCertificateId = duchyCertificate.externalCertificateId
liquidLegionsV2 = setParticipantRequisitionParamsDetails
}
)
}

val requisitions =
requisitionsService
.streamRequisitions(
streamRequisitionsRequest {
filter = filter {
externalMeasurementConsumerId = measurement.externalMeasurementConsumerId
externalMeasurementId = measurement.externalMeasurementId
}
}
)
.toList()

// Step 2 - FulfillRequisitions for all Requisitions. This transitions the measurement state to
// PENDING_PARTICIPANT_CONFIRMATION.
val nonce = 3127743798281582205L
for ((requisition, duchy) in requisitions zip duchies) {
requisitionsService.fulfillRequisition(
fulfillRequisitionRequest {
externalRequisitionId = requisition.externalRequisitionId
this.nonce = nonce
computedParams = computedRequisitionParams {
externalComputationId = measurement.externalComputationId
externalFulfillingDuchyId = duchy.externalDuchyId
}
}
)
}

// Step 3 - ConfirmComputationParticipant for all ComputationParticipants. This transitions
// the measurement state to PENDING_COMPUTATION.
for (duchy in duchies) {
computationParticipantsService.confirmComputationParticipant(
confirmComputationParticipantRequest {
externalComputationId = measurement.externalComputationId
this.externalDuchyId = duchy.externalDuchyId
}
)
}

val updatedMeasurement =
measurementsService.getMeasurementByComputationId(
getMeasurementByComputationIdRequest {
this.externalComputationId = measurement.externalComputationId
}
)
assertThat(updatedMeasurement.state).isEqualTo(Measurement.State.PENDING_COMPUTATION)
}

@Test
fun `failComputationParticipant fails due to illegal measurement state`() = runBlocking {
createDuchyCertificates()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class Population(val clock: Clock, val idGenerator: IdGenerator) {

suspend fun createDataProvider(
dataProvidersService: DataProvidersCoroutineImplBase,
requiredDuchiesList: List<String> =
listOf(WORKER1_DUCHY.externalDuchyId, WORKER2_DUCHY.externalDuchyId),
notValidBefore: Instant = clock.instant(),
notValidAfter: Instant = notValidBefore.plus(365L, ChronoUnit.DAYS),
customize: (DataProviderKt.Dsl.() -> Unit)? = null
Expand All @@ -165,8 +167,7 @@ class Population(val clock: Clock, val idGenerator: IdGenerator) {
publicKeySignature = "EDP public key signature".toByteStringUtf8()
publicKeySignatureAlgorithmOid = "2.9999"
}
requiredExternalDuchyIds += WORKER1_DUCHY.externalDuchyId
requiredExternalDuchyIds += WORKER2_DUCHY.externalDuchyId
requiredExternalDuchyIds += requiredDuchiesList
customize?.invoke(this)
}
)
Expand Down

0 comments on commit 192c5ee

Please sign in to comment.