Skip to content

Commit

Permalink
TOSA: slice: Support start < 0, start < end and start + sizeOfDim < 0 (
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd authored Jun 9, 2023
1 parent dfc3c0d commit 3f8a1cd
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 22 deletions.
5 changes: 4 additions & 1 deletion e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfLowerBoundStartIndexStaticModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStartEqEndModule_basic",
Expand All @@ -651,6 +652,7 @@
"SliceScatterNegativeDimModule_basic",
"SliceScatterNegativeEndModule_basic",
"SliceScatterStaticModule_basic",
"SliceEndSleStartStaticModule_basic",
"SliceScatterStepVariationModule_basic",
"SliceScatterZeroDimModule_basic",
"SqueezeDimModule_static",
Expand Down Expand Up @@ -1096,8 +1098,8 @@
"BroadcastZeroRankInputStaticModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"SliceStaticModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceSizeTwoStepDivisibleStaticModule_basic",
"SliceOutOfLowerBoundStartIndexStaticModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeIntModule_basic",
Expand Down Expand Up @@ -1346,6 +1348,7 @@
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SliceEndSleStartModule_basic",
"SliceEndSleStartStaticModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStartEqEndModule_basic",
Expand Down
43 changes: 22 additions & 21 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3323,6 +3323,16 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");

auto outTy =
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!outTy) {
return rewriter.notifyMatchFailure(op, "output type must be ranked");
}
if (outTy.hasStaticShape() && outTy.getNumElements() == 0) {
return rewriter.notifyMatchFailure(op,
"tosa.slice does not support zero size");
}

// Only statically deducible values are currently supported
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
Expand All @@ -3333,36 +3343,34 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
if (!isValidDim(dim, selfType.getRank()))
return rewriter.notifyMatchFailure(op, "dim must less than tensor rank");

auto sizeOfDim = selfType.getDimSize(dim);

int64_t start;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(op, "start must be a Scalar constant");

if (start < 0)
return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0");

start = std::min(selfType.getShape()[dim], start);
// support for start < 0
start = toPositiveDim(start, sizeOfDim);
start = std::clamp(start, (int64_t)0, sizeOfDim);

int64_t end;
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
if (isa<ConstantNoneOp>(op.getEnd().getDefiningOp()))
end = selfType.getShape()[dim];
end = sizeOfDim;
else
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
}
// support for end < 0
end = toPositiveDim(end, selfType.getShape()[dim]);
end = std::min(end, selfType.getDimSize(dim));

// FIXME: add support for start < 0 and end < start
if (end < start)
return rewriter.notifyMatchFailure(op,
"Currently unsupported: end < start");
// support for end < 0
end = toPositiveDim(end, sizeOfDim);
end = std::min(end, sizeOfDim);
// Handle start > end
end = std::clamp(end, (int64_t)0, sizeOfDim);

int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(op, "step must be a Scalar constant");

auto sizeOfDim = selfType.getDimSize(dim);
if (sizeOfDim % step != 0) {
return rewriter.notifyMatchFailure(op, "size must be divisible by step");
}
Expand All @@ -3380,15 +3388,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(

SmallVector<int64_t> startSlice(reshaped.getType().getRank(), 0);

startSlice[dim+1] = start % step;
// Due to the reshaping, the dimension shifted up by one
startSlice[dim] = start / step;

auto outTy =
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!outTy) {
return rewriter.notifyMatchFailure(op, "output type must be ranked");
}
startSlice[dim+1] = start % step;

SmallVector<int64_t> sliceShape{outTy.getShape()};
sliceShape.insert(sliceShape.begin() + dim+1, 1);
Expand Down
43 changes: 43 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,25 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):

# ==============================================================================

class SliceOutOfLowerBoundStartIndexStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([6, 4, 7], torch.float32, True),
])
def forward(self, x):
return x[-8:3:1, :, :]


@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexStaticModule())
def SliceOutOfLowerBoundStartIndexStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))

# ==============================================================================


class SliceEndSleStartModule(torch.nn.Module):
def __init__(self):
Expand All @@ -157,6 +176,30 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils):
# ==============================================================================


class SliceEndSleStartStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([6, 4, 7], torch.float32, True),
])
def forward(self, x):
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
result = x[:, 4:3, :]
cat_tensor = torch.ones((6,1,7), dtype=torch.float32)
return torch.cat((result, cat_tensor), dim=1)


@register_test_case(module_factory=lambda: SliceEndSleStartStaticModule())
def SliceEndSleStartStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))


# ==============================================================================


class SliceStartEqEndModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 3f8a1cd

Please sign in to comment.