diff --git a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp index c285b30424..37034e021e 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp @@ -40,6 +40,7 @@ struct ONNXConcatOpLowering : public ConversionPattern { assert(succeeded(shapecomputed) && "Could not compute output shape"); auto axis = concatOp.axis(); + assert(axis >= 0 && "negative axis is supposed to have been normalized"); unsigned int inputNum = operands.size(); // Convert the output type to MemRefType. @@ -57,7 +58,12 @@ struct ONNXConcatOpLowering : public ConversionPattern { MultiDialectBuilder create(rewriter, loc); // Creates loops, one for each input. + // Since the each input should have same size for each dimension(except + // axis), we will try to make the loop upper bound the same for futher + // optimization. Difference may come from constant vs. dynamic, or dynamic + // dim of different inputs. KrnlBuilder createKrnl(rewriter, loc); + SmallVector commonUB(shapeHelper.dimsForOutput()); for (unsigned int i = 0; i < inputNum; ++i) { OpBuilder::InsertionGuard insertGuard(rewriter); // Create loop. @@ -66,7 +72,9 @@ struct ONNXConcatOpLowering : public ConversionPattern { MemRefBoundsIndexCapture bounds(operands[i]); SmallVector ubs; bounds.getDimList(ubs); - createKrnl.iterateIE(loopDef, loopDef, lbs, ubs, + // For each input, only the dimension 'axis' is different + commonUB[axis] = ubs[axis]; + createKrnl.iterateIE(loopDef, loopDef, lbs, commonUB, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { // Indices for the read and write. SmallVector readIndices, writeIndices; diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index fb92e385b8..4c15c5e448 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -3180,3 +3180,93 @@ func.func @test_isnan(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi1> { // CHECK: {{.*}}store [[ERF]], [[ALLOC]][[[IV]]#0, [[IV]]#1, [[IV]]#2] : memref<2x3x4xi1> // CHECK: return [[ALLOC]] : memref<2x3x4xi1> } + +// ----- + + +func.func @test_concat_4(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 : si64} : (tensor, tensor, tensor) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () +// CHECK-DAG: #map0 = affine_map<(d0) -> (d0)> +// CHECK-DAG: #map1 = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: #map2 = affine_map<(d0) -> (d0 + 4)> +// CHECK-LABEL: func.func @test_concat_4 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref { +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_0_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_0_]] : memref +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c1_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c1_2_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_]] : memref +// CHECK-DAG: [[VAR_c0_3_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_3_]] : memref +// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c3_4_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[VAR_c32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[VAR_c32_5_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[VAR_c0_6_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_6_]] : memref +// CHECK-DAG: [[VAR_c0_7_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_7_]] : memref +// CHECK-DAG: [[VAR_c5_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_c5_8_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_c9_:%.+]] = arith.constant 9 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_4_]]) {{.*}}: memref +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_9_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_10_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_10_]] : memref +// CHECK-DAG: [[VAR_c1_11_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c2_12_:%.+]] = arith.constant 2 : index +// CHECK: [[VAR_8_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_12_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 1, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 32){ +// CHECK: [[VAR_14_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref +// CHECK: } +// CHECK-DAG: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_13_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_14_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_10_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_14_]] : memref +// CHECK-DAG: [[VAR_c3_15_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c32_16_:%.+]] = arith.constant 32 : index +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 32){ +// CHECK-DAG: [[VAR_14_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[VAR_c1_21_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map1([[VAR_14_1_]]#1) +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_14_1_]]#0, [[VAR_14_1_]]#1, [[VAR_14_1_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_1_MEM_]], [[RES_]]{{.}}[[VAR_14_1_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_14_1_]]#2] : memref +// CHECK: } +// CHECK-DAG: [[LOOP_2_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_c0_17_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c0_18_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_18_]] : memref +// CHECK-DAG: [[VAR_c5_19_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_c2_20_:%.+]] = arith.constant 2 : index +// CHECK: [[VAR_13_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_20_]] : memref +// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) with ([[LOOP_2_]]#0 -> [[I_6_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_2_]]#1 -> [[I_7_:%.+]] = 0 to 5, [[LOOP_2_]]#2 -> [[I_8_:%.+]] = 0 to 32){ +// CHECK-DAG: [[VAR_14_2_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[VAR_c1_21_1_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map1([[VAR_14_2_]]#1) +// CHECK-DAG: [[VAR_c3_22_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = affine.apply #map2([[VAR_14_2_]]#1) +// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_14_2_]]#0, [[VAR_14_2_]]#1, [[VAR_14_2_]]#2] : memref +// CHECK: krnl.store [[LOAD_PARAM_2_MEM_]], [[RES_]]{{.}}[[VAR_14_2_]]#0, [[LOAD_PARAM_1_MEM_1_]], [[VAR_14_2_]]#2] : memref +// CHECK: } +// CHECK: return [[RES_]] : memref +// CHECK: } +}