diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index f0bea85d3b8..134a372f74a 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -1219,23 +1219,30 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { reference1->reorder({{inner_most_pos2_in_ref1 + 1, -1}}); // [..., I1/tile1, .., I2/tile2, ..., tile1, tile2] - // Merge remaining dimensions - int64_t lhs_i = -1; - for (int64_t i = reference1->nDims() - 2; i > 0; i--) { - auto axis_i = i - 1; - if (lhs_i == -1) { - lhs_i = axis_i; - } else { - reference1->merge(axis_i, lhs_i); - lhs_i = axis_i; + // Merge remaining dimensions ignoring reduction axes (See Issue #2317) + // The reduction axes cannot be at any position. + // For example: [i0, r1, i1, r2, i2] after tiling is [i0, r1, i1/tile1, r2, + // i2/tile2, tile1, tile2] The following code merges all the outer iterdomains + // as: [i0 * i1/tile1 * i2/tile2, r1, r2, tile1, tile2] + int64_t rhs_i = reference1->nDims() - 3; + for (int64_t lhs_i = reference1->nDims() - 4; lhs_i >= 0; lhs_i--) { + if (reference1->axis(lhs_i)->isReduction()) { + continue; + } + if (reference1->axis(rhs_i)->isReduction()) { + rhs_i = lhs_i; + continue; } + reference1->merge(lhs_i, rhs_i); + rhs_i = lhs_i; } - reference1->split(0, 1); - // [merged_dim, 1, tile1, tile2] + + reference1->split(rhs_i, 1); + // [r.., merged_dim, 1, tile1, tile2] // parallelize non-tile dimensions - reference1->axis(1)->parallelize(ParallelType::Unswitch); - reference1->axis(0)->parallelize(ParallelType::BIDx); + reference1->axis(rhs_i + 1)->parallelize(ParallelType::Unswitch); + reference1->axis(rhs_i)->parallelize(ParallelType::BIDx); // [BIDx, Unswitch, tile1, tile2] // Propagate transformations so far to the entire DAG diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index f6fe5204f79..45c0d112bed 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4018,6 +4018,33 @@ def fusion_func(fd: FusionDefinition) -> None: nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + # See https://github.com/NVIDIA/Fuser/issues/2317 + @unittest.skipIf(is_pre_ampere(), "Only supported on Ampere and newer devices.") + def test_reduction_transpose_sched_issue2317(self): + inputs = [ + torch.randn((16, 25, 128, 64), dtype=torch.bfloat16, device="cuda:0"), + torch.randn((16, 128, 1600), dtype=torch.bfloat16, device="cuda:0"), + torch.randn((1600, 1600), dtype=torch.bfloat16, device="cuda:0"), + ] + + def fusion_func(fd: FusionDefinition, inputs) -> None: + T0 = fd.from_pytorch(inputs[0]) + T1 = fd.from_pytorch(inputs[1]) + T2 = fd.from_pytorch(inputs[2]) + + T10 = fd.ops.permute(T0, dims=[0, 2, 1, 3]) + T11 = fd.ops.stride_order(T10, stride_order=[3, 2, 1, 0]) + T16 = fd.ops.reshape(T11, new_shape=T1.shape()) + T17 = fd.ops.linear(T16, T2) + T33 = fd.ops.add(T17, T1) + + T33 = fd.ops.cast(T33, dtype=DataType.BFloat16) + T34 = fd.ops.linear(T33, T2) + T35 = fd.ops.add(T34, T33) + fd.add_output(T35) + + nvf_out, _ = self.exec_nvfuser(partial(fusion_func, inputs=inputs), inputs) + def test_fusion_profiler(self): inputs = [ torch.randn((2, 5), dtype=torch.float, device="cuda:0"),