Skip to content

Commit

Permalink
Forward input dims in Reshape ShapeHelper (onnx#2828)
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
tungld and AlexandreEichenberger authored Jun 12, 2024
1 parent 7f4f510 commit 713fc2e
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 53 deletions.
8 changes: 7 additions & 1 deletion src/Dialect/Mlir/IndexExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,9 +1095,15 @@ IndexExpr IndexExpr::clamp(IndexExpr const min, IndexExpr const max) const {
assert(trueVal.canBeUsedInScope() && "trueVal incompatible scope");
assert(falseVal.canBeUsedInScope() && "falseVal incompatible scope");
// When compare result is literal, just feed forward the right value.
// Do not deep copy the question mark to keep it unchanged.
if (compare.isLiteral()) {
if (compare.getLiteral())
if (compare.getLiteral()) {
if (trueVal.isQuestionmark())
return trueVal;
return trueVal.deepCopy();
}
if (falseVal.isQuestionmark())
return falseVal;
return falseVal.deepCopy();
}
// Dynamic value, just set as undefined during shape inference pass.
Expand Down
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() {
// should be -1 (represented as QuestionmarkIndexExpr)
for (unsigned i = 0; i < outputRank; ++i) {
if (hasShapeAndRank(data)) {
IndexExpr dimShape = createIE->getIntFromArrayAsSymbol(shape, i);
outputDims[i] = outputDims[i].selectOrSelf(
outputDims[i] == -1, numOfElements.floorDiv(numOfElementsFromShape));
dimShape == -1, numOfElements.floorDiv(numOfElementsFromShape));
} else {
// ToFix: can not check getAllowzero because the operandAdaptor is
// constructed without attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,56 +8,67 @@ func.func private @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>)
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()

// CHECK-LABEL: func private @test_reshape
// CHECK: ([[PARAM_0_:%.+]]: memref<?x10xf32>, [[PARAM_1_:%.+]]: memref<4xi64>) -> memref<?x?x?x?xf32> {
// CHECK-LABEL: func.func private @test_reshape
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x10xf32>, [[PARAM_1_:%.+]]: memref<4xi64>) -> memref<?x?x?x?xf32> {
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : index
// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_0_]]{{.}}
// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
// CHECK-DAG: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_0_]]{{.}} : memref<4xi64>
// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index
// CHECK-DAG: [[VAR_5_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi eq, [[VAR_3_]], [[CST_0_]] : index
// CHECK: [[VAR_6_:%.+]] = arith.select [[VAR_4_]], [[VAR_5_]], [[VAR_3_]] : index
// CHECK: [[VAR_7_:%.+]] = arith.cmpi eq, [[VAR_6_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[CST_1_]], [[VAR_6_]] : index
// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_]] : i64 to index
// CHECK-DAG: [[VAR_3_:%.+]] = arith.cmpi eq, [[VAR_2_]], [[CST_0_]] : index
// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
// CHECK: [[VAR_4_:%.+]] = arith.select [[VAR_3_]], [[VAR_dim_0_]], [[VAR_2_]] : index
// CHECK: [[VAR_5_:%.+]] = arith.cmpi eq, [[VAR_4_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_6_:%.+]] = arith.select [[VAR_5_]], [[CST_1_]], [[VAR_4_]] : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_1_]]{{.}} : memref<4xi64>
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_1_]] : i64 to index
// CHECK: [[VAR_11_:%.+]] = arith.cmpi eq, [[VAR_10_]], [[CST_0_]] : index
// CHECK: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[CST_10_]], [[VAR_10_]] : index
// CHECK: [[VAR_13_:%.+]] = arith.cmpi eq, [[VAR_12_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_14_:%.+]] = arith.select [[VAR_13_]], [[CST_1_]], [[VAR_12_]] : index
// CHECK-DAG: [[VAR_15_:%.+]] = arith.muli [[VAR_8_]], [[VAR_14_]] : index
// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_1_]] : i64 to index
// CHECK: [[VAR_9_:%.+]] = arith.cmpi eq, [[VAR_8_]], [[CST_0_]] : index
// CHECK: [[VAR_10_:%.+]] = arith.select [[VAR_9_]], [[CST_10_]], [[VAR_8_]] : index
// CHECK: [[VAR_11_:%.+]] = arith.cmpi eq, [[VAR_10_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[CST_1_]], [[VAR_10_]] : index
// CHECK-DAG: [[VAR_13_:%.+]] = arith.muli [[VAR_6_]], [[VAR_12_]] : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_2_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_2_]]{{.}} : memref<4xi64>
// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_2_]] : i64 to index
// CHECK: [[VAR_18_:%.+]] = arith.cmpi eq, [[VAR_17_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_19_:%.+]] = arith.select [[VAR_18_]], [[CST_1_]], [[VAR_17_]] : index
// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[VAR_15_]], [[VAR_19_]] : index
// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_2_]] : i64 to index
// CHECK: [[VAR_16_:%.+]] = arith.cmpi eq, [[VAR_15_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_17_:%.+]] = arith.select [[VAR_16_]], [[CST_1_]], [[VAR_15_]] : index
// CHECK-DAG: [[VAR_18_:%.+]] = arith.muli [[VAR_13_]], [[VAR_17_]] : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_3_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_3_]]{{.}} : memref<4xi64>
// CHECK: [[VAR_22_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_3_]] : i64 to index
// CHECK: [[VAR_23_:%.+]] = arith.cmpi eq, [[VAR_22_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_24_:%.+]] = arith.select [[VAR_23_]], [[CST_1_]], [[VAR_22_]] : index
// CHECK-DAG: [[VAR_25_:%.+]] = arith.muli [[VAR_20_]], [[VAR_24_]] : index
// CHECK-DAG: [[VAR_26_:%.+]] = arith.cmpi eq, [[VAR_6_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_27_:%.+]] = arith.floordivsi [[VAR_1_]], [[VAR_25_]] : index
// CHECK-DAG: [[VAR_29_:%.+]] = arith.cmpi eq, [[VAR_12_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_30_:%.+]] = arith.floordivsi [[VAR_1_]], [[VAR_25_]] : index
// CHECK-DAG: [[VAR_28_:%.+]] = arith.select [[VAR_26_]], [[VAR_27_]], [[VAR_6_]] : index
// CHECK-DAG: [[VAR_31_:%.+]] = arith.select [[VAR_29_]], [[VAR_30_]], [[VAR_12_]] : index
// CHECK-DAG: [[VAR_32_:%.+]] = arith.cmpi eq, [[VAR_17_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_33_:%.+]] = arith.floordivsi [[VAR_1_]], [[VAR_25_]] : index
// CHECK-DAG: [[VAR_34_:%.+]] = arith.select [[VAR_32_]], [[VAR_33_]], [[VAR_17_]] : index
// CHECK-DAG: [[VAR_36_:%.+]] = arith.floordivsi [[VAR_1_]], [[VAR_25_]] : index
// CHECK-DAG: [[VAR_35_:%.+]] = arith.cmpi eq, [[VAR_22_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_37_:%.+]] = arith.select [[VAR_35_]], [[VAR_36_]], [[VAR_22_]] : index
// CHECK: [[VAR_38_:%.+]] = arith.muli [[VAR_37_]], [[VAR_34_]] : index
// CHECK: [[VAR_39_:%.+]] = arith.muli [[VAR_38_]], [[VAR_31_]] : index
// CHECK: [[VAR_40_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: {{.}}[[VAR_28_]], [[VAR_31_]], [[VAR_34_]], [[VAR_37_]]{{.}}, strides: {{.}}[[VAR_39_]], [[VAR_38_]], [[VAR_37_]], 1] : memref<?x10xf32> to memref<?x?x?x?xf32>
// CHECK: return [[VAR_40_]] : memref<?x?x?x?xf32>
// CHECK: [[VAR_20_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_3_]] : i64 to index
// CHECK: [[VAR_21_:%.+]] = arith.cmpi eq, [[VAR_20_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[CST_1_]], [[VAR_20_]] : index
// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[VAR_18_]], [[VAR_22_]] : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_4_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_0_]]{{.}} : memref<4xi64>
// CHECK: [[VAR_25_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_4_]] : i64 to index
// CHECK-DAG: [[VAR_26_:%.+]] = arith.cmpi eq, [[VAR_25_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_27_:%.+]] = arith.floordivsi [[VAR_0_]], [[VAR_23_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_28_:%.+]] = arith.select [[VAR_26_]], [[VAR_27_]], [[VAR_4_]] : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_5_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_1_]]{{.}} : memref<4xi64>
// CHECK: [[VAR_30_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_5_]] : i64 to index
// CHECK-DAG: [[VAR_31_:%.+]] = arith.cmpi eq, [[VAR_30_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_32_:%.+]] = arith.floordivsi [[VAR_0_]], [[VAR_23_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_33_:%.+]] = arith.select [[VAR_31_]], [[VAR_32_]], [[VAR_10_]] : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_6_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_2_]]{{.}} : memref<4xi64>
// CHECK: [[VAR_35_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_6_]] : i64 to index
// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpi eq, [[VAR_35_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_37_:%.+]] = arith.floordivsi [[VAR_0_]], [[VAR_23_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_15_]] : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_7_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_3_]]{{.}} : memref<4xi64>
// CHECK: [[VAR_40_:%.+]] = arith.index_cast [[LOAD_PARAM_1_MEM_7_]] : i64 to index
// CHECK-DAG: [[VAR_41_:%.+]] = arith.cmpi eq, [[VAR_40_]], [[CST_minus_1_]] : index
// CHECK-DAG: [[VAR_42_:%.+]] = arith.floordivsi [[VAR_0_]], [[VAR_23_]] : index
// CHECK: [[VAR_43_:%.+]] = arith.select [[VAR_41_]], [[VAR_42_]], [[VAR_20_]] : index
// CHECK: [[VAR_44_:%.+]] = arith.muli [[VAR_43_]], [[VAR_38_]] : index
// CHECK: [[VAR_45_:%.+]] = arith.muli [[VAR_44_]], [[VAR_33_]] : index
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: {{.}}[[VAR_28_]], [[VAR_33_]], [[VAR_38_]], [[VAR_43_]]{{.}}, strides: {{.}}[[VAR_45_]], [[VAR_44_]], [[VAR_43_]], 1] : memref<?x10xf32> to memref<?x?x?x?xf32>
// CHECK: return [[VAR_reinterpret_cast_]] : memref<?x?x?x?xf32>
// CHECK: }
}

Expand Down
26 changes: 16 additions & 10 deletions test/mlir/conversion/onnx_to_stablehlo/Tensor/Reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,22 @@ func.func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi
// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_19_]], [[CST_32_]], [[VAR_18_]] : index
// CHECK: [[VAR_21_:%.+]] = arith.cmpi eq, [[VAR_20_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[CST_1_]], [[VAR_20_]] : index
// CHECK: [[VAR_23_:%.+]] = arith.muli [[VAR_17_]], [[VAR_22_]] : index
// CHECK: [[VAR_24_:%.+]] = arith.floordivsi [[CST_800_]], [[VAR_23_]] : index
// CHECK-DAG: [[VAR_25_:%.+]] = arith.select [[VAR_4_]], [[VAR_24_]], [[VAR_3_]] : index
// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_9_]], [[VAR_24_]], [[VAR_8_]] : index
// CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_15_]], [[VAR_24_]], [[VAR_14_]] : index
// CHECK-DAG: [[VAR_28_:%.+]] = arith.select [[VAR_21_]], [[VAR_24_]], [[VAR_20_]] : index
// CHECK: [[VAR_29_:%.+]] = shape.from_extents [[VAR_25_]], [[VAR_26_]], [[VAR_27_]], [[VAR_28_]] : index, index, index, index
// CHECK: [[VAR_30_:%.+]] = shape.to_extent_tensor [[VAR_29_]] : !shape.shape -> tensor<4xindex>
// CHECK: [[VAR_31_:%.+]] = stablehlo.dynamic_reshape [[PARAM_0_]], [[VAR_30_]] : (tensor<5x5x1x32xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK: return [[VAR_31_]] : tensor<?x?x?x?xf32>
// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[VAR_17_]], [[VAR_22_]] : index
// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpi eq, [[VAR_1_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_25_:%.+]] = arith.floordivsi [[CST_800_]], [[VAR_23_]] : index
// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_3_]] : index
// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpi eq, [[VAR_6_]], [[CST_minus_1_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_25_]], [[VAR_8_]] : index
// CHECK-DAG: [[VAR_29_:%.+]] = arith.cmpi eq, [[VAR_12_]], [[CST_minus_1_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_30_:%.+]] = arith.select [[VAR_29_]], [[VAR_25_]], [[VAR_14_]] : index
// CHECK-DAG: [[VAR_31_:%.+]] = arith.cmpi eq, [[VAR_18_]], [[CST_minus_1_]] : index
// CHECK: [[VAR_32_:%.+]] = arith.select [[VAR_31_]], [[VAR_25_]], [[VAR_20_]] : index
// CHECK: [[VAR_33_:%.+]] = shape.from_extents [[VAR_26_]], [[VAR_28_]], [[VAR_30_]], [[VAR_32_]] : index, index, index, index
// CHECK: [[VAR_34_:%.+]] = shape.to_extent_tensor [[VAR_33_]] : !shape.shape -> tensor<4xindex>
// CHECK: [[VAR_35_:%.+]] = stablehlo.dynamic_reshape [[PARAM_0_]], [[VAR_34_]] : (tensor<5x5x1x32xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK: return [[VAR_35_]] : tensor<?x?x?x?xf32>
// CHECK: }
}

Expand Down
19 changes: 19 additions & 0 deletions test/mlir/onnx/onnx_dim_analysis.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,25 @@ func.func @test_reshape_single_dyn_dim(%arg0: tensor<8x?x16x4xf32>) -> tensor<?x

// -----

func.func @test_reshape_allowzero(%arg0: tensor<?x?x768xf32>) -> tensor<?x?x12x64xf32> {
%184 = onnx.Constant dense<[0, 0, 12, 64]> : tensor<4xi64>
%494 = "onnx.Reshape"(%arg0, %184) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
"onnx.Return"(%494) : (tensor<?x?x12x64xf32>) -> ()

// CHECK-LABEL: func.func @test_reshape_allowzero
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>) -> tensor<?x?x12x64xf32> {
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?x768xf32>) -> ()
// CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?x768xf32>) -> ()
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 0, 12, 64]> : tensor<4xi64>
// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x12x64xf32>
// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor<?x?x12x64xf32>) -> ()
// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor<?x?x12x64xf32>) -> ()
// CHECK: onnx.Return [[VAR_1_]] : tensor<?x?x12x64xf32>
// CHECK: }
}

// -----

func.func @test_expand_from_concat_dims(%arg0: tensor<1x256xi64>, %arg1: tensor<?x256xi64>) -> tensor<?x256xi64> {
%0 = onnx.Constant dense<256> : tensor<1xi64>
%1 = "onnx.Dim"(%arg1) {axis = 0 : si64} : (tensor<?x256xi64>) -> tensor<1xi64>
Expand Down

0 comments on commit 713fc2e

Please sign in to comment.