Skip to content

Commit

Permalink
Transpose Scheduler: Ignore reduction axis in merge (#2326)
Browse files Browse the repository at this point in the history
Issue #2317.

The issue arises in the following lines for reference 1: `[I0, I1, I2,
r3]`:

After tiling:
```
  reference1->split(inner_most_pos1_in_ref1, params.tile_size1);
  reference1->reorder({{inner_most_pos1_in_ref1 + 1, -1}});
  reference1->split(inner_most_pos2_in_ref1, params.tile_size2);
  reference1->reorder({{inner_most_pos2_in_ref1 + 1, -1}});
```
Reference 1 is: [I0, I1/tile1, I2/tile2, r3, tile1, tile2]
```
 // Merge remaining dimensions
  int64_t lhs_i = -1;
  for (int64_t i = reference1->nDims() - 2; i > 0; i--) {
    auto axis_i = i - 1;
    if (lhs_i == -1) {
      lhs_i = axis_i;
    } else {
      reference1->merge(axis_i, lhs_i);
      lhs_i = axis_i;
    }
```
This tries to merge a reduction iterdomain with iteration type
iterdomain.

This PR ignored the reduction axis when merging all non-tile dimensions.
  • Loading branch information
Priya2698 authored and protonu committed Jun 26, 2024
1 parent 758309a commit bf26c7d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
33 changes: 20 additions & 13 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1219,23 +1219,30 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {
reference1->reorder({{inner_most_pos2_in_ref1 + 1, -1}});
// [..., I1/tile1, .., I2/tile2, ..., tile1, tile2]

// Merge remaining dimensions
int64_t lhs_i = -1;
for (int64_t i = reference1->nDims() - 2; i > 0; i--) {
auto axis_i = i - 1;
if (lhs_i == -1) {
lhs_i = axis_i;
} else {
reference1->merge(axis_i, lhs_i);
lhs_i = axis_i;
// Merge remaining dimensions ignoring reduction axes (See Issue #2317)
// The reduction axes cannot be at any position.
// For example: [i0, r1, i1, r2, i2] after tiling is [i0, r1, i1/tile1, r2,
// i2/tile2, tile1, tile2] The following code merges all the outer iterdomains
// as: [i0 * i1/tile1 * i2/tile2, r1, r2, tile1, tile2]
int64_t rhs_i = reference1->nDims() - 3;
for (int64_t lhs_i = reference1->nDims() - 4; lhs_i >= 0; lhs_i--) {
if (reference1->axis(lhs_i)->isReduction()) {
continue;
}
if (reference1->axis(rhs_i)->isReduction()) {
rhs_i = lhs_i;
continue;
}
reference1->merge(lhs_i, rhs_i);
rhs_i = lhs_i;
}
reference1->split(0, 1);
// [merged_dim, 1, tile1, tile2]

reference1->split(rhs_i, 1);
// [r.., merged_dim, 1, tile1, tile2]

// parallelize non-tile dimensions
reference1->axis(1)->parallelize(ParallelType::Unswitch);
reference1->axis(0)->parallelize(ParallelType::BIDx);
reference1->axis(rhs_i + 1)->parallelize(ParallelType::Unswitch);
reference1->axis(rhs_i)->parallelize(ParallelType::BIDx);
// [BIDx, Unswitch, tile1, tile2]

// Propagate transformations so far to the entire DAG
Expand Down
27 changes: 27 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4018,6 +4018,33 @@ def fusion_func(fd: FusionDefinition) -> None:

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

# See https://github.com/NVIDIA/Fuser/issues/2317
@unittest.skipIf(is_pre_ampere(), "Only supported on Ampere and newer devices.")
def test_reduction_transpose_sched_issue2317(self):
inputs = [
torch.randn((16, 25, 128, 64), dtype=torch.bfloat16, device="cuda:0"),
torch.randn((16, 128, 1600), dtype=torch.bfloat16, device="cuda:0"),
torch.randn((1600, 1600), dtype=torch.bfloat16, device="cuda:0"),
]

def fusion_func(fd: FusionDefinition, inputs) -> None:
T0 = fd.from_pytorch(inputs[0])
T1 = fd.from_pytorch(inputs[1])
T2 = fd.from_pytorch(inputs[2])

T10 = fd.ops.permute(T0, dims=[0, 2, 1, 3])
T11 = fd.ops.stride_order(T10, stride_order=[3, 2, 1, 0])
T16 = fd.ops.reshape(T11, new_shape=T1.shape())
T17 = fd.ops.linear(T16, T2)
T33 = fd.ops.add(T17, T1)

T33 = fd.ops.cast(T33, dtype=DataType.BFloat16)
T34 = fd.ops.linear(T33, T2)
T35 = fd.ops.add(T34, T33)
fd.add_output(T35)

nvf_out, _ = self.exec_nvfuser(partial(fusion_func, inputs=inputs), inputs)

def test_fusion_profiler(self):
inputs = [
torch.randn((2, 5), dtype=torch.float, device="cuda:0"),
Expand Down

0 comments on commit bf26c7d

Please sign in to comment.