Skip to content

Commit

Permalink
Create InitTensorOp based on input operand shapes. (#5070)
Browse files Browse the repository at this point in the history
According to the documentation, the output shape is:
`params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]`

E.g., if the varibles are:
params.shape = [16, 2, 3]
indices.shape = [16, 5]
axis = 2
batch_dims = 1

Then the output shape is `[16, 2] + [5] + [] = [16, 2, 5]`.

See https://www.tensorflow.org/api_docs/python/tf/gather for more details.
  • Loading branch information
hanhanW authored Mar 12, 2021
1 parent a2a464d commit 6d41b9d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 70 deletions.
64 changes: 30 additions & 34 deletions iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,6 @@ namespace {
// mhlo.torch_index_select conversion patterns.
//===----------------------------------------------------------------------===//

static Value getOutputTensor(OpBuilder &builder, Location loc, Value opResult) {
ShapedType outputType = opResult.getType().cast<ShapedType>();
if (outputType.hasStaticShape()) {
return builder.create<linalg::InitTensorOp>(loc, outputType.getShape(),
outputType.getElementType());
}
// Check for tie-shape operations for the result to get the shape of the
// output.
SmallVector<Value, 4> dynamicSizes;
for (Operation *user : opResult.getUsers()) {
auto tieShapeOp = dyn_cast<Shape::TieShapeOp>(user);
if (!tieShapeOp) continue;
auto makeShapeOp =
tieShapeOp.shape().getDefiningOp<Shape::MakeRankedShapeOp>();
if (!makeShapeOp) continue;
dynamicSizes = llvm::to_vector<4>(makeShapeOp.dynamic_dimensions());
break;
}
if (outputType.getNumDynamicDims() != dynamicSizes.size()) return nullptr;
return builder.create<linalg::InitTensorOp>(
loc, dynamicSizes, outputType.getShape(), outputType.getElementType());
}

namespace {

/// Converts xla-hlo.torch_index_select op to a linalg.indexed_generic op.
Expand All @@ -88,30 +65,49 @@ struct TorchIndexSelectOpConversion
mhlo::TorchIndexSelectOp::Adaptor adaptor(args);
int axis = op.dim();
int batch = op.batch_dims();
auto indexShapeType = adaptor.index().getType().dyn_cast<ShapedType>();
auto indexShapeType = adaptor.index().getType().cast<ShapedType>();
int nIndices = indexShapeType.getRank();
auto inputShapeType = adaptor.input().getType().dyn_cast<ShapedType>();
auto inputShapeType = adaptor.input().getType().cast<ShapedType>();
if (axis < 0) axis += inputShapeType.getRank();
if (batch < 0) batch += nIndices;

Location loc = op.getLoc();
Value output = op.getResult();
int rank = output.getType().cast<ShapedType>().getRank();
ShapedType resultType = op.getResult().getType().cast<ShapedType>();
int rank = resultType.getRank();

SmallVector<AffineMap, 2> indexingMaps;
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < batch; ++i)
for (int i = 0; i < batch; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(i));
for (int i = 0, e = nIndices - batch; i < e; ++i)
}
for (int i = 0, e = nIndices - batch; i < e; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(axis + i));
}
indexingMaps.emplace_back(
AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
indexingMaps.emplace_back(rewriter.getMultiDimIdentityMap(rank));

SmallVector<StringRef, 3> loopTypes(rank, getParallelIteratorTypeName());
ShapedType outputType = op.getResult().getType().cast<ShapedType>();
Value initOp = getOutputTensor(rewriter, loc, op.getResult());
if (!initOp) return failure();

// The output shape is
// `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
SmallVector<Value, 4> dynSizes;
for (int i = 0; i < rank; ++i) {
if (!resultType.isDynamicDim(i)) continue;
if (i < axis) {
dynSizes.push_back(rewriter.create<DimOp>(loc, adaptor.input(), i));
} else if (i < (axis + nIndices - batch)) {
int idx = i - axis + batch;
dynSizes.push_back(rewriter.create<DimOp>(loc, adaptor.index(), idx));
} else {
int idx = i - (axis + nIndices - batch) + axis + 1;
dynSizes.push_back(rewriter.create<DimOp>(loc, adaptor.input(), idx));
}
}
Value initOp = rewriter.create<linalg::InitTensorOp>(
loc, dynSizes, resultType.getShape(), resultType.getElementType());
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
loc, /*resultTensors=*/ArrayRef<Type>{op.getResult().getType()},
loc, /*resultTensors=*/ArrayRef<Type>{resultType},
/*inputs=*/adaptor.index(),
/*outputBuffers=*/initOp, indexingMaps, loopTypes);

Expand All @@ -126,7 +122,7 @@ struct TorchIndexSelectOpConversion
blockArgs.getType().cast<ShapedType>().getElementType());
}
block->addArguments(bodyArgTypes);
block->addArguments(outputType.getElementType());
block->addArguments(resultType.getElementType());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);

Expand Down
64 changes: 28 additions & 36 deletions iree/compiler/Conversion/HLOToLinalg/test/torch_index_select.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,48 +110,40 @@ module {

// -----

func @torch_index_select_dynamic() {
%c0 = constant 0 : index
%0 = hal.interface.load.constant offset = 9 : index
%1 = hal.interface.load.constant offset = 8 : index
%2 = shapex.make_ranked_shape %1, %0 : (index, index) -> !shapex.ranked_shape<[?,?]>
%3 = hal.interface.load.constant offset = 7 : index
%4 = hal.interface.load.constant offset = 6 : index
%5 = hal.interface.load.constant offset = 5 : index
%6 = shapex.make_ranked_shape %1, %5, %4, %3 : (index, index, index, index) -> !shapex.ranked_shape<[?,?,?,?]>
%7 = hal.interface.load.constant offset = 4 : index
%8 = shapex.make_ranked_shape %7, %5, %4, %3 : (index, index, index, index) -> !shapex.ranked_shape<[?,?,?,?]>
%9 = hal.interface.load.tensor @legacy_io::@arg4, offset = %c0 {operand_result_index = 0 : i32} : tensor<?x?x?x?xi32>
%10 = hal.interface.load.tensor @legacy_io::@arg9, offset = %c0 {operand_result_index = 1 : i32} : tensor<?x?xi32>
%11 = shapex.tie_shape %10, %2 : tensor<?x?xi32>, !shapex.ranked_shape<[?,?]>
%12 = shapex.tie_shape %9, %8 : tensor<?x?x?x?xi32>, !shapex.ranked_shape<[?,?,?,?]>
%13 = "mhlo.torch_index_select"(%11, %12) {batch_dims = 1 : i64, dim = 1 : i64} : (tensor<?x?xi32>, tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
%14 = shapex.tie_shape %13, %6 : tensor<?x?x?x?xi32>, !shapex.ranked_shape<[?,?,?,?]>
hal.interface.store.tensor %14, @legacy_io::@ret0, offset = %c0 {operand_result_index = 2 : i32} : tensor<?x?x?x?xi32>
return
func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
%index: tensor<?x?xi32>) -> tensor<?x?x?x?xf32>{
%0 = "mhlo.torch_index_select"(%input, %index) {
batch_dims = 1 : i64,
dim = 2 : i64
} : (tensor<?x?x?x?xf32>, tensor<?x?xi32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}

// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @torch_index_select_dynamic
// CHECK-DAG: %[[D0:.+]] = hal.interface.load.constant offset = 8
// CHECK-DAG: %[[D1:.+]] = hal.interface.load.constant offset = 5
// CHECK-DAG: %[[D2:.+]] = hal.interface.load.constant offset = 6
// CHECK-DAG: %[[D3:.+]] = hal.interface.load.constant offset = 7
// CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[D0]], %[[D1]], %[[D2]], %[[D3]]
// CHECK: %[[INPUT:.+]] = hal.interface.load.tensor @legacy_io::@arg9
// CHECK: %[[INPUT_TIED:.+]] = shapex.tie_shape %[[INPUT]], %{{.+}}
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[D0:.+]] = dim %[[INPUT]], %[[C0]]
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[D1:.+]] = dim %[[INPUT]], %[[C1]]
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[D2:.+]] = dim %[[INDEX]], %[[C1]]
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[D3:.+]] = dim %[[INPUT]], %[[C3]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]]
// CHECK: %[[RESULT:.+]] = linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%{{[a-zA-Z0-9_]+}} : tensor<?x?x?x?xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xi32>)
// CHECK-SAME: ins(%[[INDEX]] : tensor<?x?xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xf32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %{{[a-zA-Z0-9_]+}}: index, %{{[a-zA-Z0-9_]+}}: index,
// CHECK-SAME: %{{[a-zA-Z0-9_]+}}: index, %[[ARG4:[a-zA-Z0-9_]+]]: i32
// CHECK-SAME: %{{[a-zA-Z0-9_]+}}: i32)
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32)
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT_TIED]][%[[ARG0]], %[[POS]]]
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
// CHECK: linalg.yield %[[YIELD]]
// CHECK: shapex.tie_shape %[[RESULT]], %[[SHAPE]]

0 comments on commit 6d41b9d

Please sign in to comment.