Skip to content

Commit

Permalink
Fixes for split tensor and slice (llvm#2314)
Browse files Browse the repository at this point in the history
* RecomposeComplexOps: Remove dead slice op

* lib/Dialect/Torch/IR/TorchOps.cpp: Fold slice ops even when they are on non-value tensors

* lib/Conversion/TorchToTosa/TorchToTosa.cpp: Fix slice start/end out of range/none

* lib/Dialect/Torch/IR/TorchOps.cpp: AtenSliceTensorOp::fold: Fold slices that go from 0:int_max

* More tests for aten.split.Tensor
  • Loading branch information
mgehre-amd authored and jinchen62 committed Jul 20, 2023
1 parent d5c1308 commit 852b79b
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 8 deletions.
10 changes: 10 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStartEqEndModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceWholeTensorModule_basic",
Expand Down Expand Up @@ -797,6 +798,8 @@
"AtenComplex64Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitTensorLastSmallerModule_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
"ChunkListUnpack_Module_basic",
Expand Down Expand Up @@ -1051,6 +1054,7 @@
"BroadcastZeroRankInputStaticModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"SliceStaticModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeIntModule_basic",
Expand Down Expand Up @@ -1117,13 +1121,16 @@
"ElementwiseSqrtModule_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitTensorLastSmallerModule_basic",
"ChunkListUnpack_Module_basic",
"ChunkListUnpackUneven_Module_basic",
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
### Tests additionally passing in make_fx_tosa
"NativeGroupNormBackwardModule_basic",
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
}) - {
Expand Down Expand Up @@ -1261,6 +1268,7 @@
"ScalarImplicitIntModule_basic",
"SliceEndSleStartModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStartEqEndModule_basic",
"SqrtIntModule_basic",
"SubFloatModule_basic",
Expand Down Expand Up @@ -1344,6 +1352,8 @@
"AtenComplexViewModule_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitTensorLastSmallerModule_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
"ChunkListUnpack_Module_basic",
Expand Down
10 changes: 8 additions & 2 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3194,9 +3194,15 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
if (start < 0)
return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0");

start = std::min(selfType.getShape()[dim], start);

int64_t end;
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
if (isa<ConstantNoneOp>(op.getEnd().getDefiningOp()))
end = selfType.getShape()[dim];
else
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
}
// support for end < 0
end = toPositiveDim(end, selfType.getShape()[dim]);
// support for end out of upper bound
Expand Down
13 changes: 11 additions & 2 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2296,8 +2296,17 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
int64_t start, end, step;
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
matchPattern(getStep(), m_TorchConstantInt(&step))
&& step == 1
&& start == 0
&& end == std::numeric_limits<int64_t>::max())
return getOperand(0);

auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
op, sliceOpInput.getType(), sliceOpInput, indices, op.getSrc(),
/*accumulate=*/falseVal, /*unsafe=*/falseVal);

if (sliceOp->use_empty())
rewriter.eraseOp(sliceOp);

return success();
}
};
Expand Down
68 changes: 68 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,28 @@ def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils):

# ==============================================================================

class SliceOutOfUpperBoundIndexStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([6, 4, 7], torch.float32, True),
])
def forward(self, x):
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
result = x[:8, :5, 8:]
cat_tensor = torch.ones((6,4,1), dtype=torch.float32)
return torch.cat((result,cat_tensor), dim=2)


@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexStaticModule())
def SliceOutOfUpperBoundIndexStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))

# ==============================================================================

class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -696,6 +718,52 @@ def SplitTensorListUnpackModule_basic(module, tu: TestUtils):

# ==============================================================================


class SplitTensorLastSmallerModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([8, 10, 12], torch.float32, True)
])
def forward(self, x):
s0, s1, s2 = torch.ops.aten.split(x, 3, dim=0)
return s2


@register_test_case(module_factory=lambda: SplitTensorLastSmallerModule())
def SplitTensorLastSmallerModule_basic(module, tu: TestUtils):
# Splitting the first dimension with 8 elements into chunks of 3
# will leave the last result to have 2 elements in that dimension.
module.forward(tu.rand(8, 10, 12))

# ==============================================================================


class SplitTensorNegativeDimModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([10, 12, 6], torch.float32, True)
])
def forward(self, x):
s0, s1, s2 = torch.ops.aten.split(x, 2, -1)
return s1


@register_test_case(module_factory=lambda: SplitTensorNegativeDimModule())
def SplitTensorNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 12, 6))

# ==============================================================================

class ChunkListUnpack_Module(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
8 changes: 4 additions & 4 deletions test/Conversion/TorchToStablehlo/view_like.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]]
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]]
// CHECK: %[[INT10:.*]] = torch.constant.int 10
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT10]]
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64
Expand Down Expand Up @@ -48,8 +48,8 @@
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int9223372036854775807 = torch.constant.int 9223372036854775807
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32>
%int10 = torch.constant.int 10
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int10, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}

Expand Down
21 changes: 21 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,27 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<
return %0 : !torch.vtensor<[4],f32>
}

// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[?],f32>
// CHECK: return %[[ARG0]] : !torch.vtensor<[?],f32>
func.func @torch.aten.slice.tensor$fold_full_slice(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> {
%int1 = torch.constant.int 1
%int9223372036854775807 = torch.constant.int 9223372036854775807
%int0 = torch.constant.int 0
%0 = torch.aten.slice.Tensor %arg0, %dim, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[?], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?], f32>
return %0 : !torch.vtensor<[?],f32>
}

// CHECK-LABEL: @torch.aten.slice.tensor$no_fold_step
// CHECK: torch.aten.slice.Tensor
func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> {
%int2 = torch.constant.int 2
%int9223372036854775807 = torch.constant.int 9223372036854775807
%int0 = torch.constant.int 0
%0 = torch.aten.slice.Tensor %arg0, %dim, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?], f32>
return %0 : !torch.vtensor<[?],f32>
}

// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %int-1 = torch.constant.int -1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
Expand Down

0 comments on commit 852b79b

Please sign in to comment.