diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/ConfirmComputationParticipant.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/ConfirmComputationParticipant.kt index 34f627c7b03..a050ba9609b 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/ConfirmComputationParticipant.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/ConfirmComputationParticipant.kt @@ -14,11 +14,16 @@ package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers +import com.google.cloud.spanner.Statement import com.google.cloud.spanner.Value +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.InternalId +import org.wfanet.measurement.gcloud.spanner.bind import org.wfanet.measurement.gcloud.spanner.bufferUpdateMutation import org.wfanet.measurement.gcloud.spanner.set +import org.wfanet.measurement.gcloud.spanner.statement import org.wfanet.measurement.internal.kingdom.ComputationParticipant import org.wfanet.measurement.internal.kingdom.ConfirmComputationParticipantRequest import org.wfanet.measurement.internal.kingdom.Measurement @@ -106,13 +111,17 @@ class ConfirmComputationParticipant(private val request: ConfirmComputationParti set("State" to NEXT_COMPUTATION_PARTICIPANT_STATE) } - val otherDuchyIds: List = - DuchyIds.entries.map { InternalId(it.internalDuchyId) }.filter { it.value != duchyId } + val duchyIds: List = + getComputationParticipantsDuchyIds( + InternalId(measurementConsumerId), + InternalId(measurementId) + ) + .filter { it.value != duchyId } if ( computationParticipantsInState( transactionContext, - otherDuchyIds, + duchyIds, InternalId(measurementConsumerId), InternalId(measurementId), NEXT_COMPUTATION_PARTICIPANT_STATE @@ -136,6 +145,29 @@ class ConfirmComputationParticipant(private val request: ConfirmComputationParti return computationParticipant.copy { state = NEXT_COMPUTATION_PARTICIPANT_STATE } } + private suspend fun TransactionScope.getComputationParticipantsDuchyIds( + measurementConsumerId: InternalId, + measurementId: InternalId + ): List { + val sql = + """ + SELECT DuchyId + FROM ComputationParticipants + WHERE MeasurementConsumerId = @measurement_consumer_id + AND MeasurementId = @measurement_id + """ + .trimIndent() + val statement: Statement = + statement(sql) { + bind("measurement_consumer_id" to measurementConsumerId.value) + bind("measurement_id" to measurementId) + } + return transactionContext + .executeQuery(statement) + .map { InternalId(it.getLong("DuchyId")) } + .toList() + } + override fun ResultScope.buildResult(): ComputationParticipant { return checkNotNull(transactionResult).copy { updateTime = commitTimestamp.toProto() } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ComputationParticipantsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ComputationParticipantsServiceTest.kt index 2d06fa6aca0..35149b89917 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ComputationParticipantsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ComputationParticipantsServiceTest.kt @@ -129,8 +129,9 @@ abstract class ComputationParticipantsServiceTest = DUCHIES.map { it.externalDuchyId } + ) { duchyCertificates = externalDuchyIds.associateWith { externalDuchyId -> runBlocking { population.createDuchyCertificate(certificatesService, externalDuchyId) } @@ -699,6 +700,87 @@ abstract class ComputationParticipantsServiceTest = + listOf(WORKER1_DUCHY.externalDuchyId, WORKER2_DUCHY.externalDuchyId), notValidBefore: Instant = clock.instant(), notValidAfter: Instant = notValidBefore.plus(365L, ChronoUnit.DAYS), customize: (DataProviderKt.Dsl.() -> Unit)? = null @@ -165,8 +167,7 @@ class Population(val clock: Clock, val idGenerator: IdGenerator) { publicKeySignature = "EDP public key signature".toByteStringUtf8() publicKeySignatureAlgorithmOid = "2.9999" } - requiredExternalDuchyIds += WORKER1_DUCHY.externalDuchyId - requiredExternalDuchyIds += WORKER2_DUCHY.externalDuchyId + requiredExternalDuchyIds += requiredDuchiesList customize?.invoke(this) } )