Skip to content

Commit

Permalink
[SDY] add JAX lowering to Shardy ShardingGroupOp for shard_alike.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694567084
  • Loading branch information
Varcho authored and Google-ML-Automation committed Nov 8, 2024
1 parent 22c2e04 commit d9e4ec5
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions xla/service/sharding_remover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ absl::StatusOr<bool> ShardingRemover::Run(

const absl::flat_hash_set<absl::string_view> to_remove_sharding_ops = {
"Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape",
sdy::kShardingGroupCustomCallTargetName,
sdy::kFuncResultShardingTargetName};

for (HloComputation* computation : module->computations(execution_threads)) {
Expand All @@ -57,6 +58,14 @@ absl::StatusOr<bool> ShardingRemover::Run(
}
CHECK(instruction->operand_count() == 1)
<< "Sharding instruction must have exactly one operand";

// ShardingGroupOp is dangling so we just remove it.
if (instruction->custom_call_target() ==
sdy::kShardingGroupCustomCallTargetName) {
TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
continue;
}

TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(
instruction->mutable_operand(0), name()));
changed = true;
Expand Down

0 comments on commit d9e4ec5

Please sign in to comment.