Skip to content

Commit

Permalink
Update HMSS stages. Revert Duchy ControlService into Blob-Only Patter…
Browse files Browse the repository at this point in the history
…n. (#1476)
  • Loading branch information
renjiezh authored Feb 27, 2024
1 parent d69e6a8 commit 6fa65e1
Show file tree
Hide file tree
Showing 14 changed files with 334 additions and 829 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.STAGE_UNSPECIFIED
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.UNRECOGNIZED
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_AGGREGATION_INPUT
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_TO_START
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.stageDetails
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.waitOnAggregationInputDetails

Expand All @@ -57,11 +59,11 @@ object HonestMajorityShareShuffleProtocol {

override val validSuccessors =
mapOf(
INITIALIZED to setOf(SETUP_PHASE),
// A Non-aggregator will skip WAIT_ON_SHUFFLE_INPUT into SHUFFLE_PHASE if the requisition
// data from EDPs and seed from the peer worker have been received.
SETUP_PHASE to setOf(WAIT_ON_SHUFFLE_INPUT, SHUFFLE_PHASE),
WAIT_ON_SHUFFLE_INPUT to setOf(SHUFFLE_PHASE),
INITIALIZED to setOf(WAIT_TO_START, WAIT_ON_SHUFFLE_INPUT_PHASE_ONE),
WAIT_TO_START to setOf(SETUP_PHASE),
WAIT_ON_SHUFFLE_INPUT_PHASE_ONE to setOf(SETUP_PHASE),
SETUP_PHASE to setOf(WAIT_ON_SHUFFLE_INPUT_PHASE_TWO, SHUFFLE_PHASE),
WAIT_ON_SHUFFLE_INPUT_PHASE_TWO to setOf(SHUFFLE_PHASE),
WAIT_ON_AGGREGATION_INPUT to setOf(AGGREGATION_PHASE),
SHUFFLE_PHASE to setOf(COMPLETE),
AGGREGATION_PHASE to setOf(COMPLETE),
Expand Down Expand Up @@ -91,10 +93,16 @@ object HonestMajorityShareShuffleProtocol {
): Boolean {
return when (stage) {
INITIALIZED,
WAIT_TO_START,
WAIT_ON_SHUFFLE_INPUT_PHASE_ONE,
WAIT_ON_SHUFFLE_INPUT_PHASE_TWO,
SETUP_PHASE,
SHUFFLE_PHASE -> details.role == RoleInComputation.NON_AGGREGATOR
WAIT_ON_AGGREGATION_INPUT,
AGGREGATION_PHASE -> details.role == RoleInComputation.AGGREGATOR
else -> true /* Stage can be executed at either primary or non-primary */
COMPLETE -> true /* Stage can be executed at either AGGREGATOR or NON_AGGREGATOR */
STAGE_UNSPECIFIED,
UNRECOGNIZED -> error("Invalid Stage. $stage")
}
}

Expand All @@ -106,7 +114,9 @@ object HonestMajorityShareShuffleProtocol {
SETUP_PHASE,
SHUFFLE_PHASE,
AGGREGATION_PHASE -> AfterTransition.ADD_UNCLAIMED_TO_QUEUE
WAIT_ON_SHUFFLE_INPUT,
WAIT_TO_START,
WAIT_ON_SHUFFLE_INPUT_PHASE_ONE,
WAIT_ON_SHUFFLE_INPUT_PHASE_TWO,
WAIT_ON_AGGREGATION_INPUT -> AfterTransition.DO_NOT_ADD_TO_QUEUE
COMPLETE -> error("Computation should be ended with call to endComputation(...)")
// Stages that we can't transition to ever.
Expand All @@ -122,19 +132,18 @@ object HonestMajorityShareShuffleProtocol {
): Int {
return when (stage) {
SETUP_PHASE,
WAIT_ON_SHUFFLE_INPUT,
SHUFFLE_PHASE -> 0
WAIT_TO_START -> 0
// The output of these stages are the data received from the peer non-aggregator duchy:
WAIT_ON_SHUFFLE_INPUT_PHASE_ONE,
WAIT_ON_SHUFFLE_INPUT_PHASE_TWO,
// The output of these stages are the computed intermediate data:
SHUFFLE_PHASE,
AGGREGATION_PHASE -> 1
WAIT_ON_AGGREGATION_INPUT -> 2
AGGREGATION_PHASE ->
// The output is the intermediate computation result either received from another duchy
// or computed locally.
1
// Mill have nothing to do for this stage.
COMPLETE -> error("Computation should be ended with call to endComputation(...)")
// Stages that we can't transition to ever.
INITIALIZED,
UNRECOGNIZED,
STAGE_UNSPECIFIED,
INITIALIZED -> error("Cannot make transition function to stage $stage")
STAGE_UNSPECIFIED -> error("Cannot make transition function to stage $stage")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoro
import org.wfanet.measurement.internal.duchy.GetOutputBlobMetadataRequest
import org.wfanet.measurement.internal.duchy.getComputationTokenRequest
import org.wfanet.measurement.internal.duchy.recordOutputBlobPathRequest
import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest

/** Implementation of the internal Async Computation Control Service. */
class AsyncComputationControlService(
Expand Down Expand Up @@ -90,7 +89,7 @@ class AsyncComputationControlService(
.asRuntimeException()

val computationStage = token.computationStage
if (!stages.isValidStage(computationStage, request.computationStage)) {
if (computationStage != request.computationStage) {
if (computationStage == stages.nextStage(request.computationStage)) {
// This is technically an error, but it should be safe to treat as a no-op.
logger.warning { "[id=$globalComputationId]: Computation stage has already been advanced" }
Expand All @@ -103,57 +102,43 @@ class AsyncComputationControlService(
.asRuntimeException()
}

if (stages.expectBlob(computationStage)) {
val outputBlob =
token.blobsList.firstOrNull {
it.blobId == request.blobId && it.dependencyType == ComputationBlobDependency.OUTPUT
} ?: failGrpc(Status.FAILED_PRECONDITION) { "No output blob with ID ${request.blobId}" }
if (outputBlob.path.isNotEmpty()) {
if (outputBlob.path != request.blobPath) {
throw Status.FAILED_PRECONDITION.withDescription(
"Output blob ${outputBlob.blobId} already has a different path recorded"
)
.asRuntimeException()
}
logger.info {
"[id=$globalComputationId]: Path already recorded for output blob ${outputBlob.blobId}"
}
} else {
val response =
try {
computationsClient.recordOutputBlobPath(
recordOutputBlobPathRequest {
this.token = token
outputBlobId = outputBlob.blobId
blobPath = request.blobPath
}
)
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.ABORTED -> RetryableException(e)
else -> Status.UNKNOWN.withCause(e).asRuntimeException()
}
}

// Computation has changed, so use the new token.
token = response.token
val outputBlob =
token.blobsList.firstOrNull {
it.blobId == request.blobId && it.dependencyType == ComputationBlobDependency.OUTPUT
} ?: failGrpc(Status.FAILED_PRECONDITION) { "No output blob with ID ${request.blobId}" }
if (outputBlob.path.isNotEmpty()) {
if (outputBlob.path != request.blobPath) {
throw Status.FAILED_PRECONDITION.withDescription(
"Output blob ${outputBlob.blobId} already has a different path recorded"
)
.asRuntimeException()
}
}

if (stages.expectStageInput(token)) {
val computationDetails =
stages.updateComputationDetails(token.computationDetails, request.computationStageInput)
logger.info {
"[id=$globalComputationId]: Path already recorded for output blob ${outputBlob.blobId}"
}
} else {
val response =
computationsClient.updateComputationDetails(
updateComputationDetailsRequest {
this.token = token
details = computationDetails
try {
computationsClient.recordOutputBlobPath(
recordOutputBlobPathRequest {
this.token = token
outputBlobId = outputBlob.blobId
blobPath = request.blobPath
}
)
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.ABORTED -> RetryableException(e)
else -> Status.UNKNOWN.withCause(e).asRuntimeException()
}
)
}

// Computation has changed, so use the new token.
token = response.token
}

if (stages.readyForNextStage(token)) {
// Advance the computation to next stage if all blob paths are present.
if (!token.outputPathList().any(String::isEmpty)) {
try {
computationsClient.advanceComputationStage(
computationToken = token,
Expand Down
Loading

0 comments on commit 6fa65e1

Please sign in to comment.