Skip to content

Commit

Permalink
Vectorization Factor patch for computeInfoC2P with Broadcast in mappe…
Browse files Browse the repository at this point in the history
…d IterDomain (#1625)

Fixes #1567

This PR patches vectorization factor in
`ContiguousInnerDimensionsMapper::computeInfoC2P`.

Handling of resolved broadcast dimension should be made on mapped
consumer tensors' from_ids, instead of the root_domain order. Added a
few tests per @zasdfgbnm 's suggestion:

```
Case 0:
T2[1024, 2, 512] = T0[1024, 2, 1] + T1[1024, 2, 512]
allocation = rfactor
--> T0 has no vectorization

Case 1:
T2[1024, 512, 2] = T0[1024, 1, 2] + T1[1024, 512, 2]
allocation = rfactor
--> T0 has vectorization 2

Case 2:
T2[1024, 512, 2] = T0[1024, 1, 2] + T1[1024, 512, 2];
T3[512, 1024, 2] = transpose(T2[1024, 512, 2])
allocation = rfactor
*except T1 has stride_order {1, 2, 0}
--> T0 has vectorization 4

Case 3:
T2[512, 1024, 2] = T0[1, 1024, 2] + T1[512, 1024, 2]
T3[1024, 512, 2] = transpose(T2[512, 1024, 2])
allocation = rfactor
--> T0 has vectorization 2
```

---------

Co-authored-by: Jacob Hinkle <[email protected]>
Co-authored-by: Gao, Xiang <[email protected]>
  • Loading branch information
3 people authored Feb 1, 2024
1 parent 3e1e11e commit 0cf0a25
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 25 deletions.
51 changes: 26 additions & 25 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ ContiguousInnerDimensionsMapper::computeInfoC2P(
std::shared_ptr<MaxInfoSpanningTree::Information> from_info) {
auto from_ids = std::dynamic_pointer_cast<const MappedDomain>(from_info)
->mapped_root_ids_;
// When we propagate, we should check the resolved broadcast in the order of
// mapped from_ids.
//
// If we have a case where we have a concretized broadcast that's being
// tracked in a consumer but not concretized in the producer we should break
// off the dimensions connected to the left of that dimension. So if we have:
Expand All @@ -445,49 +448,47 @@ ContiguousInnerDimensionsMapper::computeInfoC2P(
// T3[i0, i1, i2] = T1 + T2
// and we're propogating from T3 with {i0, i1, i2}
// When we go from T3 to T0, we don't have any mechanism to understand that i0
// and i2 are not contiguous in the original domain of T3. It's not ideal with
// transpose, but when this happens we'll clear all dimensions mapped left of
// the concretized broadcast.
// So if we have:
// T0[i1, i2]
// T1[b0, i1, i2] = broadcast(T0)
// T2[i1, b0, i2] = transpose(T1)
// T3[i1, i0, i2]
// T4[i1, i0, i2] = T2 + T3
// T5[i0, i1, i2] = transpose(T4)
// Then i1 and i2 are contiguous in both T0 and T5, but due to the realization
// of the broadcast on T4 we will have removed i1 from the mapped set.
// and i2 are not contiguous in the original domain of T3.
//
// Another example is that, if the last broadcast dimension resolved in
// consumers root domain is mapped for vectorization, the merge order in
// the vectorization axes matters.
//
// T0[i0, i1]
// T1[i0, i1, b2] = broadcast(T0)
// T2[i0, i1, i3]
// T3[i0, i1, i2] = T1 + T2
//
// If the mapped ids are {i0, i2, i1}, when propagating from T3 to T1, the
// resolved broadcast iterdomain is `i2`/`b2`, which would give clear_pos=1.
// So we'll skip all from_ids with index < clear_pos. see issue:
// https://github.com/NVIDIA/Fuser/issues/1567#issuecomment-1894605385
PairwiseRootDomainMap root_map(to, from);
auto c2p_map = root_map.mapConsumerToProducer();

// Id's in consumer to clear from the mapped set due to broadcast
// concretization.
std::unordered_set<IterDomain*> consumer_ids_to_clear;
size_t clear_pos = 0;
if (to->hasBroadcast()) {
// Find the last broadcast dimension resolved in consumers root domain
int clear_pos = -1;
for (auto i : c10::irange(from->getRootDomain().size())) {
auto c_id = from->getRootDomain()[i];
// Find the last broadcast dimension resolved in consumers through from_ids
for (int i = (int)from_ids.size() - 1; i >= 0; i--) {
auto c_id = from_ids[i];
auto c_it = c2p_map.find(c_id);
if (c_it == c2p_map.end()) {
continue;
}
auto p_id = c_it->second;
if ((!c_id->isBroadcast()) && p_id->isBroadcast()) {
clear_pos = (int)i;
clear_pos = (size_t)i + 1;
break;
}
}
// Clear everything to the left of the inner most resolved broadcast
// dimension, including the broadcasted domain.
if (clear_pos >= 0) {
consumer_ids_to_clear.insert(
from->getRootDomain().begin(),
from->getRootDomain().begin() + clear_pos + 1);
}
}

std::vector<IterDomain*> producer_rfactor_ids;
for (auto from_id : from_ids) {
for (auto i : c10::irange(clear_pos, from_ids.size())) {
auto from_id = from_ids[i];
auto c2p_it = c2p_map.find(from_id);
if (c2p_it != c2p_map.end() &&
consumer_ids_to_clear.find(c2p_it->first) ==
Expand Down
229 changes: 229 additions & 0 deletions test/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ size_t getVecSizeForPointwise(FusionExecutorCache& fec) {
return 1;
}

bool hasVectorizationCache(TensorView* tv) {
NVF_CHECK(tv->isFusionInput());
NVF_CHECK(tv->uses().size() == 1);
auto set_expr = dynamic_cast<LoadStoreOp*>(tv->uses().at(0));
NVF_CHECK(set_expr != nullptr && set_expr->opType() == LoadStoreOpType::Set);
auto cached_input = set_expr->out()->as<TensorView>();
NVF_CHECK(cached_input, "expects input to be cached");

for (const auto* id : cached_input->getLeafDomain()) {
if (id->getParallelType() == ParallelType::Vectorize) {
return true;
}
}
return false;
}

} // namespace

TEST_F(PointwiseTest, VectorizeStrideContiguity2D) {
Expand Down Expand Up @@ -201,4 +217,217 @@ TEST_F(PointwiseTest, VectorizeAllocationDomain) {
testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__);
}

// All inputs & outputs share the same allocation domain permutation from root
// domain, but intermediate tv2 isn't specified a stride order. There's also a
// broadcast IterDomain on tv1, which is tricky for vectorization analysis to
// figure out which axes should be excluded from the computation of
// vectorization factor.
TEST_F(PointwiseTest, Issue1567VectorizeAllocationDomain) {
auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

TensorView* tv0 = TensorViewBuilder()
.ndims(3)
.contiguity({true, true, true})
.strideOrder({2, 0, 1})
.build();
TensorView* tv1 = TensorViewBuilder()
.ndims(3)
.shape({1, -1, 1})
.contiguity({std::nullopt, std::nullopt, true})
.strideOrder({2, 0, 1})
.build();
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = add(tv0, tv1);
auto tv3 = add(tv2, IrBuilder::create<Val>(1.0, DataType::Float));
tv3->setAllocationDomain({tv3->axis(0), tv3->axis(2), tv3->axis(1)}, true);
fusion->addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::empty_strided({1024, 128, 25}, {128*25, 1, 128}, options);
at::Tensor t1 = at::empty_strided({1, 128, 1}, {128, 1, 128}, options);
std::vector<c10::IValue> aten_inputs = {t0, t1};

// NOTE: force pointwise scheduler here just for testing purpose
auto params = getPointwiseHeuristics(fusion, aten_inputs);
auto lparams = schedulePointwise(fusion, aten_inputs);
FusionExecutor fe;
fe.compileFusion(fusion, aten_inputs, lparams);
auto cg_outputs = fe.runFusion(aten_inputs, lparams);

EXPECT_EQ(params->vectorize, true);
EXPECT_EQ(params->unroll_factor, 4);
EXPECT_TRUE(hasVectorizationCache(tv0));
EXPECT_TRUE(hasVectorizationCache(tv1));

testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
}

TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase0) {
auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

TensorView* tv0 = TensorViewBuilder()
.ndims(3)
.contiguity({true, true, std::nullopt})
.shape({-1, -1, 1})
.build();
TensorView* tv1 = TensorViewBuilder()
.ndims(3)
.contiguity({true, true, true})
.build();
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = add(tv0, tv1);
fusion->addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({1024, 2, 1}, options);
at::Tensor t1 = at::randn({1024, 2, 512}, options);
std::vector<c10::IValue> aten_inputs = {t0, t1};

// NOTE: force pointwise scheduler here just for testing purpose
auto params = getPointwiseHeuristics(fusion, aten_inputs);
auto lparams = schedulePointwise(fusion, aten_inputs);
FusionExecutor fe;
fe.compileFusion(fusion, aten_inputs, lparams);
auto cg_outputs = fe.runFusion(aten_inputs, lparams);

EXPECT_EQ(params->vectorize, true);
EXPECT_EQ(params->unroll_factor, 4);
EXPECT_FALSE(hasVectorizationCache(tv0));
EXPECT_TRUE(hasVectorizationCache(tv1));

testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
}

TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase1) {
auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

TensorView* tv0 = TensorViewBuilder()
.ndims(3)
.contiguity({true, std::nullopt, true})
.shape({-1, 1, -1})
.build();
TensorView* tv1 = TensorViewBuilder()
.ndims(3)
.contiguity({true, true, true})
.build();
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = add(tv0, tv1);
fusion->addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({1024, 1, 2}, options);
at::Tensor t1 = at::randn({1024, 512, 2}, options);
std::vector<c10::IValue> aten_inputs = {t0, t1};

// NOTE: force pointwise scheduler here just for testing purpose
auto params = getPointwiseHeuristics(fusion, aten_inputs);
auto lparams = schedulePointwise(fusion, aten_inputs);
FusionExecutor fe;
fe.compileFusion(fusion, aten_inputs, lparams);
auto cg_outputs = fe.runFusion(aten_inputs, lparams);

EXPECT_EQ(params->vectorize, true);
EXPECT_EQ(params->unroll_factor, 2);
EXPECT_TRUE(hasVectorizationCache(tv0));
EXPECT_TRUE(hasVectorizationCache(tv1));

testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
}

TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase2) {
auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

TensorView* tv0 = TensorViewBuilder()
.ndims(3)
.contiguity({true, std::nullopt, true})
.shape({-1, 1, -1})
.build();
TensorView* tv1 = TensorViewBuilder()
.ndims(3)
.contiguity({true, true, true})
.strideOrder({1, 2, 0})
.build();
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = add(tv0, tv1);
auto tv3 = transpose(tv2, 0, 1);
fusion->addOutput(tv3);

FusionExecutorCache fec(std::move(fusion_ptr));
fec.profile(true);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({1024, 1, 2}, options);
at::Tensor t1 = at::empty_strided({1024, 512, 2}, {2, 2048, 1}, options);
std::vector<c10::IValue> aten_inputs = {t0, t1};

// NOTE: force pointwise scheduler here just for testing purpose
auto params = getPointwiseHeuristics(fusion, aten_inputs);
auto lparams = schedulePointwise(fusion, aten_inputs);
FusionExecutor fe;
fe.compileFusion(fusion, aten_inputs, lparams);
auto cg_outputs = fe.runFusion(aten_inputs, lparams);

EXPECT_EQ(params->vectorize, true);
EXPECT_EQ(params->unroll_factor, 4);
EXPECT_TRUE(hasVectorizationCache(tv0));
EXPECT_TRUE(hasVectorizationCache(tv1));

testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
}

TEST_F(PointwiseTest, VIssue1567ectorizationFactorAnalysisCase3) {
auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

TensorView* tv0 = TensorViewBuilder()
.ndims(3)
.contiguity({std::nullopt, true, true})
.shape({1, -1, -1})
.build();
TensorView* tv1 = TensorViewBuilder()
.ndims(3)
.contiguity({true, true, true})
.build();
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = add(tv0, tv1);
auto tv3 = transpose(tv2, 0, 1);
fusion->addOutput(tv3);

FusionExecutorCache fec(std::move(fusion_ptr));
fec.profile(true);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({1, 1024, 2}, options);
at::Tensor t1 = at::randn({512, 1024, 2}, options);
std::vector<c10::IValue> aten_inputs = {t0, t1};

// NOTE: force pointwise scheduler here just for testing purpose
auto params = getPointwiseHeuristics(fusion, aten_inputs);
auto lparams = schedulePointwise(fusion, aten_inputs);
FusionExecutor fe;
fe.compileFusion(fusion, aten_inputs, lparams);
auto cg_outputs = fe.runFusion(aten_inputs, lparams);

EXPECT_EQ(params->vectorize, true);
EXPECT_EQ(params->unroll_factor, 2);
EXPECT_TRUE(hasVectorizationCache(tv0));
EXPECT_TRUE(hasVectorizationCache(tv1));

testValidate(fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
}

} // namespace nvfuser

0 comments on commit 0cf0a25

Please sign in to comment.