Skip to content

Commit

Permalink
Implement internal control service for HMSS protocol. (#1390)
Browse files Browse the repository at this point in the history
  • Loading branch information
renjiezh authored Jan 11, 2024
1 parent d4b9a65 commit 1b0f32a
Show file tree
Hide file tree
Showing 11 changed files with 955 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ object HonestMajorityShareShuffleProtocol {
): Int {
return when (stage) {
SETUP_PHASE,
WAIT_ON_SHUFFLE_INPUT -> 0
SHUFFLE_PHASE,
WAIT_ON_SHUFFLE_INPUT,
SHUFFLE_PHASE -> 0
WAIT_ON_AGGREGATION_INPUT -> 2
AGGREGATION_PHASE ->
// The output is the intermediate computation result either received from another duchy
// or computed locally.
1
WAIT_ON_AGGREGATION_INPUT -> 2
// Mill have nothing to do for this stage.
COMPLETE -> error("Computation should be ended with call to endComputation(...)")
// Stages that we can't transition to ever.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ import org.wfanet.measurement.internal.duchy.ComputationStageBlobMetadata
import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineStub
import org.wfanet.measurement.internal.duchy.GetOutputBlobMetadataRequest
import org.wfanet.measurement.internal.duchy.RecordOutputBlobPathResponse
import org.wfanet.measurement.internal.duchy.getComputationTokenRequest
import org.wfanet.measurement.internal.duchy.recordOutputBlobPathRequest
import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest

/** Implementation of the internal Async Computation Control Service. */
class AsyncComputationControlService(
Expand Down Expand Up @@ -88,8 +88,9 @@ class AsyncComputationControlService(
"Computation with global ID $globalComputationId not found"
)
.asRuntimeException()

val computationStage = token.computationStage
if (computationStage != request.computationStage) {
if (!stages.isValidStage(computationStage, request.computationStage)) {
if (computationStage == stages.nextStage(request.computationStage)) {
// This is technically an error, but it should be safe to treat as a no-op.
logger.warning { "[id=$globalComputationId]: Computation stage has already been advanced" }
Expand All @@ -102,44 +103,57 @@ class AsyncComputationControlService(
.asRuntimeException()
}

// Record the key provided as the path to the output blob.
val outputBlob =
token.blobsList.firstOrNull {
it.blobId == request.blobId && it.dependencyType == ComputationBlobDependency.OUTPUT
} ?: failGrpc(Status.FAILED_PRECONDITION) { "No output blob with ID ${request.blobId}" }
if (outputBlob.path.isNotEmpty()) {
if (outputBlob.path != request.blobPath) {
throw Status.FAILED_PRECONDITION.withDescription(
"Output blob ${outputBlob.blobId} already has a different path recorded"
)
.asRuntimeException()
}
logger.info {
"[id=$globalComputationId]: Path already recorded for output blob ${outputBlob.blobId}"
}
} else {
val response: RecordOutputBlobPathResponse =
try {
computationsClient.recordOutputBlobPath(
recordOutputBlobPathRequest {
this.token = token
outputBlobId = outputBlob.blobId
blobPath = request.blobPath
if (stages.expectBlob(computationStage)) {
val outputBlob =
token.blobsList.firstOrNull {
it.blobId == request.blobId && it.dependencyType == ComputationBlobDependency.OUTPUT
} ?: failGrpc(Status.FAILED_PRECONDITION) { "No output blob with ID ${request.blobId}" }
if (outputBlob.path.isNotEmpty()) {
if (outputBlob.path != request.blobPath) {
throw Status.FAILED_PRECONDITION.withDescription(
"Output blob ${outputBlob.blobId} already has a different path recorded"
)
.asRuntimeException()
}
logger.info {
"[id=$globalComputationId]: Path already recorded for output blob ${outputBlob.blobId}"
}
} else {
val response =
try {
computationsClient.recordOutputBlobPath(
recordOutputBlobPathRequest {
this.token = token
outputBlobId = outputBlob.blobId
blobPath = request.blobPath
}
)
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.ABORTED -> RetryableException(e)
else -> Status.UNKNOWN.withCause(e).asRuntimeException()
}
)
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.ABORTED -> RetryableException(e)
else -> Status.UNKNOWN.withCause(e).asRuntimeException()
}
}

// Computation has changed, so use the new token.
// Computation has changed, so use the new token.
token = response.token
}
}

if (stages.expectStageInput(token)) {
val computationDetails =
stages.updateComputationDetails(token.computationDetails, request.computationStageInput)
val response =
computationsClient.updateComputationDetails(
updateComputationDetailsRequest {
this.token = token
details = computationDetails
}
)
token = response.token
}

// Advance the computation to next stage if all blob paths are present.
if (!token.outputPathList().any(String::isEmpty)) {
if (stages.readyForNextStage(token)) {
try {
computationsClient.advanceComputationStage(
computationToken = token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ kt_jvm_library(
"//src/main/proto/wfa/measurement/internal/duchy:computation_blob_dependency_kt_jvm_proto",
"//src/main/proto/wfa/measurement/internal/duchy:computation_protocols_kt_jvm_proto",
"//src/main/proto/wfa/measurement/internal/duchy:computation_token_kt_jvm_proto",
"//src/main/proto/wfa/measurement/internal/duchy/protocol:honest_majority_share_shuffle_kt_jvm_proto",
"//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_kt_jvm_proto",
"//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_kt_jvm_proto",
],
)
Loading

0 comments on commit 1b0f32a

Please sign in to comment.