Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transpose scheduler fix: reduction IterDomain on input tensors #1661

Merged
merged 14 commits into from
Feb 1, 2024
26 changes: 26 additions & 0 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,30 @@ void TransposeScheduler::computeHeuristics(

namespace {

// If a fusion is segmented, the segmenter will create fusions whose inputs
// contain reduction IterDomains. These reduction IterDomains on input
// TensorViews does not have any meaning, and should just be left untouched. See
// https://github.com/NVIDIA/Fuser/issues/1659#issuecomment-1907053830
//
// This function checks the inner `n` iterdomain and reorder reduction
// iterdomain to the beginning.
void moveReductionsOut(TensorView* tv, int n) {
if (!tv->isFusionInput()) {
return;
}

std::unordered_map<int, int> old2new;

int target = 0;
for (int i = 0; i < n; i++) {
if (tv->axis(-1 - i)->isReduction()) {
old2new[-1 - i] = target++;
}
}

tv->reorder(old2new);
}

// TransposeViewPropagator doesn't propagate anything. It simply walks across
// the path of potential propagation checking if there's any incompatible
// propagation that would not be resolved.
Expand Down Expand Up @@ -1236,6 +1260,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {

int pos = (int)reference2->nDims() - 2;
// [..., tile1, tile2]
moveReductionsOut(reference2, 2);
reference2->merge(pos);
reference2->split(pos, params.vectorize_factor2);
reference2->split(pos, params.getThreadsPerBlock());
Expand Down Expand Up @@ -1321,6 +1346,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {
reference1->reorder({{-2, -1}});
// [..., tile2, tile1]
pos = (int)reference1->nDims() - 2;
moveReductionsOut(reference1, 2);
reference1->merge(pos);
reference1->split(pos, params.vectorize_factor1);
reference1->split(pos, params.getThreadsPerBlock());
Expand Down
53 changes: 53 additions & 0 deletions test/test_gpu_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,4 +1335,57 @@ TEST_F(TransposeTest, TransposeSplitAggregatedVectorizationWidth) {
NVF_CHECK(ref.equal(cg_outputs.at(0)));
}

// Testing transpose scheduler to handle fusion inputs with reduction IterDomain
// produced by segmented fusion, see issue
// https://github.com/NVIDIA/Fuser/issues/1659 for details
TEST_F(TransposeTest, ReductionIterDomainOnInputsIssue1659) {
auto fusion = std::make_unique<Fusion>();
auto fusion_ptr = fusion.get();
FusionGuard fg(fusion_ptr);

auto tv0 = TensorViewBuilder()
.ndims(3)
.contiguity({true, true, std::nullopt})
.shape({-1, -1, 1})
.dtype(DataType::Float)
.build();
fusion->addInput(tv0);
auto tv1 = TensorViewBuilder()
.ndims(3)
.contiguity({true, std::nullopt, true})
.shape({-1, 1, -1})
.dtype(DataType::Float)
.build();
fusion->addInput(tv1);
auto tv2 = sum(tv0, {1});
auto tv3 = squeeze(tv1, std::vector<int64_t>{1});
auto tv4 = add(tv2, tv3);
fusion->addOutput(tv4);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

auto t0 = at::randn({1024, 512, 1}, options);
auto t1 = at::randn({1024, 1, 512}, options);
std::vector<c10::IValue> aten_inputs({t0, t1});

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

auto runtime = executor_cache.getMostRecentKernelRuntime();
NVF_CHECK(runtime->isSegmented(), "Segmentation expected");
auto heuristic0 =
runtime->schedulerHeuristics()->heuristicsList().at(0).get()->heuristic();
NVF_CHECK(
heuristic0 == ScheduleHeuristic::Reduction,
"Unexpected heuristic: ",
heuristic0);
auto heuristic1 =
runtime->schedulerHeuristics()->heuristicsList().at(1).get()->heuristic();
NVF_CHECK(
heuristic1 == ScheduleHeuristic::Transpose,
"Unexpected heuristic: ",
heuristic1);
testValidate(fusion_ptr, cg_outputs, {t0, t1}, __LINE__, __FILE__);
}

} // namespace nvfuser
Loading