Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transpose scheduler fix: reduction IterDomain on input tensors #1661

Merged
merged 14 commits into from
Feb 1, 2024
23 changes: 23 additions & 0 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,27 @@ void TransposeScheduler::computeHeuristics(

namespace {

// propagation could miss trivial reduction iterdomain on input tensors, since
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Xiang also commented, trivial reduction could be confusing as it would imply squeeze. Can you please rephrase the comment?

// that doesn't have any dependency and doesn't map to anything, we can naively
// just reorder them so they won't interfere with tiling.
// See https://github.com/NVIDIA/Fuser/issues/1659#issuecomment-1907053830
void cleanInnerNDLeafDomain(TensorView* tv, int n) {
naoyam marked this conversation as resolved.
Show resolved Hide resolved
if (!tv->isFusionInput()) {
return;
}

std::unordered_map<int, int> old2new;

int target = 0;
for (int i = 0; i < n; i++) {
if (tv->axis(-1 - i)->isReduction()) {
old2new[-1 - i] = target++;
}
}

tv->reorder(old2new);
}

// TransposeViewPropagator doesn't propagate anything. It simply walks across
// the path of potential propagation checking if there's any incompatible
// propagation that would not be resolved.
Expand Down Expand Up @@ -1236,6 +1257,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {

int pos = (int)reference2->nDims() - 2;
// [..., tile1, tile2]
cleanInnerNDLeafDomain(reference2, 2);
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
reference2->merge(pos);
reference2->split(pos, params.vectorize_factor2);
reference2->split(pos, params.getThreadsPerBlock());
Expand Down Expand Up @@ -1321,6 +1343,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {
reference1->reorder({{-2, -1}});
// [..., tile2, tile1]
pos = (int)reference1->nDims() - 2;
cleanInnerNDLeafDomain(reference1, 2);
reference1->merge(pos);
reference1->split(pos, params.vectorize_factor1);
reference1->split(pos, params.getThreadsPerBlock());
Expand Down
52 changes: 52 additions & 0 deletions test/test_gpu_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,4 +1335,56 @@ TEST_F(TransposeTest, TransposeSplitAggregatedVectorizationWidth) {
NVF_CHECK(ref.equal(cg_outputs.at(0)));
}

// Testing transpose scheduler to handle trivial reduction IterDomain produced
// by segmented fusion, see issue 1659 for details
TEST_F(TransposeTest, TrivialReductionIterDomainOnInputsIssueRepro1659) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also change the comment here to avoid the term "trivial reduction"

auto fusion = std::make_unique<Fusion>();
auto fusion_ptr = fusion.get();
FusionGuard fg(fusion_ptr);

auto tv0 = TensorViewBuilder()
.ndims(3)
.contiguity({true, true, std::nullopt})
.shape({-1, -1, 1})
.dtype(DataType::Float)
.build();
fusion->addInput(tv0);
auto tv1 = TensorViewBuilder()
.ndims(3)
.contiguity({true, std::nullopt, true})
.shape({-1, 1, -1})
.dtype(DataType::Float)
.build();
fusion->addInput(tv1);
auto tv2 = sum(tv0, {1});
auto tv3 = squeeze(tv1, std::vector<int64_t>{1});
auto tv4 = add(tv2, tv3);
fusion->addOutput(tv4);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

auto t0 = at::randn({1024, 512, 1}, options);
auto t1 = at::randn({1024, 1, 512}, options);
std::vector<c10::IValue> aten_inputs({t0, t1});

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

auto runtime = executor_cache.getMostRecentKernelRuntime();
NVF_CHECK(runtime->isSegmented(), "Segmentation expected");
auto heuristic0 =
runtime->schedulerHeuristics()->heuristicsList().at(0).get()->heuristic();
NVF_CHECK(
heuristic0 == ScheduleHeuristic::Reduction,
"Unexpected heuristic: ",
heuristic0);
auto heuristic1 =
runtime->schedulerHeuristics()->heuristicsList().at(1).get()->heuristic();
NVF_CHECK(
heuristic1 == ScheduleHeuristic::Transpose,
"Unexpected heuristic: ",
heuristic1);
testValidate(fusion_ptr, cg_outputs, {t0, t1}, __LINE__, __FILE__);
}

} // namespace nvfuser
Loading