diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc
index d4ad82c32e5a4..35e334f120ca5 100644
--- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc
+++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc
@@ -2316,7 +2316,9 @@ AutoShardingSolverResult CallSolver(
     const std::vector<NodeStrategyIdx>& s_hint,
     int64_t memory_budget_per_device, bool crash_at_infinity_costs_check,
     bool compute_iis, int64_t solver_timeout_in_seconds,
-    bool allow_alias_to_follower_conversion) {
+    bool allow_alias_to_follower_conversion,
+    const absl::flat_hash_map<std::string, const HloInstruction*>&
+        sharding_propagation_solution) {
   // Serialize edges and edge costs to 1d numpy arrays
   AutoShardingSolverRequest request;
   request.num_nodes = leaf_strategies.size();
@@ -2346,10 +2348,16 @@ AutoShardingSolverResult CallSolver(
   // Serialize node costs
   for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) {
     const StrategyVector* strategies = leaf_strategies[node_idx];
+    auto instruction_name = instructions.at(strategies->instruction_id)->name();
     request.instruction_names.push_back(
-        absl::StrCat(instructions.at(strategies->instruction_id)->name(),
-                     " (id: ", node_idx, ")"));
+        absl::StrCat(instruction_name, " (id: ", node_idx, ")"));
     std::vector<double> ci, di, mi, pi;
+    auto default_strategy = HloSharding::Replicate();
+    auto iter = sharding_propagation_solution.find(instruction_name);
+    if (iter != sharding_propagation_solution.end()) {
+      CHECK(iter->second->has_sharding()) << iter->second->ToString();
+      default_strategy = iter->second->sharding();
+    }
     for (NodeStrategyIdx j = 0; j < strategies->leaf_vector.size(); ++j) {
       const ShardingStrategy& strategy = strategies->leaf_vector[j];
       const HloSharding& sharding = strategy.output_sharding;
@@ -2359,7 +2367,7 @@ AutoShardingSolverResult CallSolver(
       mi.push_back(strategy.memory_cost);
       // TODO(moffitt): Revisit the default strategy below, which is currently
       // defined as the "trivial sharding" in hlo_sharding.h
-      pi.push_back(sharding.IsReplicated() && !sharding.IsManual() ? 0.0 : 1.0);
+      pi.push_back(sharding == default_strategy ? 0.0 : 1.0);
     }
     request.c.push_back(ci);
     request.d.push_back(di);
@@ -3974,7 +3982,9 @@ AutoShardingImplementation::AutoShardingImplementation(
 StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
     HloModule* module,
     const absl::flat_hash_set<std::string>& replicated_small_tensors,
-    const absl::flat_hash_set<absl::string_view>& execution_threads) {
+    const absl::flat_hash_set<absl::string_view>& execution_threads,
+    const absl::flat_hash_map<std::string, const HloInstruction*>&
+        sharding_propagation_solution) {
   if (!option_.enable) {
     return AutoShardingResult::kModuleUnchanged;
   }
@@ -4226,7 +4236,7 @@ StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
     if (!solver_option.load_solution_vector) {
       auto solver_result =
           Solve(*hlo_live_range, liveness_set, strategy_map, leaf_strategies,
-                cost_graph, alias_set, option_);
+                cost_graph, alias_set, option_, sharding_propagation_solution);
       if (solver_result.skip_auto_sharding) {
         return AutoShardingResult::kModuleUnchangedNoShardingPerfomed;
       } else if (!solver_result.status.ok()) {
@@ -4317,6 +4327,13 @@ bool IsModuleManuallySharded(const HloModule* module) {
   return false;
 }
 
+std::unique_ptr<HloModule> CloneModule(const HloModule* module) {
+  auto module_clone = module->Clone("");
+  module_clone->set_layout_canonicalization_callback(
+      module->layout_canonicalization_callback());
+  return module_clone;
+}
+
 StatusOr<bool> AutoSharding::Run(
     HloModule* module,
     const absl::flat_hash_set<absl::string_view>& execution_threads) {
@@ -4383,6 +4400,35 @@ StatusOr<bool> AutoSharding::Run(
     mesh_shapes.push_back(option_.device_mesh_shape);
   }
 
+  absl::flat_hash_map<std::string, const HloInstruction*>
+      sharding_propagation_solution;
+  std::unique_ptr<HloModule> module_with_default_solution = nullptr;
+  if (option_.use_sharding_propagation_for_default_shardings) {
+    module_with_default_solution = CloneModule(module);
+    // TODO(pratikf): Ensure that we're passing the correct customc all sharding
+    // helper to the sharding propagation pass.
+    auto sharding_prop = ShardingPropagation(
+        /*is_spmd */ true, /*propagate_metadata */ false,
+        /*allow_spmd_sharding_propagation_to_output*/
+        module->config().allow_spmd_sharding_propagation_to_output(),
+        /*allow_spmd_sharding_propagation_to_parameters */
+        absl::InlinedVector<bool, 1>{false},
+        /*cse_prevention_only */ false,
+        /*sharding_helper*/ nullptr);
+
+    CHECK_OK(sharding_prop.Run(module_with_default_solution.get(),
+                               execution_threads));
+    LOG(INFO) << module_with_default_solution->ToString();
+    for (const auto computation :
+         module_with_default_solution->computations()) {
+      for (const auto instruction : computation->instructions()) {
+        if (instruction->has_sharding()) {
+          sharding_propagation_solution[instruction->name()] = instruction;
+        }
+      }
+    }
+  }
+
   size_t num_meshes = mesh_shapes.size();
   std::vector<std::unique_ptr<HloModule>> modules(num_meshes);
   std::vector<StatusOr<AutoShardingResult>> changed(
@@ -4399,11 +4445,10 @@ StatusOr<bool> AutoSharding::Run(
     AutoShardingOption this_option = option_;
     this_option.device_mesh_shape = mesh_shapes[i];
     auto pass = new AutoShardingImplementation(this_option);
-    auto module_clone = module->Clone("");
-    module_clone->set_layout_canonicalization_callback(
-        module->layout_canonicalization_callback());
-    auto pass_result = pass->RunAutoSharding(
-        module_clone.get(), replicated_small_tensors, execution_threads);
+    auto module_clone = CloneModule(module);
+    auto pass_result =
+        pass->RunAutoSharding(module_clone.get(), replicated_small_tensors,
+                              execution_threads, sharding_propagation_solution);
 
     changed[i] = pass_result;
     objective_values[i] = pass->GetSolverOptimalObjectiveValue();
diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h
index ca05cba5e23c3..ab51cac5af966 100644
--- a/xla/hlo/experimental/auto_sharding/auto_sharding.h
+++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h
@@ -69,7 +69,9 @@ class AutoShardingImplementation {
   StatusOr<AutoShardingResult> RunAutoSharding(
       HloModule* module,
       const absl::flat_hash_set<std::string>& replicated_small_tensors,
-      const absl::flat_hash_set<absl::string_view>& execution_threads);
+      const absl::flat_hash_set<absl::string_view>& execution_threads,
+      const absl::flat_hash_map<std::string, const HloInstruction*>&
+          sharding_propagation_solution = {});
 
   // Removes SPMD annotations (if there are) to test AutoSharding on manually
   // annotated graphs.
@@ -210,13 +212,13 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins,
                                    const ClusterEnvironment& cluster_env);
 
 // The high-level "recipe" for solving an Auto Sharding problem.
-AutoShardingSolverResult Solve(const HloLiveRange& hlo_live_range,
-                               const LivenessSet& liveness_set,
-                               const StrategyMap& strategy_map,
-                               const LeafStrategies& leaf_strategies,
-                               const CostGraph& cost_graph,
-                               const AliasSet& alias_set,
-                               const AutoShardingOption& option);
+AutoShardingSolverResult Solve(
+    const HloLiveRange& hlo_live_range, const LivenessSet& liveness_set,
+    const StrategyMap& strategy_map, const LeafStrategies& leaf_strategies,
+    const CostGraph& cost_graph, const AliasSet& alias_set,
+    const AutoShardingOption& option,
+    const absl::flat_hash_map<std::string, const HloInstruction*>&
+        sharding_propagation_solution = {});
 
 // Populates temporal distance values.
 void PopulateTemporalValues(const CostGraph& cost_graph,
diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
index 51052f7d8efd4..2a4afaa108bbd 100644
--- a/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
+++ b/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
@@ -23,19 +23,19 @@ limitations under the License.
 namespace xla {
 namespace spmd {
 
-AutoShardingSolverResult Solve(const HloLiveRange& hlo_live_range,
-                               const LivenessSet& liveness_set,
-                               const StrategyMap& strategy_map,
-                               const LeafStrategies& leaf_strategies,
-                               const CostGraph& cost_graph,
-                               const AliasSet& alias_set,
-                               const AutoShardingOption& option) {
+AutoShardingSolverResult Solve(
+    const HloLiveRange& hlo_live_range, const LivenessSet& liveness_set,
+    const StrategyMap& strategy_map, const LeafStrategies& leaf_strategies,
+    const CostGraph& cost_graph, const AliasSet& alias_set,
+    const AutoShardingOption& option,
+    const absl::flat_hash_map<std::string, const HloInstruction*>&
+        sharding_propagation_solution) {
   return CallSolver(
       hlo_live_range, liveness_set, strategy_map, leaf_strategies, cost_graph,
       alias_set, /*s_hint*/ {}, option.memory_budget_per_device,
       /*crash_at_infinity_costs_check*/ !option.try_multiple_mesh_shapes,
       /*compute_iis*/ true, option.solver_timeout_in_seconds,
-      option.allow_alias_to_follower_conversion);
+      option.allow_alias_to_follower_conversion, sharding_propagation_solution);
 }
 
 void PopulateTemporalValues(const CostGraph& cost_graph,
diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
index b14df2abe7697..465bfd1cdbb20 100644
--- a/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
+++ b/xla/hlo/experimental/auto_sharding/auto_sharding_option.h
@@ -175,6 +175,11 @@ struct AutoShardingOption {
   // sharding.
   int64_t small_tensor_byte_size = 0;
 
+  // In order to obtain default sharding strategies for instructions to limit
+  // departures from the defaults, use sharding propagation instead of assuming
+  // a simple replicated default.
+  bool use_sharding_propagation_for_default_shardings = true;
+
   std::string ToString() {
     std::vector<std::string> lines;
     lines.push_back(absl::StrCat("preserve_shardings: ", preserve_shardings));
diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h
index b22713b4de9b4..82913283f35df 100644
--- a/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h
+++ b/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h
@@ -36,7 +36,9 @@ AutoShardingSolverResult CallSolver(
     const std::vector<NodeStrategyIdx>& s_hint,
     int64_t memory_budget_per_device, bool crash_at_infinity_costs_check,
     bool compute_iis, int64_t solver_timeout_in_seconds,
-    bool allow_alias_to_follower_conversion);
+    bool allow_alias_to_follower_conversion,
+    const absl::flat_hash_map<std::string, const HloInstruction*>&
+        sharding_propagation_solution = {});
 
 }  // namespace spmd
 }  // namespace xla