From 5db6e462abbe746e3a138b08643d0bd57f4a516c Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 25 Oct 2023 08:09:56 +0900 Subject: [PATCH 1/2] Relax the condition for checking dimensions defined by Concat (#2579) Signed-off-by: Tung D. Le --- src/Dialect/ONNX/ONNXDimAnalysis.cpp | 3 ++- src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 12 +++++------- src/Dialect/ONNX/ONNXOps/OpHelper.hpp | 4 ++-- test/mlir/onnx/onnx_shape_inference.mlir | 17 +++++++++++++++++ 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.cpp b/src/Dialect/ONNX/ONNXDimAnalysis.cpp index 2bb3595904..17a939ea10 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.cpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.cpp @@ -84,7 +84,8 @@ static std::optional insertDimWhenUseful(const Value tensor, // The correct axis is from ONNXDimOp. axis = dimOp.getAxis(); okToInsert = true; - } else if (isa(op) || tensorType.isDynamicDim(axis)) + } else if (isa(op) || (axis < (uint64_t)tensorType.getRank() && + tensorType.isDynamicDim(axis))) okToInsert = true; } diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 2b2d8f5bc4..6aff875a85 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -435,8 +435,8 @@ template bool definedBy(Value v); template bool definedBy(Value v); template bool definedBy(Value v); -/// Check if a value is to store dimensions, meaning it is defined by -/// Dim/Constant/Cast/Concat. +/// Check if a value is to store dimensions, meaning it is a tensor of one +/// element or concatenation of one-element tensors. bool areDims(Value val) { // Value must be a 1D tensor. Type vType = val.getType(); @@ -444,11 +444,9 @@ bool areDims(Value val) { return false; // Base case. - if (definedBy(val) || definedBy(val) || - definedBy(val)) { - // Value must be a 1D tensor of one element. - return (getShape(vType)[0] == 1); - } + // A dimension must be a 1D tensor of one i64 element. + if ((getShape(vType)[0] == 1) && getElementType(vType).isSignlessInteger(64)) + return true; // Recursion case. if (definedBy(val)) { diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index 17cbe99ca4..3e6db619e9 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -256,8 +256,8 @@ bool hasIntegerPowerExponent(mlir::ONNXPowOp *op, int64_t &exponentValue); template bool definedBy(mlir::Value v); -/// Check if a value is to store dimensions, meaning it is defined by -/// Dim/Constant/Cast/Concat. +/// Check if a value is to store dimensions, meaning it is a tensor of one +/// element or concatenation of one-element tensors. bool areDims(mlir::Value val); /// Check if a value is defined by Concat to store dimensions. diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 34ed116591..1f17490ae0 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -2292,6 +2292,23 @@ func.func @test_expand_with_shape(%arg0 : tensor<2x1x6x1xf32>, %arg1: tensor<6x2 // ----- +func.func @test_expand_with_concat(%arg0: tensor<1xi64>, %arg1: tensor<1xi64>, %arg2: tensor) -> tensor { + %0 = onnx.Constant dense<1> : tensor<1xi64> + %1 = "onnx.Concat"(%arg0, %0, %arg1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> + %2 = "onnx.Expand"(%arg2, %1) : (tensor, tensor<3xi64>) -> tensor + return %2 : tensor + +// CHECK-LABEL: func.func @test_expand_with_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi64>, [[PARAM_1_:%.+]]: tensor<1xi64>, [[PARAM_2_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> +// CHECK: [[VAR_2_:%.+]] = "onnx.Expand"([[PARAM_2_]], [[VAR_1_]]) : (tensor, tensor<3xi64>) -> tensor +// CHECK: return [[VAR_2_]] : tensor +// CHECK: } +} + +// ----- + //===----------------------------------------------------------------------===// /// Test shape inference for ReduceMean. //===----------------------------------------------------------------------===// From 968bf5f831efeb2f14a3e4aa62c91fd8bdfa07fd Mon Sep 17 00:00:00 2001 From: Yasushi Negishi Date: Wed, 25 Oct 2023 17:27:38 +0900 Subject: [PATCH 2/2] Fix onnx-to-krnl lowering of onnx.ConstantOp with string type values. (#2574) Signed-off-by: Yasushi Negishi --- src/Conversion/ONNXToKrnl/Tensor/Constant.cpp | 16 +++++++-- .../onnx_to_krnl/Tensor/Constant.mlir | 36 +++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index 19ef38abbd..dfdc35d910 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -39,8 +39,20 @@ struct ONNXConstantOpLowering : public OpConversionPattern { // Emit the constant global in Krnl dialect. MultiDialectBuilder create(rewriter, loc); - Value constantGlobal = create.krnl.constant( - memRefType, "constant_", constantOp.getValue().value()); + mlir::Attribute constValAttr = constantOp.getValue().value(); + if (memRefType.getElementType().isa()) { + // If the onnx.ConstantOp has string type value attribute, + // The element type of the value attribute of krnl.global op should be + // "!krnl.string" instead of "!onnx.String". + ShapedType constStrType = RankedTensorType::get( + memRefType.getShape(), krnl::StringType::get(rewriter.getContext())); + SmallVector constStrVector( + constValAttr.dyn_cast().getValues()); + ArrayRef constStrValues(constStrVector); + constValAttr = mlir::DenseElementsAttr::get(constStrType, constStrValues); + } + Value constantGlobal = + create.krnl.constant(memRefType, "constant_", constValAttr); // Replace this operation with the generated krnl.global. rewriter.replaceOp(op, constantGlobal); diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir index 7585ebe59a..772143937f 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir @@ -8,3 +8,39 @@ func.func private @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor< // CHECK: return [[GLOBAL]] : memref<3x2xf32> } +// ----- + +func.func @test_constant_string() -> tensor { + %0 = onnx.Constant dense<"1"> : tensor + "func.return"(%0) : (tensor) -> () + // mlir2FileCheck.py + // CHECK-LABEL: func.func @test_constant_string + // CHECK-SAME: () -> memref { + // CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [], value = dense<"1"> : tensor} : () -> memref + // CHECK: return [[VAR_0_]] : memref +} + +// ----- + +func.func @test_constant_string_3elem() -> tensor<3x!onnx.String> { + %0 = onnx.Constant dense<["1", "2", "3"]> : tensor<3x!onnx.String> + "func.return"(%0) : (tensor<3x!onnx.String>) -> () + // mlir2FileCheck.py + // CHECK-LABEL: func.func @test_constant_string_3elem + // CHECK-SAME: () -> memref<3x!krnl.string> { + // CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [3], value = dense<["1", "2", "3"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> + // CHECK: return [[VAR_0_]] : memref<3x!krnl.string> +} + +// ----- + +func.func @test_constant_string_3elem2() -> tensor<3x!onnx.String> { + %0 = onnx.Constant dense<"1"> : tensor<3x!onnx.String> + "func.return"(%0) : (tensor<3x!onnx.String>) -> () + // mlir2FileCheck.py + // CHECK-LABEL: func.func @test_constant_string_3elem2 + // CHECK-SAME: () -> memref<3x!krnl.string> { + // CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [3], value = dense<"1"> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> + // CHECK: return [[VAR_0_]] : memref<3x!krnl.string> + // CHECK: } +}