Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mhlo] BatchNorm & Concat #1539

Merged
merged 7 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Conversion/ONNXToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ add_onnx_mlir_library(OMONNXToMhlo
Math/Elementwise.cpp
Math/Gemm.cpp
Math/Softmax.cpp
NN/Normalization.cpp
NN/Pooling.cpp
Tensor/Concat.cpp
Tensor/Constant.cpp
Tensor/Reshape.cpp

Expand All @@ -24,7 +26,7 @@ add_onnx_mlir_library(OMONNXToMhlo
OMSupport
)

target_include_directories(OMONNXToMhlo PUBLIC
target_include_directories(OMONNXToMhlo PUBLIC
${CMAKE_SOURCE_DIR}/third_party/mlir-hlo/include
${CMAKE_BINARY_DIR}/third_party/mlir-hlo/include
)
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ void populateONNXToMhloConversionPattern(
populateLoweringONNXSoftmaxOpToMhloPattern(patterns, ctx);
populateLoweringONNXGemmOpToMhloPattern(patterns, ctx);
// Neural network
populateLoweringONNXNormalizationOpToMhloPattern(patterns, ctx);
populateLoweringONNXPoolingOpToMhloPattern(patterns, ctx);
// Tensor
populateLoweringONNXConcatOpToMhloPattern(patterns, ctx);
populateLoweringONNXConstantOpToMhloPattern(patterns, ctx);
populateLoweringONNXReshapeOpToMhloPattern(patterns, ctx);
}
Expand Down
59 changes: 59 additions & 0 deletions src/Conversion/ONNXToMhlo/NN/Normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===----------- Normalization.cpp - Lowering Normalization Ops -----------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers ONNX Normalization Operators to Mhlo dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

struct ONNXBatchNormalizationInferenceModeOpLoweringToMhlo
: public ConversionPattern {
ONNXBatchNormalizationInferenceModeOpLoweringToMhlo(MLIRContext *ctx)
: ConversionPattern(
mlir::ONNXBatchNormalizationInferenceModeOp::getOperationName(), 1,
ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
// scale * (x - mean) / sqrt(variance + epsilon) + bias
ONNXBatchNormalizationInferenceModeOpAdaptor operandAdaptor(
operands, op->getAttrDictionary());
Location loc = op->getLoc();

Value operand = operandAdaptor.X();
Value scale = operandAdaptor.scale();
Value bias = operandAdaptor.B();
Value mean = operandAdaptor.mean();
Value variance = operandAdaptor.var();
llvm::APFloat eps = operandAdaptor.epsilon();

Value result = rewriter.create<mhlo::BatchNormInferenceOp>(loc,
op->getResultTypes(), operand, scale, bias, mean, variance, eps, 1);
rewriter.replaceOp(op, result);
return success();
}
};

} // namespace

void populateLoweringONNXNormalizationOpToMhloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXBatchNormalizationInferenceModeOpLoweringToMhlo>(ctx);
}

} // namespace onnx_mlir
4 changes: 4 additions & 0 deletions src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ void populateLoweringONNXSoftmaxOpToMhloPattern(
void populateLoweringONNXGemmOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
// `NN` directory methods:
void populateLoweringONNXNormalizationOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXPoolingOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
// `Tensor` directory methods:
void populateLoweringONNXConcatOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXConstantOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXReshapeOpToMhloPattern(
Expand Down
61 changes: 61 additions & 0 deletions src/Conversion/ONNXToMhlo/Tensor/Concat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- Concat.cpp - Lowering Concat Op -------------------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Concat Operator to Mhlo dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp"
#include "src/Support/TypeUtilities.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

struct ONNXConcatOpLoweringToMhlo : public ConversionPattern {
ONNXConcatOpLoweringToMhlo(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

Location loc = op->getLoc();
ONNXConcatOpAdaptor operandAdaptor(operands);
ONNXConcatOp concatOp = llvm::cast<ONNXConcatOp>(op);

assert(op->getNumResults() == 1 && "ONNXConcatOp shoule have 1 result");
Type resultType = op->getResult(0).getType();
if (!onnx_mlir::isRankedShapedType(resultType)) {
op->emitError() << "Concat Output Is Not Ranked\n";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's an error. Suggest using llvm::dbgs() and DEBUG_TYPE here.

return failure();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not specific to this PR, I am a bit concerned about failures during rewrite patterns that can cause the whole compilation pipeline terminated. This would work if all ONNX ops can be lowered to MHLO. However, if an op does not satisfy the condition here, there is no a fallback path so that the op will be lowered to other dialects such as Krnl for CPU or ZHigh for NNPA.

One way to solve this is to use DynamicallyLegalOp in --convert--onnx-to-mhlo pass to selectively add operations into the conversion target. We did this for --convert-onnx-to-zhigh: https://github.com/onnx/onnx-mlir/blob/main/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp#L323

FYI @AlexandreEichenberger

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment! Will discuss with co-workers about the fallback path for mhlo lowering. @yaochengji

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tungld Currently -convert-onnx-to-hlo uses applyPartialConversion and it doesn't make ONNX Dialect or any ONNX operations illegal. Therefore it is fine if there're some ONNX ops couldn't be converted to MHLO.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaochengji I see. Thanks for clarifying that!

}
int64_t rank = onnx_mlir::getRank(resultType);
int64_t axis = concatOp.axis();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1 && "Axis out of rank range");

ValueRange inputs = operandAdaptor.inputs();
Value result = rewriter.create<mhlo::ConcatenateOp>(
loc, op->getResultTypes(), inputs, rewriter.getI64IntegerAttr(axis));
rewriter.replaceOp(op, result);
return success();
}
};

} // namespace

void populateLoweringONNXConcatOpToMhloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConcatOpLoweringToMhlo>(ctx);
}

} // namespace onnx_mlir
6 changes: 6 additions & 0 deletions src/Support/TypeUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ ArrayRef<int64_t> getShape(Type ty) {
return ty.cast<ShapedType>().getShape();
}

/// Get rank.
int64_t getRank(Type ty) {
assert(isRankedShapedType(ty) && "Type must be ranked");
return ty.cast<ShapedType>().getRank();
}

/// Get the number of elements.
int64_t getNumberOfElements(Type ty) {
ArrayRef<int64_t> shape = getShape(ty);
Expand Down
2 changes: 2 additions & 0 deletions src/Support/TypeUtilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ mlir::Type getElementType(mlir::Type ty);
bool isRankedShapedType(mlir::Type ty);
/// Get shape.
llvm::ArrayRef<int64_t> getShape(mlir::Type ty);
/// Get rank.
int64_t getRank(mlir::Type ty);
/// Get the number of elements.
int64_t getNumberOfElements(mlir::Type ty);
/// Get the element size in bytes.
Expand Down
19 changes: 19 additions & 0 deletions test/mlir/conversion/onnx_to_mhlo/NN/Normalization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s

func @test_batch_normalization(%arg0 : tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> {
%0 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32>
%1 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32>
%2 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32>
%3 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32>
%4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32>
"func.return"(%4) : (tensor<1x3x10x10xf32>) -> ()
// CHECK-LABEL: func @test_batch_normalization
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK-NEXT: [[VAR_2_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK-NEXT: [[VAR_3_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK-NEXT: [[VAR_4_:%.+]] = "mhlo.batch_norm_inference"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, feature_index = 1 : i64} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32>
// CHECK-NEXT: return [[VAR_4_]] : tensor<1x3x10x10xf32>
// CHECK-NEXT: }
}
25 changes: 25 additions & 0 deletions test/mlir/conversion/onnx_to_mhlo/Tensor/Concat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s

// Test when output shape is unkown
func @test_concat_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>, %arg1 : tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> {
%0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32>
"func.return"(%0) : (tensor<5x5x?x32xf32>) -> ()
// CHECK-LABEL: func @test_concat_dynamic_shape
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x?x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x?x32xf32>
// CHECK-NEXT: }
}

// -----

// Test when axis is negative
func @test_concat_negative_axis(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> {
%0 = "onnx.Concat"(%arg0, %arg1) { axis = -2 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32>
"func.return"(%0) : (tensor<5x5x4x32xf32>) -> ()
// CHECK-LABEL: func @test_concat_negative_axis
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x1x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x4x32xf32>
// CHECK-NEXT: }
}