Skip to content

Commit

Permalink
Multidevice segmenter (#1696)
Browse files Browse the repository at this point in the history
# What
Add an option in the segmenter to segment resharding Expr in separate
singleton segment.
To trigger it, set the segmenter's options as follows:
```
    SegmentCandidateFinderOptions options{
        .run_translate_welford = false,
        .run_combine_reductions = false,
        .run_herrmann_merge = true,
        .run_final_merge = true,
        .only_segment_resharding_exprs = true};
```
and use the segmenter as follows with any (possibly dummy) inputs:
```
KernelArgumentHolder dummy_inputs;
auto segmented_fusion = SegmentCandidateFinder::segment(std::move(fusion), dummy_inputs, options);
```
If `only_segment_resharding_exprs` is set to `false` (which is the case
by default), the behavior of the segmenter is unchanged.


We also provide a quite wide testing suite to validate our
implementation.

# Why 
Resharding Exprs need to be handled differently than other Exprs because
we want them to result in posting a network collective from the host.
Therefore those expressions cannot (for now) be fused to any kernel. For
this reason, we need those Expr to be segmented before and after.

# How
_**Remark:** For now, the segmenter is only used [at one place before
scheduling and compiling the
fusion](https://github.com/NVIDIA/Fuser/blob/1603f39bab8c1bbe12e38f2b5de53dec3b7cc373/csrc/kernel_cache.cpp#L990)._

Recall that the segmenter first creates as many segments as there are
Expr and then tries to merge the neighbour segments incrementally in an
eager manner. The method
```
bool SegmentCandidateFinder::codeGenSupportedMerge(
    SegmentedGroup* group1,
    SegmentedGroup* group2) 
```
returns whether two groups can be merged (i.e. fused into one kernel). 

With the current patch, if
`SegmentCandidateFinderOptions::only_segment_resharding_exprs` is set to
`true`, then the usual behavior of `codeGenSupportedMerge` is bypassed
and the function returns whether one Expr among the groups is
resharding.

Because this segmentation shouldn't depend on the inputs data, we use
default (aka empty) `KernelArgumentHolder`, from which it is invalid to
instantiate a `SchedulerRuntimeInfo runtime_info_`. For this reason, we
had to make the latter attribute optional.

# Future/other directions

Another way to achieve the same result is to manually add segment bounds
surrounding the resharding Exprs as was suggested by @wujingyue here
#1571

The current implementation looks a bit "hacky" and should be be
integrated more properly once multidevice schedulers are implemented
and/or the segmenter is refactored.

Later, we might wanna be able to fuse communications and computes and
also communications between them. This would require a more advanced
segmenter and scheduler, but hopefully this patch could serve as a good
basis

# Example:
consider the fusion:
```
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  TensorView* tv0 = makeContigTensor({4});
  fusion->addInput(tv0);
  TensorView* tv1 = sum(tv0,{3});
  TensorView* tv2 = set(tv1);
  TensorView* tv3 = sum(tv2, {2});
  fusion->addOutput(tv3);
```

Manually scheduled as follows:
```
  DeviceMesh mesh ({0,1,2,3})
  for (auto tv : {tv0, tv1, tv2, tv3}) {
    tv->setDeviceMesh(mesh);
  }
  tv0->axis(0)->parallelize(ParallelType::DIDx);
  tv1->axis(0)->parallelize(ParallelType::DIDx);
```
This scheduling implies that
- `tv0` and `tv1` are fully sharded on the devices {0,1,2,3}
- `tv2` and `tv3` are fully replicated on those same devices
- consequently, the "set" operation on the line `tv2 = set(tv1)`
actually embedds an "AllGather" network collective. This Expr is
resharding while all the other exprs are not. We thus excpect this
expression to constitute an unmergeable segment.

The segmenter in this situation with the
option`SegmentCandidateFinderOptions::only_segment_resharding_exprs` set
to `true` will result in three segments:
- Compute segment 1: with the expr `tv1 = sum(tv0,{3})`
- Communication segment 1:  with the expr `tv2 = set(tv1)`
- Compute segment 2: with the expr `tv3 = sum(tv2, {2})`
  • Loading branch information
samnordmann authored Feb 1, 2024
1 parent 01ff3a2 commit 3e1e11e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 21 deletions.
2 changes: 1 addition & 1 deletion csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void swap(Fusion& a, Fusion& b) noexcept {
std::unique_ptr<SegmentedFusion> Fusion::segment(
const KernelArgumentHolder& args) {
FUSER_PERF_SCOPE("Segment Fusion");
return SegmentCandidateFinder::segment(this, args);
return SegmentCandidateFinder::segment(this, &args);
}

IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
Expand Down
40 changes: 33 additions & 7 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ir/graphviz.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <multidevice/utils.h>
#include <ops/arith.h>
#include <scheduler/debug_utils.h>
#include <scheduler/normalization_utils.h>
Expand Down Expand Up @@ -1866,7 +1867,7 @@ std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) {

std::unique_ptr<SegmentedFusion> SegmentCandidateFinder::segment(
std::unique_ptr<Fusion> fusion,
const KernelArgumentHolder& inputs,
const KernelArgumentHolder* inputs,
SchedulerRuntimeInfo& runtime_info) {
if (!hasSegmentHints(fusion.get())) {
scheduler_debug_utils::canScheduleMessage(
Expand All @@ -1875,7 +1876,7 @@ std::unique_ptr<SegmentedFusion> SegmentCandidateFinder::segment(
SchedulerEntry::proposeHeuristics(fusion.get(), runtime_info);
if (maybe_complete_fusion_heuristic.has_value()) {
return SegmentedFusion::fromCompleteFusion(
std::move(fusion), maybe_complete_fusion_heuristic.value(), inputs);
std::move(fusion), maybe_complete_fusion_heuristic.value(), *inputs);
}
}
if (fusion) {
Expand Down Expand Up @@ -3548,27 +3549,52 @@ bool SegmentCandidateFinder::codeGenSupportedMerge(
NVF_ERROR(
areDirectlyConnected(group1, group2),
"only support testing immediate producer-consumer groups");
auto h = tryMerge(segmented_fusion_.get(), runtime_info_, group1, group2);
if (options_.only_segment_resharding_exprs) {
for (auto group : {group1, group2}) {
for (auto expr : group->exprs()) {
if (isResharding(expr)) {
return false;
}
}
}
return true;
}
auto h = tryMerge(segmented_fusion_.get(), runtimeInfo(), group1, group2);
return h.has_value();
}

// TODO: consider caching the heuristics value so tryMerge doesn't have to be
// called twice
ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic(
SegmentedGroup* group) {
auto h = tryMerge(segmented_fusion_.get(), runtime_info_, group);
if (options_.only_segment_resharding_exprs) {
// We don't need to generate a heuristic for multidevice segments at this
// moment
return ScheduleHeuristic::None;
}
auto h = tryMerge(segmented_fusion_.get(), runtimeInfo(), group);
NVF_ERROR(
h.has_value(), "Can not find a scheduler to schedule fusion segment");
return h.value();
}

SegmentCandidateFinder::SegmentCandidateFinder(
std::unique_ptr<Fusion> fusion,
const KernelArgumentHolder& inputs,
const KernelArgumentHolder* inputs,
SegmentCandidateFinderOptions options)
: options_(options),
runtime_info_(fusion.get(), inputs),
runtime_info_(
inputs == nullptr ? std::nullopt
: std::make_optional<SchedulerRuntimeInfo>(
fusion.get(),
*inputs)),
runtime_inputs_(inputs) {
NVF_ERROR(
!options_.only_segment_resharding_exprs ||
(!options_.run_translate_welford &&
!options_.run_combine_reductions && options_.run_herrmann_merge &&
options_.run_final_merge),
"Invalid Segmenter options");
segmented_fusion_ = std::make_unique<SegmentedFusion>(std::move(fusion));
findSegments();
}
Expand Down Expand Up @@ -3688,7 +3714,7 @@ void SegmentCandidateFinder::findSegments() {

if (options_.run_translate_welford && has_welford_ops) {
if (TranslateApplicableWelford::run(
segmented_fusion_.get(), runtime_inputs_)) {
segmented_fusion_.get(), *runtime_inputs_)) {
// If modified, rebuild segments as existing expressions may be
// pulled into welford groups
buildInitialSegments();
Expand Down
20 changes: 12 additions & 8 deletions csrc/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ struct SegmentCandidateFinderOptions {
bool run_combine_reductions = true;
bool run_herrmann_merge = true;
bool run_final_merge = true;
bool only_segment_resharding_exprs = false;
};

//! SegmentCandidateFinder
Expand Down Expand Up @@ -558,7 +559,7 @@ class SegmentCandidateFinder {
// Perform segmentation on a copy of the given fusion
static std::unique_ptr<SegmentedFusion> segment(
const Fusion* fusion,
const KernelArgumentHolder& inputs,
const KernelArgumentHolder* inputs,
SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) {
auto fusion_copy = std::make_unique<Fusion>(*fusion);
return segment(std::move(fusion_copy), inputs, options);
Expand All @@ -567,7 +568,7 @@ class SegmentCandidateFinder {
// Perform segmentation on and take ownership of the given fusion
static std::unique_ptr<SegmentedFusion> segment(
std::unique_ptr<Fusion> fusion,
const KernelArgumentHolder& inputs,
const KernelArgumentHolder* inputs,
SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) {
if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) {
debug() << "Segment the fusion (Original Fusion Un-modified): "
Expand All @@ -580,7 +581,7 @@ class SegmentCandidateFinder {

static std::unique_ptr<SegmentedFusion> segment(
std::unique_ptr<Fusion> fusion,
const KernelArgumentHolder& inputs,
const KernelArgumentHolder* inputs,
SchedulerRuntimeInfo& runtime_info);

static bool hasSegmentHints(Fusion* fusion);
Expand All @@ -593,7 +594,7 @@ class SegmentCandidateFinder {
// Perform segmentation on and take ownership of the given fusion
SegmentCandidateFinder(
std::unique_ptr<Fusion> fusion,
const KernelArgumentHolder& inputs,
const KernelArgumentHolder* inputs,
SegmentCandidateFinderOptions options);

void resetTraversal();
Expand Down Expand Up @@ -637,11 +638,12 @@ class SegmentCandidateFinder {
}

SchedulerRuntimeInfo& runtimeInfo() {
return runtime_info_;
NVF_ERROR(runtime_info_.has_value(), "needs runtime info");
return runtime_info_.value();
}

ExpressionEvaluator& expressionEvaluator() {
return runtime_info_.expressionEvaluator();
return runtimeInfo().expressionEvaluator();
}

//! Additional merging iteration, clean up the rest of
Expand Down Expand Up @@ -751,7 +753,9 @@ class SegmentCandidateFinder {
// unary ops on inputs to the complete fusion
VectorOfUniqueEntries<Expr*> excluded_inp_unary_exprs_;

SchedulerRuntimeInfo runtime_info_;
// This is allowed to be null in the multidevice case where the segmenter is
// used for breaking the fusion into compute and communication segments
std::optional<SchedulerRuntimeInfo> runtime_info_;

//! Note:
//! Segmenter should eventually rely only on runtime_info_ for
Expand All @@ -769,7 +773,7 @@ class SegmentCandidateFinder {
//! TODO:
//! implement the expression evaluator transfer and
//! remove runtime_inputs_ in a follow up.
const KernelArgumentHolder& runtime_inputs_;
const KernelArgumentHolder* runtime_inputs_;
};

// TODO: Make as member functions on classes instead of global scope
Expand Down
2 changes: 1 addition & 1 deletion csrc/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ FusionKernelRuntime::FusionKernelRuntime(
// Default compilation path applies segmentation before scheduling and
// compiling the fusion.
segmented_fusion_ =
SegmentCandidateFinder::segment(std::move(fusion), args, runtime_info);
SegmentCandidateFinder::segment(std::move(fusion), &args, runtime_info);
} else {
// Serialization path that generates segmented fusion from flatbuffers.
// Convert Welford to two-pass if option is enabled and the original
Expand Down
8 changes: 4 additions & 4 deletions test/test_gpu2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5210,7 +5210,7 @@ TEST_F(NVFuserTest, FusionSegmentVerticalMerge_CUDA) {
args.push(t0);

auto segmented_fusion =
SegmentCandidateFinder::segment(fusion.get(), args, segment_options);
SegmentCandidateFinder::segment(fusion.get(), &args, segment_options);

NVF_CHECK(segmented_fusion->groups().size() == 2);
}
Expand Down Expand Up @@ -5256,7 +5256,7 @@ TEST_F(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) {
args.push(scalar);

auto segmented_fusion =
SegmentCandidateFinder::segment(fusion.get(), args, segment_options);
SegmentCandidateFinder::segment(fusion.get(), &args, segment_options);

NVF_CHECK(segmented_fusion->groups().size() == 2);
}
Expand Down Expand Up @@ -5299,7 +5299,7 @@ TEST_F(NVFuserTest, FusionSegmentMixReduction_CUDA) {
args.push(t0);

auto segmented_fusion =
SegmentCandidateFinder::segment(fusion.get(), args, segment_options);
SegmentCandidateFinder::segment(fusion.get(), &args, segment_options);

NVF_CHECK(segmented_fusion->groups().size() <= 2);
}
Expand Down Expand Up @@ -7844,7 +7844,7 @@ TEST_F(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) {
for (auto i : c10::irange(5)) {
(void)i; // Suppress unused variable warning
auto segmented_fusion =
SegmentCandidateFinder::segment(fusion_ptr.get(), args);
SegmentCandidateFinder::segment(fusion_ptr.get(), &args);
}
}

Expand Down
24 changes: 24 additions & 0 deletions test/test_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
// clang-format on
#include <gtest/gtest.h>

#include <executor_kernel_arg.h>
#include <fusion.h>
#include <fusion_segmenter.h>
#include <ir/all_nodes.h>
#include <ir/builder.h>
#include <multidevice/lower_communication.h>
Expand Down Expand Up @@ -396,6 +398,28 @@ class automaticReshardingTest
GTEST_EXPECT_TRUE(!isResharding(expr) || isLowerableToCommunication(expr))
<< "on expr=" << expr;
}

SegmentCandidateFinderOptions options{
.run_translate_welford = false,
.run_combine_reductions = false,
.run_herrmann_merge = true,
.run_final_merge = true,
.only_segment_resharding_exprs = true};

auto segmented_fusion =
SegmentCandidateFinder::segment(std::move(fusion), nullptr, options);

for (SegmentedGroup* group : segmented_fusion->groups()) {
GTEST_EXPECT_TRUE(
std::none_of(
group->exprs().begin(),
group->exprs().end(),
[](auto expr) { return isResharding(expr); }) ||
(group->exprs().size() == 1 && isResharding(group->exprs().at(0))));
}
// checks that the segments are disjoints and that the graph of segment is
// acyclic
segmented_fusion->validate();
}

std::unique_ptr<Fusion> fusion;
Expand Down

0 comments on commit 3e1e11e

Please sign in to comment.