Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marcopremier fix confirmcomputationparticipant #1272

Merged
merged 7 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@

package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers

import com.google.cloud.spanner.Statement
import com.google.cloud.spanner.Value
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.bufferUpdateMutation
import org.wfanet.measurement.gcloud.spanner.set
import org.wfanet.measurement.gcloud.spanner.statement
import org.wfanet.measurement.internal.kingdom.ComputationParticipant
import org.wfanet.measurement.internal.kingdom.ConfirmComputationParticipantRequest
import org.wfanet.measurement.internal.kingdom.Measurement
Expand Down Expand Up @@ -106,13 +111,17 @@ class ConfirmComputationParticipant(private val request: ConfirmComputationParti
set("State" to NEXT_COMPUTATION_PARTICIPANT_STATE)
}

val otherDuchyIds: List<InternalId> =
DuchyIds.entries.map { InternalId(it.internalDuchyId) }.filter { it.value != duchyId }
val duchyIds: List<InternalId> =
getComputationParticipantsDuchyIds(
InternalId(measurementConsumerId),
InternalId(measurementId)
)
.filter { it.value != duchyId }

if (
computationParticipantsInState(
transactionContext,
otherDuchyIds,
duchyIds,
InternalId(measurementConsumerId),
InternalId(measurementId),
NEXT_COMPUTATION_PARTICIPANT_STATE
Expand All @@ -136,6 +145,29 @@ class ConfirmComputationParticipant(private val request: ConfirmComputationParti
return computationParticipant.copy { state = NEXT_COMPUTATION_PARTICIPANT_STATE }
}

private suspend fun TransactionScope.getComputationParticipantsDuchyIds(
measurementConsumerId: InternalId,
measurementId: InternalId
): List<InternalId> {
val sql =
"""
SELECT DuchyId
FROM ComputationParticipants
WHERE MeasurementConsumerId = @measurement_consumer_id
AND MeasurementId = @measurement_id
"""
.trimIndent()
val statement: Statement =
statement(sql) {
bind("measurement_consumer_id" to measurementConsumerId.value)
bind("measurement_id" to measurementId)
}
return transactionContext
.executeQuery(statement)
.map { InternalId(it.getLong("DuchyId")) }
.toList()
}

override fun ResultScope<ComputationParticipant>.buildResult(): ComputationParticipant {
return checkNotNull(transactionResult).copy { updateTime = commitTimestamp.toProto() }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
accountsService = services.accountsService
}

private fun createDuchyCertificates() {
val externalDuchyIds = DUCHIES.map { it.externalDuchyId }
private fun createDuchyCertificates(
externalDuchyIds: List<String> = DUCHIES.map { it.externalDuchyId }
) {
duchyCertificates =
externalDuchyIds.associateWith { externalDuchyId ->
runBlocking { population.createDuchyCertificate(certificatesService, externalDuchyId) }
Expand Down Expand Up @@ -699,6 +700,87 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
assertThat(updatedMeasurement.state).isEqualTo(Measurement.State.PENDING_COMPUTATION)
}

@Test
fun `confirmComputationParticipant succeeds with a subset of registered duchies`() = runBlocking {
val duchies = DUCHIES.dropLast(1)
createDuchyCertificates(duchies.map { it.externalDuchyId })
val measurement =
population.createComputedMeasurement(
measurementsService,
population.createMeasurementConsumer(measurementConsumersService, accountsService),
"measurement",
population.createDataProvider(
dataProvidersService,
listOf(Population.AGGREGATOR_DUCHY.externalDuchyId)
),
)

val setParticipantRequisitionParamsDetails = liquidLegionsV2Details {
elGamalPublicKey = EL_GAMAL_PUBLIC_KEY
elGamalPublicKeySignature = EL_GAMAL_PUBLIC_KEY_SIGNATURE
}

// Step 1 - SetParticipantRequisitionParams for all ComputationParticipants. This transitions
// the measurement state to PENDING_REQUISITION_FULFILLMENT.
for (duchyCertificate in duchyCertificates.values) {
computationParticipantsService.setParticipantRequisitionParams(
setParticipantRequisitionParamsRequest {
externalComputationId = measurement.externalComputationId
externalDuchyId = duchyCertificate.externalDuchyId
externalDuchyCertificateId = duchyCertificate.externalCertificateId
liquidLegionsV2 = setParticipantRequisitionParamsDetails
}
)
}

val requisitions =
requisitionsService
.streamRequisitions(
streamRequisitionsRequest {
filter = filter {
externalMeasurementConsumerId = measurement.externalMeasurementConsumerId
externalMeasurementId = measurement.externalMeasurementId
}
}
)
.toList()

// Step 2 - FulfillRequisitions for all Requisitions. This transitions the measurement state to
// PENDING_PARTICIPANT_CONFIRMATION.
val nonce = 3127743798281582205L
for ((requisition, duchy) in requisitions zip duchies) {
requisitionsService.fulfillRequisition(
fulfillRequisitionRequest {
externalRequisitionId = requisition.externalRequisitionId
this.nonce = nonce
computedParams = computedRequisitionParams {
externalComputationId = measurement.externalComputationId
externalFulfillingDuchyId = duchy.externalDuchyId
}
}
)
}

// Step 3 - ConfirmComputationParticipant for all ComputationParticipants. This transitions
// the measurement state to PENDING_COMPUTATION.
for (duchy in duchies) {
computationParticipantsService.confirmComputationParticipant(
confirmComputationParticipantRequest {
externalComputationId = measurement.externalComputationId
this.externalDuchyId = duchy.externalDuchyId
}
)
}

val updatedMeasurement =
measurementsService.getMeasurementByComputationId(
getMeasurementByComputationIdRequest {
this.externalComputationId = measurement.externalComputationId
}
)
assertThat(updatedMeasurement.state).isEqualTo(Measurement.State.PENDING_COMPUTATION)
}

@Test
fun `failComputationParticipant fails due to illegal measurement state`() = runBlocking {
createDuchyCertificates()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class Population(val clock: Clock, val idGenerator: IdGenerator) {

suspend fun createDataProvider(
dataProvidersService: DataProvidersCoroutineImplBase,
requiredDuchiesList: List<String> =
listOf(WORKER1_DUCHY.externalDuchyId, WORKER2_DUCHY.externalDuchyId),
notValidBefore: Instant = clock.instant(),
notValidAfter: Instant = notValidBefore.plus(365L, ChronoUnit.DAYS),
customize: (DataProviderKt.Dsl.() -> Unit)? = null
Expand All @@ -165,8 +167,7 @@ class Population(val clock: Clock, val idGenerator: IdGenerator) {
publicKeySignature = "EDP public key signature".toByteStringUtf8()
publicKeySignatureAlgorithmOid = "2.9999"
}
requiredExternalDuchyIds += WORKER1_DUCHY.externalDuchyId
requiredExternalDuchyIds += WORKER2_DUCHY.externalDuchyId
requiredExternalDuchyIds += requiredDuchiesList
customize?.invoke(this)
}
)
Expand Down