Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into fuse-broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
chentong319 committed Oct 25, 2023
2 parents d2202d1 + 968bf5f commit a7bd850
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 12 deletions.
16 changes: 14 additions & 2 deletions src/Conversion/ONNXToKrnl/Tensor/Constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,20 @@ struct ONNXConstantOpLowering : public OpConversionPattern<ONNXConstantOp> {

// Emit the constant global in Krnl dialect.
MultiDialectBuilder<KrnlBuilder> create(rewriter, loc);
Value constantGlobal = create.krnl.constant(
memRefType, "constant_", constantOp.getValue().value());
mlir::Attribute constValAttr = constantOp.getValue().value();
if (memRefType.getElementType().isa<krnl::StringType>()) {
// If the onnx.ConstantOp has string type value attribute,
// The element type of the value attribute of krnl.global op should be
// "!krnl.string" instead of "!onnx.String".
ShapedType constStrType = RankedTensorType::get(
memRefType.getShape(), krnl::StringType::get(rewriter.getContext()));
SmallVector<StringRef> constStrVector(
constValAttr.dyn_cast<DenseElementsAttr>().getValues<StringAttr>());
ArrayRef<StringRef> constStrValues(constStrVector);
constValAttr = mlir::DenseElementsAttr::get(constStrType, constStrValues);
}
Value constantGlobal =
create.krnl.constant(memRefType, "constant_", constValAttr);

// Replace this operation with the generated krnl.global.
rewriter.replaceOp(op, constantGlobal);
Expand Down
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/ONNXDimAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ static std::optional<DimAnalysis::DimT> insertDimWhenUseful(const Value tensor,
// The correct axis is from ONNXDimOp.
axis = dimOp.getAxis();
okToInsert = true;
} else if (isa<ONNXCastOp>(op) || tensorType.isDynamicDim(axis))
} else if (isa<ONNXCastOp>(op) || (axis < (uint64_t)tensorType.getRank() &&
tensorType.isDynamicDim(axis)))
okToInsert = true;
}

Expand Down
12 changes: 5 additions & 7 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,20 +435,18 @@ template bool definedBy<ONNXConstantOp>(Value v);
template bool definedBy<ONNXDimOp>(Value v);
template bool definedBy<ONNXExpandOp>(Value v);

/// Check if a value is to store dimensions, meaning it is defined by
/// Dim/Constant/Cast/Concat.
/// Check if a value is to store dimensions, meaning it is a tensor of one
/// element or concatenation of one-element tensors.
bool areDims(Value val) {
// Value must be a 1D tensor.
Type vType = val.getType();
if (!(isRankedShapedType(vType) && (getRank(vType) == 1)))
return false;

// Base case.
if (definedBy<ONNXConstantOp>(val) || definedBy<ONNXDimOp>(val) ||
definedBy<ONNXCastOp>(val)) {
// Value must be a 1D tensor of one element.
return (getShape(vType)[0] == 1);
}
// A dimension must be a 1D tensor of one i64 element.
if ((getShape(vType)[0] == 1) && getElementType(vType).isSignlessInteger(64))
return true;

// Recursion case.
if (definedBy<ONNXConcatOp>(val)) {
Expand Down
4 changes: 2 additions & 2 deletions src/Dialect/ONNX/ONNXOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ bool hasIntegerPowerExponent(mlir::ONNXPowOp *op, int64_t &exponentValue);
template <typename OP>
bool definedBy(mlir::Value v);

/// Check if a value is to store dimensions, meaning it is defined by
/// Dim/Constant/Cast/Concat.
/// Check if a value is to store dimensions, meaning it is a tensor of one
/// element or concatenation of one-element tensors.
bool areDims(mlir::Value val);

/// Check if a value is defined by Concat to store dimensions.
Expand Down
36 changes: 36 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,39 @@ func.func private @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<
// CHECK: return [[GLOBAL]] : memref<3x2xf32>
}

// -----

func.func @test_constant_string() -> tensor<!onnx.String> {
%0 = onnx.Constant dense<"1"> : tensor<!onnx.String>
"func.return"(%0) : (tensor<!onnx.String>) -> ()
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_constant_string
// CHECK-SAME: () -> memref<!krnl.string> {
// CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [], value = dense<"1"> : tensor<!krnl.string>} : () -> memref<!krnl.string>
// CHECK: return [[VAR_0_]] : memref<!krnl.string>
}

// -----

func.func @test_constant_string_3elem() -> tensor<3x!onnx.String> {
%0 = onnx.Constant dense<["1", "2", "3"]> : tensor<3x!onnx.String>
"func.return"(%0) : (tensor<3x!onnx.String>) -> ()
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_constant_string_3elem
// CHECK-SAME: () -> memref<3x!krnl.string> {
// CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [3], value = dense<["1", "2", "3"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string>
// CHECK: return [[VAR_0_]] : memref<3x!krnl.string>
}

// -----

func.func @test_constant_string_3elem2() -> tensor<3x!onnx.String> {
%0 = onnx.Constant dense<"1"> : tensor<3x!onnx.String>
"func.return"(%0) : (tensor<3x!onnx.String>) -> ()
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_constant_string_3elem2
// CHECK-SAME: () -> memref<3x!krnl.string> {
// CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [3], value = dense<"1"> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string>
// CHECK: return [[VAR_0_]] : memref<3x!krnl.string>
// CHECK: }
}
17 changes: 17 additions & 0 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,23 @@ func.func @test_expand_with_shape(%arg0 : tensor<2x1x6x1xf32>, %arg1: tensor<6x2

// -----

func.func @test_expand_with_concat(%arg0: tensor<1xi64>, %arg1: tensor<1xi64>, %arg2: tensor<f32>) -> tensor<?x1x?xf32> {
%0 = onnx.Constant dense<1> : tensor<1xi64>
%1 = "onnx.Concat"(%arg0, %0, %arg1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64>
%2 = "onnx.Expand"(%arg2, %1) : (tensor<f32>, tensor<3xi64>) -> tensor<?x1x?xf32>
return %2 : tensor<?x1x?xf32>

// CHECK-LABEL: func.func @test_expand_with_concat
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi64>, [[PARAM_1_:%.+]]: tensor<1xi64>, [[PARAM_2_:%.+]]: tensor<f32>) -> tensor<?x1x?xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64>
// CHECK: [[VAR_1_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64>
// CHECK: [[VAR_2_:%.+]] = "onnx.Expand"([[PARAM_2_]], [[VAR_1_]]) : (tensor<f32>, tensor<3xi64>) -> tensor<?x1x?xf32>
// CHECK: return [[VAR_2_]] : tensor<?x1x?xf32>
// CHECK: }
}

// -----

//===----------------------------------------------------------------------===//
/// Test shape inference for ReduceMean.
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit a7bd850

Please sign in to comment.