Skip to content

Commit

Permalink
Add support for index.Tensor on dimensions other than the first
Browse files Browse the repository at this point in the history
This patch still only supports a single indexing tensor.
  • Loading branch information
qedawkins authored and vivekkhandelwal1 committed Jul 19, 2022
1 parent 7f08169 commit c73a39e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 15 deletions.
48 changes: 33 additions & 15 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,28 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
return rewriter.notifyMatchFailure(
op, "unimplemented: the indices list is not from a list construct");
}
if (indicesTuple.size() != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only one index tensor is supported");
}

SmallVector<Value> indicesVal =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);
Value indexTensor = indicesVal[0];
if (failed(checkNotNone(rewriter, op, indexTensor))) {

int indexTensorDim = -1;
for (auto i : llvm::seq(0, (int)indicesVal.size())) {
Value index = indicesVal[i];
if (!index || failed(checkNotNone(rewriter, op, index)))
continue;
if (indexTensorDim >= 0) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only one index tensor allowed");
}
indexTensorDim = i;
}

if (indexTensorDim == -1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensor must not be None");
}

Value indexTensor = indicesVal[indexTensorDim];
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
Expand All @@ -286,13 +295,16 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
int indexTensorRank = indexTensorType.getRank();

// This result shape calculation assumes that there is only one
// index tensor and that it is indexing the first dimension of the
// input tensor. The calculation for arbitrary inputs is much more complex.
// index tensor of the input tensor. The calculation for arbitrary inputs is
// much more complex.
SmallVector<Value> resultShape;
for (auto i : llvm::seq(0, indexTensorDim)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
for (auto i : llvm::seq(0, indexTensorRank)) {
resultShape.push_back(getDimOp(rewriter, loc, indexTensor, i));
}
for (auto i : llvm::seq(1, inputRank)) {
for (auto i : llvm::seq(indexTensorDim + 1, inputRank)) {
resultShape.push_back(getDimOp(rewriter, loc, input, i));
}
int resultRank = resultShape.size();
Expand All @@ -302,7 +314,7 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
SmallVector<AffineExpr> indicesExpr, resultExpr;
SmallVector<StringRef> iteratorTypes;

for (auto i : llvm::seq(0, indexTensorRank))
for (auto i : llvm::seq(indexTensorDim, indexTensorDim + indexTensorRank))
indicesExpr.push_back(rewriter.getAffineDimExpr(i));
for (auto i : llvm::seq(0, resultRank)) {
resultExpr.push_back(rewriter.getAffineDimExpr(i));
Expand All @@ -316,11 +328,17 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
loc, initTensor.getType(), indexTensor, initTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> extractionIndices{
castIntToIndex(b, loc, args[0])};
for (auto i : llvm::seq(1, inputRank)) {
extractionIndices.push_back(b.create<linalg::IndexOp>(
loc, i + indexTensorRank - 1));
Value index = castIntToIndex(b, loc, args[0]);
SmallVector<Value> extractionIndices;
int extra_dims = 0;
for (auto i : llvm::seq(0, inputRank)) {
if (i == indexTensorDim) {
extractionIndices.push_back(index);
extra_dims += indexTensorRank - 1;
} else {
extractionIndices.push_back(
b.create<linalg::IndexOp>(loc, i + extra_dims));
}
}
Value extractedElement = b.create<tensor::ExtractOp>(
loc, input, extractionIndices);
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,28 @@ def IndexTensorModule3dInput_basic(module, tu: TestUtils):
# ==============================================================================


class IndexTensorSelectDimModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1], torch.int64, True),
])
def forward(self, a, ind):
return torch.ops.aten.index(a, (None, ind, None))


@register_test_case(module_factory=lambda: IndexTensorSelectDimModule())
def IndexTensorSelectDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6), torch.randint(3, (2, 3)))

# ==============================================================================


class SquareModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit c73a39e

Please sign in to comment.