From 6d41b9d363c8c2019ea81383e31e79936aa716f7 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Sat, 13 Mar 2021 01:41:35 +0800 Subject: [PATCH] Create InitTensorOp based on input operand shapes. (#5070) According to the documentation, the output shape is: `params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]` E.g., if the varibles are: params.shape = [16, 2, 3] indices.shape = [16, 5] axis = 2 batch_dims = 1 Then the output shape is `[16, 2] + [5] + [] = [16, 2, 5]`. See https://www.tensorflow.org/api_docs/python/tf/gather for more details. --- .../HLOToLinalg/HLOToLinalgOnTensors.cpp | 64 +++++++++---------- .../HLOToLinalg/test/torch_index_select.mlir | 64 ++++++++----------- 2 files changed, 58 insertions(+), 70 deletions(-) diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp index 2c5b278037f9..b81680116bc5 100644 --- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp @@ -52,29 +52,6 @@ namespace { // mhlo.torch_index_select conversion patterns. //===----------------------------------------------------------------------===// -static Value getOutputTensor(OpBuilder &builder, Location loc, Value opResult) { - ShapedType outputType = opResult.getType().cast(); - if (outputType.hasStaticShape()) { - return builder.create(loc, outputType.getShape(), - outputType.getElementType()); - } - // Check for tie-shape operations for the result to get the shape of the - // output. - SmallVector dynamicSizes; - for (Operation *user : opResult.getUsers()) { - auto tieShapeOp = dyn_cast(user); - if (!tieShapeOp) continue; - auto makeShapeOp = - tieShapeOp.shape().getDefiningOp(); - if (!makeShapeOp) continue; - dynamicSizes = llvm::to_vector<4>(makeShapeOp.dynamic_dimensions()); - break; - } - if (outputType.getNumDynamicDims() != dynamicSizes.size()) return nullptr; - return builder.create( - loc, dynamicSizes, outputType.getShape(), outputType.getElementType()); -} - namespace { /// Converts xla-hlo.torch_index_select op to a linalg.indexed_generic op. @@ -88,30 +65,49 @@ struct TorchIndexSelectOpConversion mhlo::TorchIndexSelectOp::Adaptor adaptor(args); int axis = op.dim(); int batch = op.batch_dims(); - auto indexShapeType = adaptor.index().getType().dyn_cast(); + auto indexShapeType = adaptor.index().getType().cast(); int nIndices = indexShapeType.getRank(); - auto inputShapeType = adaptor.input().getType().dyn_cast(); + auto inputShapeType = adaptor.input().getType().cast(); if (axis < 0) axis += inputShapeType.getRank(); if (batch < 0) batch += nIndices; Location loc = op.getLoc(); - Value output = op.getResult(); - int rank = output.getType().cast().getRank(); + ShapedType resultType = op.getResult().getType().cast(); + int rank = resultType.getRank(); + SmallVector indexingMaps; SmallVector exprs; - for (int i = 0; i < batch; ++i) + for (int i = 0; i < batch; ++i) { exprs.push_back(rewriter.getAffineDimExpr(i)); - for (int i = 0, e = nIndices - batch; i < e; ++i) + } + for (int i = 0, e = nIndices - batch; i < e; ++i) { exprs.push_back(rewriter.getAffineDimExpr(axis + i)); + } indexingMaps.emplace_back( AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext())); indexingMaps.emplace_back(rewriter.getMultiDimIdentityMap(rank)); + SmallVector loopTypes(rank, getParallelIteratorTypeName()); - ShapedType outputType = op.getResult().getType().cast(); - Value initOp = getOutputTensor(rewriter, loc, op.getResult()); - if (!initOp) return failure(); + + // The output shape is + // `params[:axis] + indices[batch_dims:] + params[axis + 1:]` + SmallVector dynSizes; + for (int i = 0; i < rank; ++i) { + if (!resultType.isDynamicDim(i)) continue; + if (i < axis) { + dynSizes.push_back(rewriter.create(loc, adaptor.input(), i)); + } else if (i < (axis + nIndices - batch)) { + int idx = i - axis + batch; + dynSizes.push_back(rewriter.create(loc, adaptor.index(), idx)); + } else { + int idx = i - (axis + nIndices - batch) + axis + 1; + dynSizes.push_back(rewriter.create(loc, adaptor.input(), idx)); + } + } + Value initOp = rewriter.create( + loc, dynSizes, resultType.getShape(), resultType.getElementType()); auto linalgOp = rewriter.create( - loc, /*resultTensors=*/ArrayRef{op.getResult().getType()}, + loc, /*resultTensors=*/ArrayRef{resultType}, /*inputs=*/adaptor.index(), /*outputBuffers=*/initOp, indexingMaps, loopTypes); @@ -126,7 +122,7 @@ struct TorchIndexSelectOpConversion blockArgs.getType().cast().getElementType()); } block->addArguments(bodyArgTypes); - block->addArguments(outputType.getElementType()); + block->addArguments(resultType.getElementType()); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(block); diff --git a/iree/compiler/Conversion/HLOToLinalg/test/torch_index_select.mlir b/iree/compiler/Conversion/HLOToLinalg/test/torch_index_select.mlir index 38e1b68d90af..182780fd6966 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/torch_index_select.mlir +++ b/iree/compiler/Conversion/HLOToLinalg/test/torch_index_select.mlir @@ -110,48 +110,40 @@ module { // ----- -func @torch_index_select_dynamic() { - %c0 = constant 0 : index - %0 = hal.interface.load.constant offset = 9 : index - %1 = hal.interface.load.constant offset = 8 : index - %2 = shapex.make_ranked_shape %1, %0 : (index, index) -> !shapex.ranked_shape<[?,?]> - %3 = hal.interface.load.constant offset = 7 : index - %4 = hal.interface.load.constant offset = 6 : index - %5 = hal.interface.load.constant offset = 5 : index - %6 = shapex.make_ranked_shape %1, %5, %4, %3 : (index, index, index, index) -> !shapex.ranked_shape<[?,?,?,?]> - %7 = hal.interface.load.constant offset = 4 : index - %8 = shapex.make_ranked_shape %7, %5, %4, %3 : (index, index, index, index) -> !shapex.ranked_shape<[?,?,?,?]> - %9 = hal.interface.load.tensor @legacy_io::@arg4, offset = %c0 {operand_result_index = 0 : i32} : tensor - %10 = hal.interface.load.tensor @legacy_io::@arg9, offset = %c0 {operand_result_index = 1 : i32} : tensor - %11 = shapex.tie_shape %10, %2 : tensor, !shapex.ranked_shape<[?,?]> - %12 = shapex.tie_shape %9, %8 : tensor, !shapex.ranked_shape<[?,?,?,?]> - %13 = "mhlo.torch_index_select"(%11, %12) {batch_dims = 1 : i64, dim = 1 : i64} : (tensor, tensor) -> tensor - %14 = shapex.tie_shape %13, %6 : tensor, !shapex.ranked_shape<[?,?,?,?]> - hal.interface.store.tensor %14, @legacy_io::@ret0, offset = %c0 {operand_result_index = 2 : i32} : tensor - return +func @torch_index_select_dynamic(%input: tensor, + %index: tensor) -> tensor{ + %0 = "mhlo.torch_index_select"(%input, %index) { + batch_dims = 1 : i64, + dim = 2 : i64 + } : (tensor, tensor) -> tensor + return %0 : tensor } -// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @torch_index_select_dynamic -// CHECK-DAG: %[[D0:.+]] = hal.interface.load.constant offset = 8 -// CHECK-DAG: %[[D1:.+]] = hal.interface.load.constant offset = 5 -// CHECK-DAG: %[[D2:.+]] = hal.interface.load.constant offset = 6 -// CHECK-DAG: %[[D3:.+]] = hal.interface.load.constant offset = 7 -// CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[D0]], %[[D1]], %[[D2]], %[[D3]] -// CHECK: %[[INPUT:.+]] = hal.interface.load.tensor @legacy_io::@arg9 -// CHECK: %[[INPUT_TIED:.+]] = shapex.tie_shape %[[INPUT]], %{{.+}} +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.+]] = constant 0 : index +// CHECK: %[[D0:.+]] = dim %[[INPUT]], %[[C0]] +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[D1:.+]] = dim %[[INPUT]], %[[C1]] +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[D2:.+]] = dim %[[INDEX]], %[[C1]] +// CHECK: %[[C3:.+]] = constant 3 : index +// CHECK: %[[D3:.+]] = dim %[[INPUT]], %[[C3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]] // CHECK: %[[RESULT:.+]] = linalg.indexed_generic -// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%{{[a-zA-Z0-9_]+}} : tensor) -// CHECK-SAME: outs(%[[INIT]] : tensor) +// CHECK-SAME: ins(%[[INDEX]] : tensor) +// CHECK-SAME: outs(%[[INIT]] : tensor) // CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %{{[a-zA-Z0-9_]+}}: index, %{{[a-zA-Z0-9_]+}}: index, -// CHECK-SAME: %{{[a-zA-Z0-9_]+}}: index, %[[ARG4:[a-zA-Z0-9_]+]]: i32 -// CHECK-SAME: %{{[a-zA-Z0-9_]+}}: i32) +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32) // CHECK: %[[POS:.+]] = index_cast %[[ARG4]] -// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT_TIED]][%[[ARG0]], %[[POS]]] +// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]] // CHECK: linalg.yield %[[YIELD]] -// CHECK: shapex.tie_shape %[[RESULT]], %[[SHAPE]] \ No newline at end of file