Skip to content

Commit

Permalink
lowering
Browse files Browse the repository at this point in the history
Signed-off-by: chentong319 <[email protected]>
  • Loading branch information
chentong319 committed Oct 11, 2022
1 parent 6c38d65 commit a74cb69
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 27 deletions.
14 changes: 10 additions & 4 deletions src/Conversion/ONNXToKrnl/Tensor/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ struct ONNXConcatOpLowering : public ConversionPattern {
// dim of different inputs.
KrnlBuilder createKrnl(rewriter, loc);
SmallVector<IndexExpr, 4> commonUB(shapeHelper.dimsForOutput());
// IndexExprScope IEScope(&rewriter, loc);
IndexExpr accumulatedOffset = LiteralIndexExpr(0);
for (unsigned int i = 0; i < inputNum; ++i) {
// Since the acculatedOffsetValue will be used in a nested IndexExprScope,
// we get the Value of this IndexExpr and pass it as a symbol
Value accumulatedOffsetValue = accumulatedOffset.getValue();
OpBuilder::InsertionGuard insertGuard(rewriter);
// Create loop.
ValueRange loopDef = createKrnl.defineLoops(rank);
Expand All @@ -84,17 +89,18 @@ struct ONNXConcatOpLowering : public ConversionPattern {
else {
IndexExprScope IEScope(&rewriter, loc);
IndexExpr writeOffset = DimIndexExpr(loopInd[r]);
for (unsigned int j = 0; j < i; j++) {
MemRefBoundsIndexCapture operandJBounds(operands[j]);
writeOffset = writeOffset + operandJBounds.getDim(r);
}
IndexExpr accumulatedOffsetIE =
SymbolIndexExpr(accumulatedOffsetValue);
writeOffset = writeOffset + accumulatedOffsetIE;
writeIndices.emplace_back(writeOffset.getValue());
}
}
// Insert copy.
Value loadData = createKrnl.load(operands[i], loopInd);
createKrnl.store(loadData, alloc, writeIndices);
});
MemRefBoundsIndexCapture operandJBounds(operands[i]);
accumulatedOffset = accumulatedOffset + operandJBounds.getDim(axis);
}
rewriter.replaceOp(op, alloc);
return success();
Expand Down
155 changes: 132 additions & 23 deletions test/mlir/onnx/onnx_lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3221,52 +3221,161 @@ func.func @test_concat_4(%arg0 : tensor<?x1x?xf32>, %arg1 : tensor<?x3x32xf32>,
// CHECK-DAG: [[VAR_c9_:%.+]] = arith.constant 9 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_4_]]) {{.*}}: memref<?x9x32xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_9_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_10_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_11_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_7_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_10_]] : memref<?x1x?xf32>
// CHECK-DAG: [[VAR_c1_11_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c2_12_:%.+]] = arith.constant 2 : index
// CHECK: [[VAR_8_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_12_]] : memref<?x1x?xf32>
// CHECK-DAG: [[VAR_7_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_11_]] : memref<?x1x?xf32>
// CHECK-DAG: [[VAR_c1_12_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c2_13_:%.+]] = arith.constant 2 : index
// CHECK: [[VAR_8_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_13_]] : memref<?x1x?xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 1, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 32){
// CHECK: [[VAR_14_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref<?x1x?xf32>
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref<?x9x32xf32>
// CHECK: }
// CHECK-DAG: [[VAR_c1_14_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c1_15_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_13_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_14_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_16_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_17_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_10_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_14_]] : memref<?x3x32xf32>
// CHECK-DAG: [[VAR_c3_15_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c32_16_:%.+]] = arith.constant 32 : index
// CHECK-DAG: [[VAR_10_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_17_]] : memref<?x3x32xf32>
// CHECK-DAG: [[VAR_c3_18_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c32_19_:%.+]] = arith.constant 32 : index
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 32){
// CHECK-DAG: [[VAR_14_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[VAR_c1_21_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c1_28_:%.+]] = arith.constant 1 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map1([[VAR_14_1_]]#1)
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_14_1_]]#0, [[VAR_14_1_]]#1, [[VAR_14_1_]]#2] : memref<?x3x32xf32>
// CHECK: krnl.store [[LOAD_PARAM_1_MEM_]], [[RES_]]{{.}}[[VAR_14_1_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_14_1_]]#2] : memref<?x9x32xf32>
// CHECK: }
// CHECK-DAG: [[VAR_c3_20_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c4_21_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[LOOP_2_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_17_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_18_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_22_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_23_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_12_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_18_]] : memref<?x5x?xf32>
// CHECK-DAG: [[VAR_c5_19_:%.+]] = arith.constant 5 : index
// CHECK-DAG: [[VAR_c2_20_:%.+]] = arith.constant 2 : index
// CHECK: [[VAR_13_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_20_]] : memref<?x5x?xf32>
// CHECK-DAG: [[VAR_12_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_23_]] : memref<?x5x?xf32>
// CHECK-DAG: [[VAR_c5_24_:%.+]] = arith.constant 5 : index
// CHECK-DAG: [[VAR_c2_25_:%.+]] = arith.constant 2 : index
// CHECK: [[VAR_13_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_25_]] : memref<?x5x?xf32>
// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) with ([[LOOP_2_]]#0 -> [[I_6_:%.+]] = 0 to #map0([[VAR_4_]]), [[LOOP_2_]]#1 -> [[I_7_:%.+]] = 0 to 5, [[LOOP_2_]]#2 -> [[I_8_:%.+]] = 0 to 32){
// CHECK-DAG: [[VAR_14_2_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[VAR_c1_21_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_c4_28_:%.+]] = arith.constant 4 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map1([[VAR_14_2_]]#1)
// CHECK-DAG: [[VAR_c3_22_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = affine.apply #map2([[VAR_14_2_]]#1)
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_14_2_]]#0, [[VAR_14_2_]]#1, [[VAR_14_2_]]#2] : memref<?x5x?xf32>
// CHECK: krnl.store [[LOAD_PARAM_2_MEM_]], [[RES_]]{{.}}[[VAR_14_2_]]#0, [[LOAD_PARAM_1_MEM_1_]], [[VAR_14_2_]]#2] : memref<?x9x32xf32>
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map2([[VAR_14_2_]]#1)
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_14_2_]]#0, [[VAR_14_2_]]#1, [[VAR_14_2_]]#2] : memref<?x5x?xf32>
// CHECK: krnl.store [[LOAD_PARAM_1_MEM_1_]], [[RES_]]{{.}}[[VAR_14_2_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_14_2_]]#2] : memref<?x9x32xf32>
// CHECK: }
// CHECK-DAG: [[VAR_c5_26_:%.+]] = arith.constant 5 : index
// CHECK-DAG: [[VAR_c9_27_:%.+]] = arith.constant 9 : index
// CHECK: return [[RES_]] : memref<?x9x32xf32>
// CHECK: }
}

// -----


func.func @test_concat_5(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x3x32xf32>, %arg2 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 : si64} : (tensor<?x?x?xf32>, tensor<?x3x32xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
"func.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-DAG: #map0 = affine_map<(d0) -> (d0)>
// CHECK-DAG: #map1 = affine_map<(d0) -> (d0 + 3)>
// CHECK-DAG: #map2 = affine_map<(d0, d1) -> (d0 + d1 + 3)>
// CHECK-DAG: #map3 = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-DAG: #map4 = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK-DAG: #map5 = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
// CHECK-DAG: #map6 = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
// CHECK-DAG: #map7 = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #map8 = affine_map<(d0, d1, d2, d3, d4) -> (d4 + 3)>
// CHECK-DAG: #map9 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
// CHECK-DAG: #map10 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4 + d6 + 3)>
// CHECK-LABEL: func.func @test_concat_5
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x?x?xf32>, [[PARAM_1_:%.+]]: memref<?x3x32xf32>, [[PARAM_2_:%.+]]: memref<?x?x?xf32>) -> memref<?x?x32xf32> {
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_0_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK: [[VAR_1_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply #map0([[VAR_1_]])
// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_3_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_c0_1_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_4_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_1_]] : memref<?x3x32xf32>
// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c3_2_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_5_:%.+]] = affine.apply #map1([[VAR_1_]])
// CHECK-DAG: [[VAR_c32_:%.+]] = arith.constant 32 : index
// CHECK-DAG: [[VAR_c32_3_:%.+]] = arith.constant 32 : index
// CHECK-DAG: [[VAR_c0_4_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_6_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_4_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_c0_5_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_7_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_5_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_c1_6_:%.+]] = arith.constant 1 : index
// CHECK: [[VAR_8_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c1_6_]] : memref<?x?x?xf32>
// CHECK: [[VAR_9_:%.+]] = affine.apply #map2([[VAR_1_]], [[VAR_8_]])
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_7_]], [[VAR_9_]]) {{.*}}: memref<?x?x32xf32>
// CHECK-DAG: [[VAR_c0_7_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_8_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_9_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_12_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_9_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_c1_10_:%.+]] = arith.constant 1 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_13_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_10_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_c2_11_:%.+]] = arith.constant 2 : index
// CHECK: [[VAR_14_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c2_11_]] : memref<?x?x?xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to #map3([[VAR_1_]], [[VAR_8_]], [[VAR_7_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to #map4([[VAR_1_]], [[VAR_8_]], [[VAR_7_]], [[VAR_1_]]3), [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 32){
// CHECK: [[VAR_26_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_26_]]#0, [[VAR_26_]]#1, [[VAR_26_]]#2] : memref<?x?x?xf32>
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_]]{{.}}[[VAR_26_]]#0, [[VAR_26_]]#1, [[VAR_26_]]#2] : memref<?x?x32xf32>
// CHECK: }
// CHECK: [[VAR_c1_12_:%.+]] = arith.constant 1 : index
// CHECK: [[VAR_15_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_12_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_16_:%.+]] = affine.apply #map5([[VAR_1_]], [[VAR_8_]], [[VAR_7_]], [[VAR_1_]]3, [[VAR_1_]]5)
// CHECK-DAG: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_13_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_14_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_18_:%.+]] = memref.dim [[PARAM_1_]], [[VAR_c0_14_]] : memref<?x3x32xf32>
// CHECK-DAG: [[VAR_c3_15_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_c32_16_:%.+]] = arith.constant 32 : index
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to #map6([[VAR_1_]], [[VAR_8_]], [[VAR_7_]], [[VAR_1_]]3, [[VAR_1_]]5), [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 32){
// CHECK: [[VAR_26_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map7([[VAR_26_1_]]#1){{.}}[[VAR_16_]]{{.}}
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_26_1_]]#0, [[VAR_26_1_]]#1, [[VAR_26_1_]]#2] : memref<?x3x32xf32>
// CHECK: krnl.store [[LOAD_PARAM_1_MEM_]], [[RES_]]{{.}}[[VAR_26_1_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_26_1_]]#2] : memref<?x?x32xf32>
// CHECK: }
// CHECK-DAG: [[VAR_c3_17_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[VAR_19_:%.+]] = affine.apply #map8([[VAR_1_]], [[VAR_8_]], [[VAR_7_]], [[VAR_1_]]3, [[VAR_1_]]5)
// CHECK-DAG: [[LOOP_2_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_c0_18_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_c0_19_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_21_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c0_19_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_c1_20_:%.+]] = arith.constant 1 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_22_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c1_20_]] : memref<?x?x?xf32>
// CHECK-DAG: [[VAR_c2_21_:%.+]] = arith.constant 2 : index
// CHECK: [[VAR_23_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c2_21_]] : memref<?x?x?xf32>
// CHECK: krnl.iterate([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) with ([[LOOP_2_]]#0 -> [[I_6_:%.+]] = 0 to #map6([[VAR_1_]], [[VAR_8_]], [[VAR_7_]], [[VAR_1_]]3, [[VAR_1_]]5), [[LOOP_2_]]#1 -> [[I_7_:%.+]] = 0 to #map9([[VAR_1_]], [[VAR_8_]], [[VAR_7_]], [[VAR_1_]]3, [[VAR_1_]]5, [[VAR_22_]]), [[LOOP_2_]]#2 -> [[I_8_:%.+]] = 0 to 32){
// CHECK: [[VAR_26_2_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_2_]]#0, [[LOOP_2_]]#1, [[LOOP_2_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = affine.apply #map7([[VAR_26_2_]]#1){{.}}[[VAR_19_]]{{.}}
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_26_2_]]#0, [[VAR_26_2_]]#1, [[VAR_26_2_]]#2] : memref<?x?x?xf32>
// CHECK: krnl.store [[LOAD_PARAM_1_MEM_1_]], [[RES_]]{{.}}[[VAR_26_2_]]#0, [[LOAD_PARAM_0_MEM_1_]], [[VAR_26_2_]]#2] : memref<?x?x32xf32>
// CHECK: }
// CHECK: [[VAR_c1_22_:%.+]] = arith.constant 1 : index
// CHECK: [[VAR_24_:%.+]] = memref.dim [[PARAM_2_]], [[VAR_c1_22_]] : memref<?x?x?xf32>
// CHECK: [[VAR_25_:%.+]] = affine.apply #map10([[VAR_1_]], [[VAR_8_]], [[VAR_7_]], [[VAR_1_]]3, [[VAR_1_]]5, [[VAR_22_]], [[VAR_24_]])
// CHECK: return [[RES_]] : memref<?x?x32xf32>
// CHECK: }
}

0 comments on commit a74cb69

Please sign in to comment.