Skip to content

Commit

Permalink
lowering
Browse files Browse the repository at this point in the history
Signed-off-by: chentong319 <[email protected]>
  • Loading branch information
chentong319 committed Oct 11, 2022
1 parent 1be7c5a commit 6c38d65
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/Conversion/ONNXToKrnl/Tensor/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct ONNXConcatOpLowering : public ConversionPattern {
assert(succeeded(shapecomputed) && "Could not compute output shape");

auto axis = concatOp.axis();
assert(axis >= 0 && "negative axis is supposed to have been normalized");
unsigned int inputNum = operands.size();

// Convert the output type to MemRefType.
Expand All @@ -57,7 +58,12 @@ struct ONNXConcatOpLowering : public ConversionPattern {
MultiDialectBuilder<KrnlBuilder> create(rewriter, loc);

// Creates loops, one for each input.
// Since the each input should have same size for each dimension(except
// axis), we will try to make the loop upper bound the same for futher
// optimization. Difference may come from constant vs. dynamic, or dynamic
// dim of different inputs.
KrnlBuilder createKrnl(rewriter, loc);
SmallVector<IndexExpr, 4> commonUB(shapeHelper.dimsForOutput());
for (unsigned int i = 0; i < inputNum; ++i) {
OpBuilder::InsertionGuard insertGuard(rewriter);
// Create loop.
Expand All @@ -66,7 +72,9 @@ struct ONNXConcatOpLowering : public ConversionPattern {
MemRefBoundsIndexCapture bounds(operands[i]);
SmallVector<IndexExpr, 4> ubs;
bounds.getDimList(ubs);
createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
// For each input, only the dimension 'axis' is different
commonUB[axis] = ubs[axis];
createKrnl.iterateIE(loopDef, loopDef, lbs, commonUB,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
// Indices for the read and write.
SmallVector<Value, 4> readIndices, writeIndices;
Expand Down
90 changes: 90 additions & 0 deletions test/mlir/onnx/onnx_lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3180,3 +3180,93 @@ func.func @test_isnan(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi1> {
// CHECK: {{.*}}store [[ERF]], [[ALLOC]][[[IV]]#0, [[IV]]#1, [[IV]]#2] : memref<2x3x4xi1>
// CHECK: return [[ALLOC]] : memref<2x3x4xi1>
}

// -----


func.func @test_concat_4(%arg0 : tensor<?x1x?xf32>, %arg1 : tensor<?x3x32xf32>, %arg2 : tensor<?x5x?xf32>) -> tensor<*xf32> {
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 : si64} : (tensor<?x1x?xf32>, tensor<?x3x32xf32>, tensor<?x5x?xf32>) -> tensor<*xf32>
"func.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-DAG: #map0 = affine_map<(d0) -> (d0)>
// CHECK-DAG: #map1 = affine_map<(d0) -> (d0 + 1)>
// CHECK-DAG: #map2 = affine_map<(d0) -> (d0 + 4)>
// CHECK-LABEL: func.func @test_concat_4
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x1x?xf32>, [[PARAM_1_:%.+]]: memref<?x3x32xf32>, [[PARAM_2_:%.+]]: memref<?x5x?xf32>) -> memref<?x9x32xf32> {
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_0_]] : memref<?x1x?xf32>
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c1_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c1_2_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_1_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_]] : memref<?x1x?xf32>
// CHECK-DAG: [[VAR_c0_3_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_2_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_3_]] : memref<?x3x32xf32>
// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c3_4_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c4_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[VAR_c32_:%.+]] = arith.constant 32 : index
// CHECK-DAG: [[VAR_c32_5_:%.+]] = arith.constant 32 : index
// CHECK-DAG: [[VAR_c0_6_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_3_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_6_]] : memref<?x5x?xf32>
// CHECK-DAG: [[VAR_c0_7_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_4_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_7_]] : memref<?x5x?xf32>
// CHECK-DAG: [[VAR_c5_:%.+]] = arith.constant 5 : index
// CHECK-DAG: [[VAR_c5_8_:%.+]] = arith.constant 5 : index
// CHECK-DAG: [[VAR_c9_:%.+]] = arith.constant 9 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_4_]]) {{.*}}: memref<?x9x32xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_9_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_10_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_7_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_10_]] : memref<?x1x?xf32>
// CHECK-DAG: [[VAR_c1_11_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c2_12_:%.+]] = arith.constant 2 : index
// CHECK: [[VAR_8_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_12_]] : memref<?x1x?xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 1, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 32){
// CHECK: [[VAR_14_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref<?x1x?xf32>
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref<?x9x32xf32>
// CHECK: }
// CHECK-DAG: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_13_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_14_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_10_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_14_]] : memref<?x3x32xf32>
// CHECK-DAG: [[VAR_c3_15_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c32_16_:%.+]] = arith.constant 32 : index
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 32){
// CHECK-DAG: [[VAR_14_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[VAR_c1_21_:%.+]] = arith.constant 1 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map1([[VAR_14_1_]]#1)
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_14_1_]]#0, [[VAR_14_1_]]#1, [[VAR_14_1_]]#2] : memref<?x3x32xf32>
// CHECK: krnl.store [[LOAD_PARAM_1_MEM_]], [[RES_]]{{.}}[[VAR_14_1_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_14_1_]]#2] : memref<?x9x32xf32>
// CHECK: }
// CHECK-DAG: [[LOOP_2_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_17_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_18_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_12_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_18_]] : memref<?x5x?xf32>
// CHECK-DAG: [[VAR_c5_19_:%.+]] = arith.constant 5 : index
// CHECK-DAG: [[VAR_c2_20_:%.+]] = arith.constant 2 : index
// CHECK: [[VAR_13_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_20_]] : memref<?x5x?xf32>
// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) with ([[LOOP_2_]]#0 -> [[I_6_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_2_]]#1 -> [[I_7_:%.+]] = 0 to 5, [[LOOP_2_]]#2 -> [[I_8_:%.+]] = 0 to 32){
// CHECK-DAG: [[VAR_14_2_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[VAR_c1_21_1_:%.+]] = arith.constant 1 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map1([[VAR_14_2_]]#1)
// CHECK-DAG: [[VAR_c3_22_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = affine.apply #map2([[VAR_14_2_]]#1)
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_14_2_]]#0, [[VAR_14_2_]]#1, [[VAR_14_2_]]#2] : memref<?x5x?xf32>
// CHECK: krnl.store [[LOAD_PARAM_2_MEM_]], [[RES_]]{{.}}[[VAR_14_2_]]#0, [[LOAD_PARAM_1_MEM_1_]], [[VAR_14_2_]]#2] : memref<?x9x32xf32>
// CHECK: }
// CHECK: return [[RES_]] : memref<?x9x32xf32>
// CHECK: }
}

0 comments on commit 6c38d65

Please sign in to comment.