Skip to content

Commit

Permalink
Merge branch 'main' into tristanvuong-replace-separate-inserts-with-1…
Browse files Browse the repository at this point in the history
…-large-insert-per-table
  • Loading branch information
tristanvuong2021 authored Feb 27, 2024
2 parents 8852f35 + 6fa65e1 commit 1b92892
Show file tree
Hide file tree
Showing 22 changed files with 952 additions and 1,196 deletions.
6 changes: 3 additions & 3 deletions src/main/k8s/dev/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@wfa_common_jvm//build:defs.bzl", "expand_template")
load("@wfa_rules_cue//cue:defs.bzl", "cue_library")
load(
"//build:variables.bzl",
Expand All @@ -8,9 +9,8 @@ load(
"KINGDOM_K8S_SETTINGS",
"SIMULATOR_K8S_SETTINGS",
)
load("@wfa_common_jvm//build:defs.bzl", "expand_template")
load("//src/main/k8s:macros.bzl", "cue_dump")
load("//build/k8s:defs.bzl", "kustomization_dir")
load("//src/main/k8s:macros.bzl", "cue_dump")

SECRET_NAME = "certs-and-configs"

Expand Down Expand Up @@ -339,13 +339,13 @@ EDP_SIMULATOR_TAGS = {
"image_tag": IMAGE_REPOSITORY_SETTINGS.image_tag,
"kingdom_public_api_target": KINGDOM_K8S_SETTINGS.public_api_target,
"duchy_public_api_target": DUCHY_K8S_SETTINGS.public_api_target,
"google_cloud_project": GCLOUD_SETTINGS.project,
}

cue_dump(
name = "bigquery_edp_simulator_gke",
srcs = ["bigquery_edp_simulator_gke.cue"],
cue_tags = dict(EDP_SIMULATOR_TAGS.items() + {
"google_cloud_project": GCLOUD_SETTINGS.project,
"bigquery_dataset": SIMULATOR_K8S_SETTINGS.bigquery_dataset,
"bigquery_table": SIMULATOR_K8S_SETTINGS.bigquery_table,
}.items()),
Expand Down
11 changes: 0 additions & 11 deletions src/main/k8s/dev/bigquery_edp_simulator_gke.cue
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

package k8s

#SimulatorServiceAccount: "simulator"

_bigQueryConfig: #BigQueryConfig & {
dataset: string @tag("bigquery_dataset")
table: string @tag("bigquery_table")
Expand All @@ -41,16 +39,7 @@ edp_simulators: {
_container: {
resources: _resourceRequirements
}
spec: template: spec: #ServiceAccountPodSpec & {
serviceAccountName: #SimulatorServiceAccount
}
}
}
}
}

serviceAccounts: {
"\(#SimulatorServiceAccount)": #WorkloadIdentityServiceAccount & {
_iamServiceAccountName: "simulator"
}
}
11 changes: 10 additions & 1 deletion src/main/k8s/dev/edp_simulator_gke.cue
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ _secret_name: string @tag("secret_name")
_kingdomPublicApiTarget: string @tag("kingdom_public_api_target")
_duchyPublicApiTarget: string @tag("duchy_public_api_target")

#SimulatorServiceAccount: "simulator"

objectSets: [
serviceAccounts,
configMaps,
Expand Down Expand Up @@ -62,7 +64,9 @@ edp_simulators: {
_mc_resource_name: _mc_name

deployment: {
spec: template: spec: #SpotVmPodSpec
spec: template: spec: #SpotVmPodSpec & #ServiceAccountPodSpec & {
serviceAccountName: #SimulatorServiceAccount
}
}
}
}
Expand All @@ -71,6 +75,11 @@ edp_simulators: {
serviceAccounts: [Name=string]: #ServiceAccount & {
metadata: name: Name
}
serviceAccounts: {
"\(#SimulatorServiceAccount)": #WorkloadIdentityServiceAccount & {
_iamServiceAccountName: "simulator"
}
}

configMaps: [Name=string]: #ConfigMap & {
metadata: name: Name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.STAGE_UNSPECIFIED
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.UNRECOGNIZED
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_AGGREGATION_INPUT
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_TO_START
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.stageDetails
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.waitOnAggregationInputDetails

Expand All @@ -57,11 +59,11 @@ object HonestMajorityShareShuffleProtocol {

override val validSuccessors =
mapOf(
INITIALIZED to setOf(SETUP_PHASE),
// A Non-aggregator will skip WAIT_ON_SHUFFLE_INPUT into SHUFFLE_PHASE if the requisition
// data from EDPs and seed from the peer worker have been received.
SETUP_PHASE to setOf(WAIT_ON_SHUFFLE_INPUT, SHUFFLE_PHASE),
WAIT_ON_SHUFFLE_INPUT to setOf(SHUFFLE_PHASE),
INITIALIZED to setOf(WAIT_TO_START, WAIT_ON_SHUFFLE_INPUT_PHASE_ONE),
WAIT_TO_START to setOf(SETUP_PHASE),
WAIT_ON_SHUFFLE_INPUT_PHASE_ONE to setOf(SETUP_PHASE),
SETUP_PHASE to setOf(WAIT_ON_SHUFFLE_INPUT_PHASE_TWO, SHUFFLE_PHASE),
WAIT_ON_SHUFFLE_INPUT_PHASE_TWO to setOf(SHUFFLE_PHASE),
WAIT_ON_AGGREGATION_INPUT to setOf(AGGREGATION_PHASE),
SHUFFLE_PHASE to setOf(COMPLETE),
AGGREGATION_PHASE to setOf(COMPLETE),
Expand Down Expand Up @@ -91,10 +93,16 @@ object HonestMajorityShareShuffleProtocol {
): Boolean {
return when (stage) {
INITIALIZED,
WAIT_TO_START,
WAIT_ON_SHUFFLE_INPUT_PHASE_ONE,
WAIT_ON_SHUFFLE_INPUT_PHASE_TWO,
SETUP_PHASE,
SHUFFLE_PHASE -> details.role == RoleInComputation.NON_AGGREGATOR
WAIT_ON_AGGREGATION_INPUT,
AGGREGATION_PHASE -> details.role == RoleInComputation.AGGREGATOR
else -> true /* Stage can be executed at either primary or non-primary */
COMPLETE -> true /* Stage can be executed at either AGGREGATOR or NON_AGGREGATOR */
STAGE_UNSPECIFIED,
UNRECOGNIZED -> error("Invalid Stage. $stage")
}
}

Expand All @@ -106,7 +114,9 @@ object HonestMajorityShareShuffleProtocol {
SETUP_PHASE,
SHUFFLE_PHASE,
AGGREGATION_PHASE -> AfterTransition.ADD_UNCLAIMED_TO_QUEUE
WAIT_ON_SHUFFLE_INPUT,
WAIT_TO_START,
WAIT_ON_SHUFFLE_INPUT_PHASE_ONE,
WAIT_ON_SHUFFLE_INPUT_PHASE_TWO,
WAIT_ON_AGGREGATION_INPUT -> AfterTransition.DO_NOT_ADD_TO_QUEUE
COMPLETE -> error("Computation should be ended with call to endComputation(...)")
// Stages that we can't transition to ever.
Expand All @@ -122,19 +132,18 @@ object HonestMajorityShareShuffleProtocol {
): Int {
return when (stage) {
SETUP_PHASE,
WAIT_ON_SHUFFLE_INPUT,
SHUFFLE_PHASE -> 0
WAIT_TO_START -> 0
// The output of these stages are the data received from the peer non-aggregator duchy:
WAIT_ON_SHUFFLE_INPUT_PHASE_ONE,
WAIT_ON_SHUFFLE_INPUT_PHASE_TWO,
// The output of these stages are the computed intermediate data:
SHUFFLE_PHASE,
AGGREGATION_PHASE -> 1
WAIT_ON_AGGREGATION_INPUT -> 2
AGGREGATION_PHASE ->
// The output is the intermediate computation result either received from another duchy
// or computed locally.
1
// Mill have nothing to do for this stage.
COMPLETE -> error("Computation should be ended with call to endComputation(...)")
// Stages that we can't transition to ever.
INITIALIZED,
UNRECOGNIZED,
STAGE_UNSPECIFIED,
INITIALIZED -> error("Cannot make transition function to stage $stage")
STAGE_UNSPECIFIED -> error("Cannot make transition function to stage $stage")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoro
import org.wfanet.measurement.internal.duchy.GetOutputBlobMetadataRequest
import org.wfanet.measurement.internal.duchy.getComputationTokenRequest
import org.wfanet.measurement.internal.duchy.recordOutputBlobPathRequest
import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest

/** Implementation of the internal Async Computation Control Service. */
class AsyncComputationControlService(
Expand Down Expand Up @@ -90,7 +89,7 @@ class AsyncComputationControlService(
.asRuntimeException()

val computationStage = token.computationStage
if (!stages.isValidStage(computationStage, request.computationStage)) {
if (computationStage != request.computationStage) {
if (computationStage == stages.nextStage(request.computationStage)) {
// This is technically an error, but it should be safe to treat as a no-op.
logger.warning { "[id=$globalComputationId]: Computation stage has already been advanced" }
Expand All @@ -103,57 +102,43 @@ class AsyncComputationControlService(
.asRuntimeException()
}

if (stages.expectBlob(computationStage)) {
val outputBlob =
token.blobsList.firstOrNull {
it.blobId == request.blobId && it.dependencyType == ComputationBlobDependency.OUTPUT
} ?: failGrpc(Status.FAILED_PRECONDITION) { "No output blob with ID ${request.blobId}" }
if (outputBlob.path.isNotEmpty()) {
if (outputBlob.path != request.blobPath) {
throw Status.FAILED_PRECONDITION.withDescription(
"Output blob ${outputBlob.blobId} already has a different path recorded"
)
.asRuntimeException()
}
logger.info {
"[id=$globalComputationId]: Path already recorded for output blob ${outputBlob.blobId}"
}
} else {
val response =
try {
computationsClient.recordOutputBlobPath(
recordOutputBlobPathRequest {
this.token = token
outputBlobId = outputBlob.blobId
blobPath = request.blobPath
}
)
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.ABORTED -> RetryableException(e)
else -> Status.UNKNOWN.withCause(e).asRuntimeException()
}
}

// Computation has changed, so use the new token.
token = response.token
val outputBlob =
token.blobsList.firstOrNull {
it.blobId == request.blobId && it.dependencyType == ComputationBlobDependency.OUTPUT
} ?: failGrpc(Status.FAILED_PRECONDITION) { "No output blob with ID ${request.blobId}" }
if (outputBlob.path.isNotEmpty()) {
if (outputBlob.path != request.blobPath) {
throw Status.FAILED_PRECONDITION.withDescription(
"Output blob ${outputBlob.blobId} already has a different path recorded"
)
.asRuntimeException()
}
}

if (stages.expectStageInput(token)) {
val computationDetails =
stages.updateComputationDetails(token.computationDetails, request.computationStageInput)
logger.info {
"[id=$globalComputationId]: Path already recorded for output blob ${outputBlob.blobId}"
}
} else {
val response =
computationsClient.updateComputationDetails(
updateComputationDetailsRequest {
this.token = token
details = computationDetails
try {
computationsClient.recordOutputBlobPath(
recordOutputBlobPathRequest {
this.token = token
outputBlobId = outputBlob.blobId
blobPath = request.blobPath
}
)
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.ABORTED -> RetryableException(e)
else -> Status.UNKNOWN.withCause(e).asRuntimeException()
}
)
}

// Computation has changed, so use the new token.
token = response.token
}

if (stages.readyForNextStage(token)) {
// Advance the computation to next stage if all blob paths are present.
if (!token.outputPathList().any(String::isEmpty)) {
try {
computationsClient.advanceComputationStage(
computationToken = token,
Expand Down
Loading

0 comments on commit 1b92892

Please sign in to comment.