Skip to content

Commit

Permalink
PR #18773: [ReduceScatterCombiner] Provide option to not combine with…
Browse files Browse the repository at this point in the history
…in while loop bodies.

Imported from GitHub PR #18773

Same as #18772 but for reduce-scatters. Copying from #18772

This PR provides an option to disable combining reduce-scatters inside while loop bodies.
It is set to true, so existing behavior is maintained.

This option is provided as some strategies for FSDP may only want to coalesce collectives that are outside of a while loop. Collectives inside while loop are not coalesced, as we assume there is sufficient compute to overlap.
Copybara import of the project:

--
9a7d247 by ptoulme-aws <[email protected]>:

[ReduceScatterCombiner] Provide option to not combine within while loop bodies.

Merging this change closes #18773

COPYBARA_INTEGRATE_REVIEW=#18773 from ptoulme-aws:reduce_scatter_combine_while 9a7d247
PiperOrigin-RevId: 694566690
  • Loading branch information
ptoulme-aws authored and Google-ML-Automation committed Nov 8, 2024
1 parent 3e87afa commit 22c2e04
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 7 deletions.
12 changes: 10 additions & 2 deletions xla/service/reduce_scatter_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ absl::StatusOr<bool> ReduceScatterCombiner::RunWithKeyCombiner(
bool changed = false;
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
if (!combine_while_loops_ && computation->IsWhileBodyComputation()) {
VLOG(2) << "Skipping this computation because the computation is a while "
"loop body: "
<< computation->ToString();
continue;
}
TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));

auto key_fn = [&](const HloInstruction* instruction) {
Expand All @@ -240,10 +246,12 @@ absl::StatusOr<bool> ReduceScatterCombiner::RunWithKeyCombiner(

ReduceScatterCombiner::ReduceScatterCombiner(int64_t combine_threshold_in_bytes,
int64_t combine_threshold_count,
bool combine_by_dim)
bool combine_by_dim,
bool combine_while_loops)
: combine_threshold_in_bytes_(combine_threshold_in_bytes),
combine_threshold_count_(combine_threshold_count),
combine_by_dim_(combine_by_dim) {}
combine_by_dim_(combine_by_dim),
combine_while_loops_(combine_while_loops) {}

absl::StatusOr<bool> ReduceScatterCombiner::Run(
HloModule* module,
Expand Down
6 changes: 5 additions & 1 deletion xla/service/reduce_scatter_combiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ namespace xla {
class ReduceScatterCombiner : public HloModulePass {
public:
ReduceScatterCombiner(int64_t combine_threshold_in_bytes,
int64_t combine_threshold_count, bool combine_by_dim);
int64_t combine_threshold_count, bool combine_by_dim,
bool combine_while_loops = true);

absl::string_view name() const override { return "reduce-scatter-combiner"; }

Expand Down Expand Up @@ -77,6 +78,9 @@ class ReduceScatterCombiner : public HloModulePass {

// Combine only reduce-scatter ops with the same dimension.
bool combine_by_dim_;

// Combine reduce-scatter ops that are inside while loop body computations.
bool combine_while_loops_;
};

} // namespace xla
Expand Down
55 changes: 51 additions & 4 deletions xla/service/reduce_scatter_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,16 @@ class ReduceScatterCombinerTest : public HloTestBase {
absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
absl::string_view hlo_module, bool expect_change,
int64_t byte_threshold = kMaxByteCount,
int64_t count_threshold = kMaxCombineCount, bool combine_by_dim = true) {
int64_t count_threshold = kMaxCombineCount, bool combine_by_dim = true,
bool combine_while_loops = true) {
TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module));

VLOG(1) << "Before running ReduceScatterCombiner: "
<< ReduceScatterCount(module.get()) << " reduce-scatter ops";

auto changed =
ReduceScatterCombiner(byte_threshold, count_threshold, combine_by_dim)
.Run(module.get());
auto changed = ReduceScatterCombiner(byte_threshold, count_threshold,
combine_by_dim, combine_while_loops)
.Run(module.get());
if (!changed.ok()) {
return changed.status();
}
Expand Down Expand Up @@ -302,5 +303,51 @@ ENTRY main {
EXPECT_EQ(ReduceScatterCount(module.get()), 1);
}

TEST_F(ReduceScatterCombinerTest, DoNotCombineInWhileLoop) {
absl::string_view hlo_string = R"(
HloModule m
sum_reduce {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_body {
param = (bf16[1024,32768]{1,0}, bf16[4096,8192]{1,0}) parameter(0)
param.0 = bf16[1024,32768]{1,0} get-tuple-element(param), index=0
param.1 = bf16[4096,8192]{1,0} get-tuple-element(param), index=1
reduce-scatter.0 = bf16[1024,32768]{1,0} reduce-scatter(param.0),
channel_id=132, replica_groups={{0}}, dimensions={0}, to_apply=sum_reduce
reduce-scatter.1 = bf16[4096,8192]{1,0} reduce-scatter(param.1),
channel_id=134, replica_groups={{0}}, dimensions={0}, to_apply=sum_reduce
ROOT tuple = tuple(reduce-scatter.0, reduce-scatter.1)
}
while_cond {
param = (bf16[1024,32768], bf16[4096,8192]) parameter(0)
ROOT cond = pred[] constant(true)
}
ENTRY main {
param.0 = bf16[1024,32768]{1,0} parameter(0)
param.1 = bf16[4096,8192]{1,0} parameter(1)
while_init = (bf16[1024,32768], bf16[4096,8192]) tuple(param.0, param.1)
while_loop = (bf16[1024,32768], bf16[4096,8192]) while(while_init), condition=while_cond, body=while_body
gte.0 = bf16[1024,32768] get-tuple-element(while_loop), index=0
gte.1 = bf16[4096,8192] get-tuple-element(while_loop), index=1
ROOT tuple = tuple(gte.0, gte.1)
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

ReduceScatterCombiner combine(1024 * 1024, kMaxCombineCount,
/*combine_by_dim=*/false,
/*combine_while_loops=*/false);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_FALSE(changed);
}

} // namespace
} // namespace xla

0 comments on commit 22c2e04

Please sign in to comment.