diff --git a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp index c285b30424..73339670ff 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp @@ -40,6 +40,7 @@ struct ONNXConcatOpLowering : public ConversionPattern { assert(succeeded(shapecomputed) && "Could not compute output shape"); auto axis = concatOp.axis(); + assert(axis >= 0 && "negative axis is supposed to have been normalized"); unsigned int inputNum = operands.size(); // Convert the output type to MemRefType. @@ -57,8 +58,18 @@ struct ONNXConcatOpLowering : public ConversionPattern { MultiDialectBuilder create(rewriter, loc); // Creates loops, one for each input. + // Since the each input should have same size for each dimension(except + // axis), we will try to make the loop upper bound the same for futher + // optimization. Difference may come from constant vs. dynamic, or dynamic + // dim of different inputs. KrnlBuilder createKrnl(rewriter, loc); + SmallVector commonUB(shapeHelper.dimsForOutput()); + // IndexExprScope IEScope(&rewriter, loc); + IndexExpr accumulatedOffset = LiteralIndexExpr(0); for (unsigned int i = 0; i < inputNum; ++i) { + // Since the acculatedOffsetValue will be used in a nested IndexExprScope, + // we get the Value of this IndexExpr and pass it as a symbol + Value accumulatedOffsetValue = accumulatedOffset.getValue(); OpBuilder::InsertionGuard insertGuard(rewriter); // Create loop. ValueRange loopDef = createKrnl.defineLoops(rank); @@ -66,7 +77,9 @@ struct ONNXConcatOpLowering : public ConversionPattern { MemRefBoundsIndexCapture bounds(operands[i]); SmallVector ubs; bounds.getDimList(ubs); - createKrnl.iterateIE(loopDef, loopDef, lbs, ubs, + // For each input, only the dimension 'axis' is different + commonUB[axis] = ubs[axis]; + createKrnl.iterateIE(loopDef, loopDef, lbs, commonUB, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { // Indices for the read and write. SmallVector readIndices, writeIndices; @@ -76,10 +89,9 @@ struct ONNXConcatOpLowering : public ConversionPattern { else { IndexExprScope IEScope(&rewriter, loc); IndexExpr writeOffset = DimIndexExpr(loopInd[r]); - for (unsigned int j = 0; j < i; j++) { - MemRefBoundsIndexCapture operandJBounds(operands[j]); - writeOffset = writeOffset + operandJBounds.getDim(r); - } + IndexExpr accumulatedOffsetIE = + SymbolIndexExpr(accumulatedOffsetValue); + writeOffset = writeOffset + accumulatedOffsetIE; writeIndices.emplace_back(writeOffset.getValue()); } } @@ -87,6 +99,8 @@ struct ONNXConcatOpLowering : public ConversionPattern { Value loadData = createKrnl.load(operands[i], loopInd); createKrnl.store(loadData, alloc, writeIndices); }); + MemRefBoundsIndexCapture operandJBounds(operands[i]); + accumulatedOffset = accumulatedOffset + operandJBounds.getDim(axis); } rewriter.replaceOp(op, alloc); return success(); diff --git a/src/Dialect/ONNX/ShapeInference/Concat.cpp b/src/Dialect/ONNX/ShapeInference/Concat.cpp index 8f9487a0cc..2058c8784e 100644 --- a/src/Dialect/ONNX/ShapeInference/Concat.cpp +++ b/src/Dialect/ONNX/ShapeInference/Concat.cpp @@ -37,19 +37,35 @@ LogicalResult ONNXConcatOpShapeHelper::computeShape( if (axisIndex < 0) axisIndex += commonRank; - IndexExpr cumulativeAxisSize = LiteralIndexExpr(0); - for (unsigned i = 0; i < numInputs; ++i) { + // For Concat Op, the size of each dimension of inputs should be the same, + // except for concatenated dimension. To simplify the result, constant + // size is used if there is one. Otherwise, the dimension of the first + // input tensor (implementation dependent) is used for the output tensor. + DimsExpr outputDims(commonRank); + MemRefBoundsIndexCapture firstInputBounds(operandAdaptor.inputs()[0]); + for (unsigned dim = 0; dim < commonRank; dim++) { + outputDims[dim] = firstInputBounds.getDim(dim); + } + IndexExpr cumulativeAxisSize = + DimIndexExpr(firstInputBounds.getDim(axisIndex)); + + // Handle the rest of input + for (unsigned i = 1; i < numInputs; ++i) { Value currentInput = operandAdaptor.inputs()[i]; MemRefBoundsIndexCapture currInputBounds(currentInput); - DimIndexExpr currentSize(currInputBounds.getDim(axisIndex)); - cumulativeAxisSize = cumulativeAxisSize + currentSize; + for (unsigned dim = 0; dim < commonRank; dim++) { + if (dim == axisIndex) { + DimIndexExpr currentSize(currInputBounds.getDim(axisIndex)); + cumulativeAxisSize = cumulativeAxisSize + currentSize; + } else { + if (currInputBounds.getDim(dim).isLiteral()) { + // The size of current dimension of current input is a constant + outputDims[dim] = currInputBounds.getDim(dim); + } + } + } } - - DimsExpr outputDims(commonRank); - MemRefBoundsIndexCapture firstInputBounds(firstInput); - for (unsigned i = 0; i < commonRank; i++) - outputDims[i] = - (i == axisIndex) ? cumulativeAxisSize : firstInputBounds.getDim(i); + outputDims[axisIndex] = cumulativeAxisSize; setOutputDims(outputDims); return success(); diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index fb92e385b8..1a8377990a 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -3180,3 +3180,201 @@ func.func @test_isnan(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi1> { // CHECK: {{.*}}store [[ERF]], [[ALLOC]][[[IV]]#0, [[IV]]#1, [[IV]]#2] : memref<2x3x4xi1> // CHECK: return [[ALLOC]] : memref<2x3x4xi1> } + +// ----- + + +// Please check the loop bounds for each input: should be same for dynamic +func.func @test_concat_4(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 : si64} : (tensor, tensor, tensor) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () +// CHECK-DAG: #map0 = affine_map<(d0) -> (d0)> +// CHECK-DAG: #map1 = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: #map2 = affine_map<(d0) -> (d0 + 4)> +// CHECK-LABEL: func.func @test_concat_4 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref { +// CHECK: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_]] : memref +// CHECK-DAG: [[VAR_c1_0_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c1_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_2_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_2_]] : memref +// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c3_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[VAR_c32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[VAR_c32_4_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[VAR_c0_5_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_5_]] : memref +// CHECK-DAG: [[VAR_c5_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_c5_6_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_c9_:%.+]] = arith.constant 9 : index +// CHECK-DAG: [[VAR_c2_7_:%.+]] = arith.constant 2 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_7_]] : memref +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref +// CHECK-DAG: [[VAR_c0_8_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_9_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_10_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_10_]] : memref +// CHECK-DAG: [[VAR_c1_11_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c2_12_:%.+]] = arith.constant 2 : index +// CHECK: [[VAR_8_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_12_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to #map0([[VAR_0_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 1, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 32){ +// CHECK: [[VAR_14_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref +// CHECK: } +// CHECK-DAG: [[VAR_c1_13_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c1_14_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_15_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_16_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_10_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_16_]] : memref +// CHECK-DAG: [[VAR_c3_17_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c32_18_:%.+]] = arith.constant 32 : index +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to #map0([[VAR_0_]]), [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 32){ +// CHECK-DAG: [[VAR_14_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[VAR_c1_27_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map1([[VAR_14_1_]]#1) +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_14_1_]]#0, [[VAR_14_1_]]#1, [[VAR_14_1_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_1_MEM_]], [[RES_]]{{.}}[[VAR_14_1_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_14_1_]]#2] : memref +// CHECK: } +// CHECK-DAG: [[VAR_c3_19_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c4_20_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[LOOP_2_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_21_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_22_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_22_]] : memref +// CHECK-DAG: [[VAR_c5_23_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_c2_24_:%.+]] = arith.constant 2 : index +// CHECK: [[VAR_13_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_24_]] : memref +// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) with ([[LOOP_2_]]#0 -> [[I_6_:%.+]] = 0 to #map0([[VAR_0_]]), [[LOOP_2_]]#1 -> [[I_7_:%.+]] = 0 to 5, [[LOOP_2_]]#2 -> [[I_8_:%.+]] = 0 to 32){ +// CHECK-DAG: [[VAR_14_2_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[VAR_c4_27_:%.+]] = arith.constant 4 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map2([[VAR_14_2_]]#1) +// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_14_2_]]#0, [[VAR_14_2_]]#1, [[VAR_14_2_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_1_MEM_1_]], [[RES_]]{{.}}[[VAR_14_2_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_14_2_]]#2] : memref +// CHECK: } +// CHECK-DAG: [[VAR_c5_25_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_c9_26_:%.+]] = arith.constant 9 : index +// CHECK: return [[RES_]] : memref +// CHECK: } +} + +// ----- + + +// Focus on accumulated offset for the store op in each loop +func.func @test_concat_5(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 : si64} : (tensor, tensor, tensor) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () +// CHECK-DAG: #map0 = affine_map<(d0) -> (d0 + 3)> +// CHECK-DAG: #map1 = affine_map<(d0, d1) -> (d0 + d1 + 3)> +// CHECK-DAG: #map2 = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #map3 = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d4)> +// CHECK-DAG: #map5 = affine_map<(d0, d1, d2, d3, d4) -> (d2)> +// CHECK-DAG: #map6 = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #map7 = affine_map<(d0, d1, d2, d3, d4) -> (d4 + 3)> +// CHECK-DAG: #map8 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)> +// CHECK-DAG: #map9 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4 + d6 + 3)> +// CHECK-LABEL: func.func @test_concat_5 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref { +// CHECK: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_]] : memref +// CHECK-DAG: [[VAR_c1_0_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_0_]] : memref +// CHECK-DAG: [[VAR_c0_1_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_1_]] : memref +// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c3_2_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_5_:%.+]] = affine.apply #map0([[VAR_3_]]) +// CHECK-DAG: [[VAR_c32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[VAR_c32_3_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[VAR_c0_4_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_4_]] : memref +// CHECK-DAG: [[VAR_c1_5_:%.+]] = arith.constant 1 : index +// CHECK: [[VAR_7_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c1_5_]] : memref +// CHECK-DAG: [[VAR_8_:%.+]] = affine.apply #map1([[VAR_3_]], [[VAR_7_]]) +// CHECK-DAG: [[VAR_c2_6_:%.+]] = arith.constant 2 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_6_]] : memref +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]], [[VAR_8_]]) {{.*}}: memref +// CHECK-DAG: [[VAR_c0_7_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_8_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_9_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_9_]] : memref +// CHECK-DAG: [[VAR_c1_10_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_13_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_10_]] : memref +// CHECK-DAG: [[VAR_c2_11_:%.+]] = arith.constant 2 : index +// CHECK: [[VAR_14_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_11_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to #map2([[VAR_3_]], [[VAR_7_]], [[VAR_0_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to #map3([[VAR_3_]], [[VAR_7_]], [[VAR_0_]], [[VAR_13_]]), [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 32){ +// CHECK: [[VAR_26_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_26_]]#0, [[VAR_26_]]#1, [[VAR_26_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]]{{.}}[[VAR_26_]]#0, [[VAR_26_]]#1, [[VAR_26_]]#2] : memref +// CHECK: } +// CHECK: [[VAR_c1_12_:%.+]] = arith.constant 1 : index +// CHECK: [[VAR_15_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_12_]] : memref +// CHECK-DAG: [[VAR_16_:%.+]] = affine.apply #map4([[VAR_3_]], [[VAR_7_]], [[VAR_0_]], [[VAR_13_]], [[VAR_15_]]) +// CHECK-DAG: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_13_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_14_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_18_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_14_]] : memref +// CHECK-DAG: [[VAR_c3_15_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c32_16_:%.+]] = arith.constant 32 : index +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to #map5([[VAR_3_]], [[VAR_7_]], [[VAR_0_]], [[VAR_13_]], [[VAR_15_]]), [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 32){ +// CHECK: [[VAR_26_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map6([[VAR_26_1_]]#1){{.}}[[VAR_16_]]{{.}} +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_26_1_]]#0, [[VAR_26_1_]]#1, [[VAR_26_1_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_1_MEM_]], [[RES_]]{{.}}[[VAR_26_1_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_26_1_]]#2] : memref +// CHECK: } +// CHECK-DAG: [[VAR_c3_17_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_19_:%.+]] = affine.apply #map7([[VAR_3_]], [[VAR_7_]], [[VAR_0_]], [[VAR_13_]], [[VAR_15_]]) +// CHECK-DAG: [[LOOP_2_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_18_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_19_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_21_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_19_]] : memref +// CHECK-DAG: [[VAR_c1_20_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_22_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c1_20_]] : memref +// CHECK-DAG: [[VAR_c2_21_:%.+]] = arith.constant 2 : index +// CHECK: [[VAR_23_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_21_]] : memref +// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) with ([[LOOP_2_]]#0 -> [[I_6_:%.+]] = 0 to #map5([[VAR_3_]], [[VAR_7_]], [[VAR_0_]], [[VAR_13_]], [[VAR_15_]]), [[LOOP_2_]]#1 -> [[I_7_:%.+]] = 0 to #map8([[VAR_3_]], [[VAR_7_]], [[VAR_0_]], [[VAR_13_]], [[VAR_15_]], [[VAR_22_]]), [[LOOP_2_]]#2 -> [[I_8_:%.+]] = 0 to 32){ +// CHECK: [[VAR_26_2_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map6([[VAR_26_2_]]#1){{.}}[[VAR_19_]]{{.}} +// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_26_2_]]#0, [[VAR_26_2_]]#1, [[VAR_26_2_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_1_MEM_1_]], [[RES_]]{{.}}[[VAR_26_2_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_26_2_]]#2] : memref +// CHECK: } +// CHECK: [[VAR_c1_22_:%.+]] = arith.constant 1 : index +// CHECK: [[VAR_24_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c1_22_]] : memref +// CHECK: [[VAR_25_:%.+]] = affine.apply #map9([[VAR_3_]], [[VAR_7_]], [[VAR_0_]], [[VAR_13_]], [[VAR_15_]], [[VAR_22_]], [[VAR_24_]]) +// CHECK: return [[RES_]] : memref +// CHECK: } +} diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 8a18b3175c..1795a3bff2 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -823,6 +823,19 @@ func.func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, // CHECK: return [[RES]] : tensor<5x9x32xf32> } +// ----- + +func.func @test_concat_4(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 : si64} : (tensor, tensor, tensor) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () +// CHECK-LABEL: func.func @test_concat_4 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + + // ----- func.func @test_rnn_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<1x3x3xf32>) -> tensor<*xf32> {