diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 8ff1eb4fea65..b611631744bb 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -641,6 +641,7 @@ "SliceModule_basic", "SliceNegIdxModule_basic", "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfLowerBoundStartIndexStaticModule_basic", "SliceOutOfUpperBoundIndexModule_basic", "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStartEqEndModule_basic", @@ -651,6 +652,7 @@ "SliceScatterNegativeDimModule_basic", "SliceScatterNegativeEndModule_basic", "SliceScatterStaticModule_basic", + "SliceEndSleStartStaticModule_basic", "SliceScatterStepVariationModule_basic", "SliceScatterZeroDimModule_basic", "SqueezeDimModule_static", @@ -1096,8 +1098,8 @@ "BroadcastZeroRankInputStaticModule_basic", "BroadcastListConstructWithMinusOneModule_basic", "SliceStaticModule_basic", - "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceSizeTwoStepDivisibleStaticModule_basic", + "SliceOutOfLowerBoundStartIndexStaticModule_basic", "ArangeStartStepIntModule_basic", "ArangeDtypeFloatModule_basic", "ArangeIntModule_basic", @@ -1346,6 +1348,7 @@ "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", "SliceEndSleStartModule_basic", + "SliceEndSleStartStaticModule_basic", "SliceOutOfUpperBoundIndexModule_basic", "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStartEqEndModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 9789bd7cdfe8..15bbb2e42257 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3323,6 +3323,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); + auto outTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!outTy) { + return rewriter.notifyMatchFailure(op, "output type must be ranked"); + } + if (outTy.hasStaticShape() && outTy.getNumElements() == 0) { + return rewriter.notifyMatchFailure(op, + "tosa.slice does not support zero size"); + } + // Only statically deducible values are currently supported int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) @@ -3333,36 +3343,34 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!isValidDim(dim, selfType.getRank())) return rewriter.notifyMatchFailure(op, "dim must less than tensor rank"); + auto sizeOfDim = selfType.getDimSize(dim); + int64_t start; if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); - if (start < 0) - return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0"); - - start = std::min(selfType.getShape()[dim], start); + // support for start < 0 + start = toPositiveDim(start, sizeOfDim); + start = std::clamp(start, (int64_t)0, sizeOfDim); int64_t end; if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { if (isa(op.getEnd().getDefiningOp())) - end = selfType.getShape()[dim]; + end = sizeOfDim; else return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); } - // support for end < 0 - end = toPositiveDim(end, selfType.getShape()[dim]); - end = std::min(end, selfType.getDimSize(dim)); - // FIXME: add support for start < 0 and end < start - if (end < start) - return rewriter.notifyMatchFailure(op, - "Currently unsupported: end < start"); + // support for end < 0 + end = toPositiveDim(end, sizeOfDim); + end = std::min(end, sizeOfDim); + // Handle start > end + end = std::clamp(end, (int64_t)0, sizeOfDim); int64_t step; if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "step must be a Scalar constant"); - auto sizeOfDim = selfType.getDimSize(dim); if (sizeOfDim % step != 0) { return rewriter.notifyMatchFailure(op, "size must be divisible by step"); } @@ -3380,15 +3388,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector startSlice(reshaped.getType().getRank(), 0); - startSlice[dim+1] = start % step; - // Due to the reshaping, the dimension shifted up by one startSlice[dim] = start / step; - - auto outTy = - dyn_cast(getTypeConverter()->convertType(op.getType())); - if (!outTy) { - return rewriter.notifyMatchFailure(op, "output type must be ranked"); - } + startSlice[dim+1] = start % step; SmallVector sliceShape{outTy.getShape()}; sliceShape.insert(sliceShape.begin() + dim+1, 1); diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index e2fb3a3071b9..d09580958acc 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -133,6 +133,25 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceOutOfLowerBoundStartIndexStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([6, 4, 7], torch.float32, True), + ]) + def forward(self, x): + return x[-8:3:1, :, :] + + +@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexStaticModule()) +def SliceOutOfLowerBoundStartIndexStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + +# ============================================================================== + class SliceEndSleStartModule(torch.nn.Module): def __init__(self): @@ -157,6 +176,30 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceEndSleStartStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([6, 4, 7], torch.float32, True), + ]) + def forward(self, x): + # TODO: remove hacky cat tensor once refbackend supports 0 size dim + result = x[:, 4:3, :] + cat_tensor = torch.ones((6,1,7), dtype=torch.float32) + return torch.cat((result, cat_tensor), dim=1) + + +@register_test_case(module_factory=lambda: SliceEndSleStartStaticModule()) +def SliceEndSleStartStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + + +# ============================================================================== + + class SliceStartEqEndModule(torch.nn.Module): def __init__(self): super().__init__()