Skip to content

Commit

Permalink
feat!: Update cross-media-measurement-api dep for Requisition.State.W…
Browse files Browse the repository at this point in the history
…ITHDRAWN (#1746)

This sets the Requisition state to WITHDRAWN when a Measurement is cancelled or failed.
  • Loading branch information
SanjayVas authored Aug 16, 2024
1 parent 067c0aa commit 362f05e
Show file tree
Hide file tree
Showing 21 changed files with 232 additions and 41 deletions.
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ bazel_dep(
)
bazel_dep(
name = "cross-media-measurement-api",
version = "0.64.0",
version = "0.65.0",
repo_name = "wfa_measurement_proto",
)
bazel_dep(
Expand Down
24 changes: 12 additions & 12 deletions MODULE.bazel.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoEnum
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.ComputationParticipant
Expand Down Expand Up @@ -92,8 +93,8 @@ private val BASE_SQL =
class ComputationParticipantReader : BaseSpannerReader<ComputationParticipantReader.Result>() {
data class Result(
val computationParticipant: ComputationParticipant,
val measurementId: Long,
val measurementConsumerId: Long,
val measurementId: InternalId,
val measurementConsumerId: InternalId,
val measurementState: Measurement.State,
val measurementDetails: Measurement.Details,
)
Expand Down Expand Up @@ -140,8 +141,8 @@ class ComputationParticipantReader : BaseSpannerReader<ComputationParticipantRea
override suspend fun translate(struct: Struct) =
Result(
buildComputationParticipant(struct),
struct.getLong("MeasurementId"),
struct.getLong("MeasurementConsumerId"),
struct.getInternalId("MeasurementId"),
struct.getInternalId("MeasurementConsumerId"),
struct.getProtoEnum("MeasurementState", Measurement.State::forNumber),
struct.getProtoMessage("MeasurementDetails", Measurement.Details.parser()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class BatchCancelMeasurements(private val requests: BatchCancelMeasurementsReque
previousState = result.measurement.state,
measurementLogEntryDetails = measurementLogEntryDetails,
)
withdrawRequisitions(result.measurementConsumerId, result.measurementId)
}

return batchCancelMeasurementsResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers

import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.gcloud.spanner.set
import org.wfanet.measurement.internal.kingdom.Measurement
import org.wfanet.measurement.internal.kingdom.MeasurementLogEntryKt
import org.wfanet.measurement.internal.kingdom.copy
Expand Down Expand Up @@ -50,7 +51,8 @@ class CancelMeasurement(

when (val state = measurement.state) {
Measurement.State.PENDING_REQUISITION_PARAMS,
Measurement.State.PENDING_REQUISITION_FULFILLMENT,
Measurement.State.PENDING_REQUISITION_FULFILLMENT ->
withdrawRequisitions(measurementConsumerId, measurementId)
Measurement.State.PENDING_PARTICIPANT_CONFIRMATION,
Measurement.State.PENDING_COMPUTATION -> {}
Measurement.State.SUCCEEDED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,16 @@ class ConfirmComputationParticipant(private val request: ConfirmComputationParti
}

val duchyIds: List<InternalId> =
getComputationParticipantsDuchyIds(
InternalId(measurementConsumerId),
InternalId(measurementId),
)
.filter { it.value != duchyId }
getComputationParticipantsDuchyIds(measurementConsumerId, measurementId).filter {
it.value != duchyId
}

if (
computationParticipantsInState(
transactionContext,
duchyIds,
InternalId(measurementConsumerId),
InternalId(measurementId),
measurementConsumerId,
measurementId,
NEXT_COMPUTATION_PARTICIPANT_STATE,
)
) {
Expand All @@ -139,8 +137,8 @@ class ConfirmComputationParticipant(private val request: ConfirmComputationParti
}

updateMeasurementState(
measurementConsumerId = InternalId(measurementConsumerId),
measurementId = InternalId(measurementId),
measurementConsumerId = measurementConsumerId,
measurementId = measurementId,
nextState = Measurement.State.PENDING_COMPUTATION,
previousState = measurementState,
measurementLogEntryDetails = measurementLogEntryDetails,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class FailComputationParticipant(private val request: FailComputationParticipant

when (measurementState) {
Measurement.State.PENDING_REQUISITION_PARAMS,
Measurement.State.PENDING_REQUISITION_FULFILLMENT,
Measurement.State.PENDING_REQUISITION_FULFILLMENT ->
withdrawRequisitions(measurementConsumerId, measurementId)
Measurement.State.PENDING_PARTICIPANT_CONFIRMATION,
Measurement.State.PENDING_COMPUTATION -> {}
Measurement.State.FAILED,
Expand Down Expand Up @@ -142,17 +143,17 @@ class FailComputationParticipant(private val request: FailComputationParticipant
}

updateMeasurementState(
measurementConsumerId = InternalId(measurementConsumerId),
measurementId = InternalId(measurementId),
measurementConsumerId = measurementConsumerId,
measurementId = measurementId,
nextState = Measurement.State.FAILED,
previousState = measurementState,
measurementLogEntryDetails = measurementLogEntryDetails,
details = updatedMeasurementDetails,
)

insertDuchyMeasurementLogEntry(
InternalId(measurementId),
InternalId(measurementConsumerId),
measurementId,
measurementConsumerId,
InternalId(duchyId),
duchyMeasurementLogEntry.details,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class RefuseRequisition(private val request: RefuseRequisitionRequest) :
SpannerWriter<Requisition, Requisition>() {
override suspend fun TransactionScope.runTransaction(): Requisition {
val readResult: RequisitionReader.Result = readRequisition()
val (measurementConsumerId, measurementId, _, requisition, measurementDetails) = readResult
val (measurementConsumerId, measurementId, requisitionId, requisition, measurementDetails) =
readResult

val state = requisition.state
if (state != Requisition.State.UNFULFILLED) {
Expand Down Expand Up @@ -97,6 +98,8 @@ class RefuseRequisition(private val request: RefuseRequisitionRequest) :
details = updatedMeasurementDetails,
)

withdrawRequisitions(measurementConsumerId, measurementId, requisitionId)

return requisition.copy {
this.state = Requisition.State.REFUSED
details = updatedDetails
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,5 +252,7 @@ class RevokeCertificate(private val request: RevokeCertificateRequest) :
measurementLogEntryDetails = measurementLogEntryDetails,
details = details,
)

withdrawRequisitions(measurementConsumerId, measurementId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ class SetParticipantRequisitionParams(private val request: SetParticipantRequisi
)
}

val measurementId = InternalId(computationParticipantResult.measurementId)
val measurementConsumerId = InternalId(computationParticipantResult.measurementConsumerId)
val measurementId = computationParticipantResult.measurementId
val measurementConsumerId = computationParticipantResult.measurementConsumerId

if (computationParticipant.state != ComputationParticipant.State.CREATED) {
throw ComputationParticipantStateIllegalException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@

package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers

import com.google.cloud.spanner.Key
import com.google.cloud.spanner.KeySet
import com.google.cloud.spanner.Value
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.gcloud.spanner.bufferUpdateMutation
import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoEnum
import org.wfanet.measurement.gcloud.spanner.set
import org.wfanet.measurement.gcloud.spanner.toProtoEnum
import org.wfanet.measurement.internal.kingdom.Requisition
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.RequisitionReader

Expand All @@ -39,3 +44,41 @@ internal fun SpannerWriter.TransactionScope.updateRequisition(
}
}
}

internal suspend fun SpannerWriter.TransactionScope.withdrawRequisitions(
measurementConsumerId: InternalId,
measurementId: InternalId,
excludedRequisitionId: InternalId? = null,
) {
val keyPrefix = Key.of(measurementConsumerId.value, measurementId.value)
transactionContext
.read("Requisitions", KeySet.prefixRange(keyPrefix), listOf("RequisitionId", "State"))
.collect { row ->
val requisitionId = row.getInternalId("RequisitionId")
if (requisitionId == excludedRequisitionId) {
return@collect
}

val requisitionState: Requisition.State =
row.getProtoEnum("State") {
Requisition.State.forNumber(it) ?: Requisition.State.UNRECOGNIZED
}
when (requisitionState) {
Requisition.State.PENDING_PARAMS,
Requisition.State.UNFULFILLED -> {}
Requisition.State.FULFILLED,
Requisition.State.REFUSED,
Requisition.State.WITHDRAWN,
Requisition.State.STATE_UNSPECIFIED,
Requisition.State.UNRECOGNIZED -> return@collect
}

transactionContext.bufferUpdateMutation("Requisitions") {
set("MeasurementConsumerId" to measurementConsumerId)
set("MeasurementId" to measurementId)
set("RequisitionId" to requisitionId)
set("State").toProtoEnum(Requisition.State.WITHDRAWN)
set("UpdateTime").to(Value.COMMIT_TIMESTAMP)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,7 @@ fun InternalRequisition.State.toRequisitionState(): Requisition.State =
InternalRequisition.State.UNFULFILLED -> Requisition.State.UNFULFILLED
InternalRequisition.State.FULFILLED -> Requisition.State.FULFILLED
InternalRequisition.State.REFUSED -> Requisition.State.REFUSED
InternalRequisition.State.WITHDRAWN -> Requisition.State.WITHDRAWN
InternalRequisition.State.STATE_UNSPECIFIED,
InternalRequisition.State.UNRECOGNIZED -> Requisition.State.STATE_UNSPECIFIED
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ private fun State.toInternal(): InternalState =
State.UNFULFILLED -> InternalState.UNFULFILLED
State.FULFILLED -> InternalState.FULFILLED
State.REFUSED -> InternalState.REFUSED
State.WITHDRAWN -> InternalState.WITHDRAWN
State.STATE_UNSPECIFIED,
State.UNRECOGNIZED -> InternalState.STATE_UNSPECIFIED
}
Expand Down Expand Up @@ -527,6 +528,7 @@ private fun buildInternalStreamRequisitionsRequest(
states += InternalState.UNFULFILLED
states += InternalState.FULFILLED
states += InternalState.REFUSED
states += InternalState.WITHDRAWN
} else {
states += requestStates.map { it.toInternal() }
}
Expand Down
Loading

0 comments on commit 362f05e

Please sign in to comment.