Skip to content

Commit

Permalink
Allow non-root trivial reductions (#2037)
Browse files Browse the repository at this point in the history
* Allow non-root trivial reductions

Fixes #2008

Co-authored-by: Christian Sarofeen <[email protected]>
  • Loading branch information
naoyam and csarofeen authored Oct 6, 2022
1 parent a2dfe40 commit 0b8e83f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 27 deletions.
12 changes: 2 additions & 10 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1416,16 +1416,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
}

//! Check if IterDomain is a reduction axis with size of 1, i.e.
//! a "squeeze" operator.
//!
//! NOTE: Detection of trivial reduction here is not
//! comprehensive. See detectTrivialReductionDerivedDomains for more
//! comprehensive analysis. We typically use this for root domain trivial
//! reduction checks. So we ship to the correct scheduler. It may
//! not be incredibly robust, but it makes sense to keep it for now.
bool isTrivialReduction() const {
return isReduction() && extent()->isOneInt();
}
//! a "squeeze" operator, or solely derived from such axes.
bool isTrivialReduction() const;

//! Split for stride by a given factor. It effectively does an inner
//! split by the factor and sets the inner domain as a Stride
Expand Down
37 changes: 36 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,37 @@ IterDomain* IterDomain::cloneWithoutRFactor() const {
return cloned;
}

bool IterDomain::isTrivialReduction() const {
if (!isReduction()) {
return false;
}

if (extent()->isOneInt()) {
return true;
}

// If this domain is an output of an expression, i.e., not a root
// domain, check if all root domains are trivial reductions. This is
// almost the same as the analysis done in TrivialReductionInfo, but
// is limited within a single tensor, whereas TrivialReductionInfo
// does more expensive analysis potentially traversing through
// rfactor domains
if (definition()) {
// Note: There's no const version of IterVisitor.
auto id_inputs = InputsOf::output(fusion(), const_cast<IterDomain*>(this));
if (std::all_of(
ir_utils::filterByType<IterDomain>(id_inputs).begin(),
ir_utils::filterByType<IterDomain>(id_inputs).end(),
[](IterDomain* root_id) {
return root_id->isReduction() && root_id->extent()->isOneInt();
})) {
return true;
}
}

return false;
}

std::vector<IterDomain*> IterDomain::clone(
const std::vector<IterDomain*>& domains) {
std::vector<IterDomain*> cloned_domains;
Expand All @@ -1744,7 +1775,11 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
outer->isReduction() == inner->isReduction() ||
(!outer->isReduction() && inner->isTrivialReduction()) ||
(outer->isTrivialReduction() && !inner->isReduction()),
"Merging IterDomains requires that their iteration types match.");
"Merging IterDomains requires that their iteration types match. ",
"Outer: ",
outer->toString(),
", Inner: ",
inner->toString());
TORCH_CHECK(
(outer->isGather() && inner->isGather()) ||
(!outer->isGather() && !inner->isGather()),
Expand Down
63 changes: 55 additions & 8 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22304,26 +22304,40 @@ TEST_F(NVFuserTest, FusionTrivialReductionForwarding3_CUDA) {
auto tv2 = add(tv1, IrBuilder::create<Double>(1));
fusion.addOutput(tv2);

// Similar pattern as FusionTrivialReductionForwarding2 but no
// trivial reduciton at the root domain
// Similar pattern as FusionTrivialReductionForwarding2 but trivial
// reduciton at non-root domain

// Create a trivial reduction by splitting with a factor of 1
tv1->split(1, 1, false);
// Merging with a trivial reduction
tv1->merge(0, 1);
auto tv1_merge_out_id = tv1->axis(0);
tv1->split(0, 5);

tv2->split(0, 5);

// While the merge of tv1 is done with a trivial reduciton, it's not
// a root domain, so forwarding is not enabled. BestEffortReplay
// should only map the first axis of each tensor.
// The merge of tv1 is done with a non-root trivial
// reduciton. BestEffortReplay should forward the merge.

PairwiseRootDomainMap root_map(tv1, tv2);
auto p2c = BestEffortReplay::replayCasP(tv2, tv1, 2, root_map).getReplay();
TORCH_CHECK(p2c.size() == 1, "Expected only one mapping found");
TORCH_CHECK(p2c.begin()->first == tv1->getRootDomain().at(0));
TORCH_CHECK(p2c.begin()->second == tv2->getRootDomain().at(0));

// The two tensors should look like:
// tv1: [I1*1//5, 5, I2//1]
// tv2: [I1//5, 5]
//
// BestEffortRepaly should forward the merge of (I1 * 1) and create
// mappings of:
// I1*1//5 -> I1//5
// 5 -> 5
// I1*1 -> I1

TORCH_CHECK(p2c.size() == 3, "Unexpected number of mappings");
TORCH_CHECK(p2c.count(tv1->axis(0)) && p2c[tv1->axis(0)] == tv2->axis(0));
TORCH_CHECK(p2c.count(tv1->axis(1)) && p2c[tv1->axis(1)] == tv2->axis(1));
TORCH_CHECK(
p2c.count(tv1_merge_out_id) &&
p2c[tv1_merge_out_id] == tv2->getRootDomain()[0]);
}

TEST_F(NVFuserTest, FusionTrivialReductionForwarding4_CUDA) {
Expand Down Expand Up @@ -26125,6 +26139,39 @@ TEST_F(NVFuserTest, FusionTrivialInputForwarding_CUDA) {
testValidate(fusion, cg_outputs2, {t0, t1}, {t0}, __LINE__, __FILE__);
}

// Simplified repro of issue #2008
TEST_F(NVFuserTest, FusionReplayTrivialReductionAndBroadcast2_CUDA) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

std::vector<int64_t> shape({10, 1, 1});

auto tv0 = makeConcreteTensor(shape);
fusion.addInput(tv0);

auto tv1 = add(tv0, IrBuilder::create<Double>(1));
auto tv2 = sum(tv1, {1, 2});
auto tv3 = broadcast(tv2, {false, true, true});
fusion.addOutput(tv3);

tv0->merge(-2, -1)->merge(-2, -1)->split(0, 4);

MaxRootDomainInfoSpanningTree tree(tv0);
TransformPropagator tp(tv0);
tree.traverse(&tp);

auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn(shape, options);
std::vector<IValue> aten_inputs({t0});

FusionExecutor fe;
fe.compileFusion(fusion_ptr.get(), aten_inputs);
auto outputs = fe.runFusion(aten_inputs);

testValidate(&fusion, outputs, aten_inputs, {t0 + 1}, __LINE__, __FILE__);
}

namespace {

size_t getVecSizeForPointwise(FusionExecutorCache& fec) {
Expand Down
8 changes: 0 additions & 8 deletions torch/csrc/jit/codegen/cuda/transform_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,14 +762,6 @@ struct ProducerForwardingInfo {
(outer->isTrivialReduction() && !inner->isReduction())) {
auto compliment_id = inner->isTrivialReduction() ? inner : outer;
auto forwarded_id = inner->isTrivialReduction() ? outer : inner;
// Only allow forwarding when the trivial reduction domain is
// an root domain
if (std::find(
producer->getMaybeRFactorDomain().begin(),
producer->getMaybeRFactorDomain().end(),
compliment_id) == producer->getMaybeRFactorDomain().end()) {
continue;
}
forwarding_map.emplace(std::make_pair(forwarded_id, merge->out()));
compliment_map.emplace(std::make_pair(
forwarded_id, std::vector<IterDomain*>{compliment_id}));
Expand Down

0 comments on commit 0b8e83f

Please sign in to comment.