diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index fe3a5b6f457..eae9d4dc7bd 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -127,6 +127,30 @@ void TransposeScheduler::computeHeuristics( namespace { +// If a fusion is segmented, the segmenter will create fusions whose inputs +// contain reduction IterDomains. These reduction IterDomains on input +// TensorViews does not have any meaning, and should just be left untouched. See +// https://github.com/NVIDIA/Fuser/issues/1659#issuecomment-1907053830 +// +// This function checks the inner `n` iterdomain and reorder reduction +// iterdomain to the beginning. +void moveReductionsOut(TensorView* tv, int n) { + if (!tv->isFusionInput()) { + return; + } + + std::unordered_map old2new; + + int target = 0; + for (int i = 0; i < n; i++) { + if (tv->axis(-1 - i)->isReduction()) { + old2new[-1 - i] = target++; + } + } + + tv->reorder(old2new); +} + // TransposeViewPropagator doesn't propagate anything. It simply walks across // the path of potential propagation checking if there's any incompatible // propagation that would not be resolved. @@ -1236,6 +1260,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { int pos = (int)reference2->nDims() - 2; // [..., tile1, tile2] + moveReductionsOut(reference2, 2); reference2->merge(pos); reference2->split(pos, params.vectorize_factor2); reference2->split(pos, params.getThreadsPerBlock()); @@ -1321,6 +1346,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { reference1->reorder({{-2, -1}}); // [..., tile2, tile1] pos = (int)reference1->nDims() - 2; + moveReductionsOut(reference1, 2); reference1->merge(pos); reference1->split(pos, params.vectorize_factor1); reference1->split(pos, params.getThreadsPerBlock()); diff --git a/test/test_gpu_transpose.cpp b/test/test_gpu_transpose.cpp index c54e2cbba74..9fe61d6c3d7 100644 --- a/test/test_gpu_transpose.cpp +++ b/test/test_gpu_transpose.cpp @@ -1335,4 +1335,57 @@ TEST_F(TransposeTest, TransposeSplitAggregatedVectorizationWidth) { NVF_CHECK(ref.equal(cg_outputs.at(0))); } +// Testing transpose scheduler to handle fusion inputs with reduction IterDomain +// produced by segmented fusion, see issue +// https://github.com/NVIDIA/Fuser/issues/1659 for details +TEST_F(TransposeTest, ReductionIterDomainOnInputsIssue1659) { + auto fusion = std::make_unique(); + auto fusion_ptr = fusion.get(); + FusionGuard fg(fusion_ptr); + + auto tv0 = TensorViewBuilder() + .ndims(3) + .contiguity({true, true, std::nullopt}) + .shape({-1, -1, 1}) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv1 = TensorViewBuilder() + .ndims(3) + .contiguity({true, std::nullopt, true}) + .shape({-1, 1, -1}) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv1); + auto tv2 = sum(tv0, {1}); + auto tv3 = squeeze(tv1, std::vector{1}); + auto tv4 = add(tv2, tv3); + fusion->addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn({1024, 512, 1}, options); + auto t1 = at::randn({1024, 1, 512}, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(runtime->isSegmented(), "Segmentation expected"); + auto heuristic0 = + runtime->schedulerHeuristics()->heuristicsList().at(0).get()->heuristic(); + NVF_CHECK( + heuristic0 == ScheduleHeuristic::Reduction, + "Unexpected heuristic: ", + heuristic0); + auto heuristic1 = + runtime->schedulerHeuristics()->heuristicsList().at(1).get()->heuristic(); + NVF_CHECK( + heuristic1 == ScheduleHeuristic::Transpose, + "Unexpected heuristic: ", + heuristic1); + testValidate(fusion_ptr, cg_outputs, {t0, t1}, __LINE__, __FILE__); +} + } // namespace nvfuser