Skip to content

Commit

Permalink
Eliminate some bad strategy combinations for gather operands/outputs …
Browse files Browse the repository at this point in the history
…from the search space.

PiperOrigin-RevId: 683412789
  • Loading branch information
Google-ML-Automation committed Oct 8, 2024
1 parent 5b74999 commit ea35692
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
20 changes: 7 additions & 13 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,10 @@ BuildStrategyAndCost(
if (std::optional<HloSharding> improved_spec =
ConstructImprovedSharding(
to_merge, output_spec, gather_shape,
/* may_combine_partial_sharding */ true,
/* allow_aggressive_resharding */ false)) {
/*may_combine_partial_sharding=*/true,
/*allow_aggressive_resharding=*/false)) {
output_spec = *improved_spec;
add_sharding_strategy(data_spec, indices_spec, output_spec);
} else {
add_sharding_strategy(data_spec, indices_spec, to_merge);
}
}
// Infer output sharding from scatter indices sharding.
Expand All @@ -468,12 +466,10 @@ BuildStrategyAndCost(
if (std::optional<HloSharding> improved_spec =
ConstructImprovedSharding(
to_merge, output_spec, gather_shape,
/* may_combine_partial_sharding */ true,
/* allow_aggressive_resharding */ false)) {
/*may_combine_partial_sharding=*/true,
/*allow_aggressive_resharding=*/false)) {
output_spec = *improved_spec;
add_sharding_strategy(data_spec, indices_spec, output_spec);
} else {
add_sharding_strategy(data_spec, indices_spec, to_merge);
}
}
}
Expand All @@ -497,17 +493,15 @@ BuildStrategyAndCost(
if (std::optional<HloSharding> improved_spec =
ConstructImprovedSharding(
*maybe_from_data, output_spec, gather_shape,
/* may_combine_partial_sharding */ true,
/* allow_aggressive_resharding */ false)) {
/*may_combine_partial_sharding=*/true,
/*allow_aggressive_resharding=*/false)) {
output_spec = *improved_spec;
add_sharding_strategy(data_spec, indices_spec, output_spec);
} else {
add_sharding_strategy(data_spec, indices_spec, *maybe_from_data);
}
}
}
AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, 0,
/* operands_to_consider_all_strategies_for */ {0},
/*operands_to_consider_all_strategies_for=*/{0},
*strategy_group);
break;
}
Expand Down
3 changes: 2 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1997,8 +1997,9 @@ ENTRY %entry {
option.device_mesh_ids = {0, 1, 2, 3, 4, 5, 6, 7};
option.device_mesh_alpha = {1.0, 1.0, 1.0};
option.device_mesh_beta = {0.01, 1.0, 1.0};
option.memory_budget_per_device = (1000 * 128 + 8 * 128) / 8 + 8;
TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
VLOG(10) << module->ToString();
VLOG(5) << module->ToString();
EXPECT_TRUE(changed);
const HloInstruction* gather = FindInstruction(module.get(), "gather");
const HloInstruction* data = FindInstruction(module.get(), "data");
Expand Down

0 comments on commit ea35692

Please sign in to comment.