diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index d4ad82c32e5a4..35e334f120ca5 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2316,7 +2316,9 @@ AutoShardingSolverResult CallSolver( const std::vector<NodeStrategyIdx>& s_hint, int64_t memory_budget_per_device, bool crash_at_infinity_costs_check, bool compute_iis, int64_t solver_timeout_in_seconds, - bool allow_alias_to_follower_conversion) { + bool allow_alias_to_follower_conversion, + const absl::flat_hash_map<std::string, const HloInstruction*>& + sharding_propagation_solution) { // Serialize edges and edge costs to 1d numpy arrays AutoShardingSolverRequest request; request.num_nodes = leaf_strategies.size(); @@ -2346,10 +2348,16 @@ AutoShardingSolverResult CallSolver( // Serialize node costs for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { const StrategyVector* strategies = leaf_strategies[node_idx]; + auto instruction_name = instructions.at(strategies->instruction_id)->name(); request.instruction_names.push_back( - absl::StrCat(instructions.at(strategies->instruction_id)->name(), - " (id: ", node_idx, ")")); + absl::StrCat(instruction_name, " (id: ", node_idx, ")")); std::vector<double> ci, di, mi, pi; + auto default_strategy = HloSharding::Replicate(); + auto iter = sharding_propagation_solution.find(instruction_name); + if (iter != sharding_propagation_solution.end()) { + CHECK(iter->second->has_sharding()) << iter->second->ToString(); + default_strategy = iter->second->sharding(); + } for (NodeStrategyIdx j = 0; j < strategies->leaf_vector.size(); ++j) { const ShardingStrategy& strategy = strategies->leaf_vector[j]; const HloSharding& sharding = strategy.output_sharding; @@ -2359,7 +2367,7 @@ AutoShardingSolverResult CallSolver( mi.push_back(strategy.memory_cost); // TODO(moffitt): Revisit the default strategy below, which is currently // defined as the "trivial sharding" in hlo_sharding.h - pi.push_back(sharding.IsReplicated() && !sharding.IsManual() ? 0.0 : 1.0); + pi.push_back(sharding == default_strategy ? 0.0 : 1.0); } request.c.push_back(ci); request.d.push_back(di); @@ -3974,7 +3982,9 @@ AutoShardingImplementation::AutoShardingImplementation( StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding( HloModule* module, const absl::flat_hash_set<std::string>& replicated_small_tensors, - const absl::flat_hash_set<absl::string_view>& execution_threads) { + const absl::flat_hash_set<absl::string_view>& execution_threads, + const absl::flat_hash_map<std::string, const HloInstruction*>& + sharding_propagation_solution) { if (!option_.enable) { return AutoShardingResult::kModuleUnchanged; } @@ -4226,7 +4236,7 @@ StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding( if (!solver_option.load_solution_vector) { auto solver_result = Solve(*hlo_live_range, liveness_set, strategy_map, leaf_strategies, - cost_graph, alias_set, option_); + cost_graph, alias_set, option_, sharding_propagation_solution); if (solver_result.skip_auto_sharding) { return AutoShardingResult::kModuleUnchangedNoShardingPerfomed; } else if (!solver_result.status.ok()) { @@ -4317,6 +4327,13 @@ bool IsModuleManuallySharded(const HloModule* module) { return false; } +std::unique_ptr<HloModule> CloneModule(const HloModule* module) { + auto module_clone = module->Clone(""); + module_clone->set_layout_canonicalization_callback( + module->layout_canonicalization_callback()); + return module_clone; +} + StatusOr<bool> AutoSharding::Run( HloModule* module, const absl::flat_hash_set<absl::string_view>& execution_threads) { @@ -4383,6 +4400,35 @@ StatusOr<bool> AutoSharding::Run( mesh_shapes.push_back(option_.device_mesh_shape); } + absl::flat_hash_map<std::string, const HloInstruction*> + sharding_propagation_solution; + std::unique_ptr<HloModule> module_with_default_solution = nullptr; + if (option_.use_sharding_propagation_for_default_shardings) { + module_with_default_solution = CloneModule(module); + // TODO(pratikf): Ensure that we're passing the correct customc all sharding + // helper to the sharding propagation pass. + auto sharding_prop = ShardingPropagation( + /*is_spmd */ true, /*propagate_metadata */ false, + /*allow_spmd_sharding_propagation_to_output*/ + module->config().allow_spmd_sharding_propagation_to_output(), + /*allow_spmd_sharding_propagation_to_parameters */ + absl::InlinedVector<bool, 1>{false}, + /*cse_prevention_only */ false, + /*sharding_helper*/ nullptr); + + CHECK_OK(sharding_prop.Run(module_with_default_solution.get(), + execution_threads)); + LOG(INFO) << module_with_default_solution->ToString(); + for (const auto computation : + module_with_default_solution->computations()) { + for (const auto instruction : computation->instructions()) { + if (instruction->has_sharding()) { + sharding_propagation_solution[instruction->name()] = instruction; + } + } + } + } + size_t num_meshes = mesh_shapes.size(); std::vector<std::unique_ptr<HloModule>> modules(num_meshes); std::vector<StatusOr<AutoShardingResult>> changed( @@ -4399,11 +4445,10 @@ StatusOr<bool> AutoSharding::Run( AutoShardingOption this_option = option_; this_option.device_mesh_shape = mesh_shapes[i]; auto pass = new AutoShardingImplementation(this_option); - auto module_clone = module->Clone(""); - module_clone->set_layout_canonicalization_callback( - module->layout_canonicalization_callback()); - auto pass_result = pass->RunAutoSharding( - module_clone.get(), replicated_small_tensors, execution_threads); + auto module_clone = CloneModule(module); + auto pass_result = + pass->RunAutoSharding(module_clone.get(), replicated_small_tensors, + execution_threads, sharding_propagation_solution); changed[i] = pass_result; objective_values[i] = pass->GetSolverOptimalObjectiveValue(); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index ca05cba5e23c3..ab51cac5af966 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -69,7 +69,9 @@ class AutoShardingImplementation { StatusOr<AutoShardingResult> RunAutoSharding( HloModule* module, const absl::flat_hash_set<std::string>& replicated_small_tensors, - const absl::flat_hash_set<absl::string_view>& execution_threads); + const absl::flat_hash_set<absl::string_view>& execution_threads, + const absl::flat_hash_map<std::string, const HloInstruction*>& + sharding_propagation_solution = {}); // Removes SPMD annotations (if there are) to test AutoSharding on manually // annotated graphs. @@ -210,13 +212,13 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, const ClusterEnvironment& cluster_env); // The high-level "recipe" for solving an Auto Sharding problem. -AutoShardingSolverResult Solve(const HloLiveRange& hlo_live_range, - const LivenessSet& liveness_set, - const StrategyMap& strategy_map, - const LeafStrategies& leaf_strategies, - const CostGraph& cost_graph, - const AliasSet& alias_set, - const AutoShardingOption& option); +AutoShardingSolverResult Solve( + const HloLiveRange& hlo_live_range, const LivenessSet& liveness_set, + const StrategyMap& strategy_map, const LeafStrategies& leaf_strategies, + const CostGraph& cost_graph, const AliasSet& alias_set, + const AutoShardingOption& option, + const absl::flat_hash_map<std::string, const HloInstruction*>& + sharding_propagation_solution = {}); // Populates temporal distance values. void PopulateTemporalValues(const CostGraph& cost_graph, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index 51052f7d8efd4..2a4afaa108bbd 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -23,19 +23,19 @@ limitations under the License. namespace xla { namespace spmd { -AutoShardingSolverResult Solve(const HloLiveRange& hlo_live_range, - const LivenessSet& liveness_set, - const StrategyMap& strategy_map, - const LeafStrategies& leaf_strategies, - const CostGraph& cost_graph, - const AliasSet& alias_set, - const AutoShardingOption& option) { +AutoShardingSolverResult Solve( + const HloLiveRange& hlo_live_range, const LivenessSet& liveness_set, + const StrategyMap& strategy_map, const LeafStrategies& leaf_strategies, + const CostGraph& cost_graph, const AliasSet& alias_set, + const AutoShardingOption& option, + const absl::flat_hash_map<std::string, const HloInstruction*>& + sharding_propagation_solution) { return CallSolver( hlo_live_range, liveness_set, strategy_map, leaf_strategies, cost_graph, alias_set, /*s_hint*/ {}, option.memory_budget_per_device, /*crash_at_infinity_costs_check*/ !option.try_multiple_mesh_shapes, /*compute_iis*/ true, option.solver_timeout_in_seconds, - option.allow_alias_to_follower_conversion); + option.allow_alias_to_follower_conversion, sharding_propagation_solution); } void PopulateTemporalValues(const CostGraph& cost_graph, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index b14df2abe7697..465bfd1cdbb20 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -175,6 +175,11 @@ struct AutoShardingOption { // sharding. int64_t small_tensor_byte_size = 0; + // In order to obtain default sharding strategies for instructions to limit + // departures from the defaults, use sharding propagation instead of assuming + // a simple replicated default. + bool use_sharding_propagation_for_default_shardings = true; + std::string ToString() { std::vector<std::string> lines; lines.push_back(absl::StrCat("preserve_shardings: ", preserve_shardings)); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index b22713b4de9b4..82913283f35df 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -36,7 +36,9 @@ AutoShardingSolverResult CallSolver( const std::vector<NodeStrategyIdx>& s_hint, int64_t memory_budget_per_device, bool crash_at_infinity_costs_check, bool compute_iis, int64_t solver_timeout_in_seconds, - bool allow_alias_to_follower_conversion); + bool allow_alias_to_follower_conversion, + const absl::flat_hash_map<std::string, const HloInstruction*>& + sharding_propagation_solution = {}); } // namespace spmd } // namespace xla