Skip to content

Commit

Permalink
[Mhlo] BatchNorm & Concat (#1539)
Browse files Browse the repository at this point in the history
* [Mhlo] BatchNorm & Concat

Signed-off-by: chongsong.chen <[email protected]>
  • Loading branch information
chenchongsong authored Jul 20, 2022
1 parent ffca406 commit 3a6db23
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/Conversion/ONNXToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ add_onnx_mlir_library(OMONNXToMhlo
Math/Elementwise.cpp
Math/Gemm.cpp
Math/Softmax.cpp
NN/Normalization.cpp
NN/Pooling.cpp
Tensor/Concat.cpp
Tensor/Constant.cpp
Tensor/Reshape.cpp

Expand All @@ -24,7 +26,7 @@ add_onnx_mlir_library(OMONNXToMhlo
OMSupport
)

target_include_directories(OMONNXToMhlo PUBLIC
${CMAKE_SOURCE_DIR}/third_party/mlir-hlo/include
${CMAKE_BINARY_DIR}/third_party/mlir-hlo/include
target_include_directories(OMONNXToMhlo PUBLIC
${ONNX_MLIR_SRC_ROOT}/third_party/mlir-hlo/include
${ONNX_MLIR_BIN_ROOT}/third_party/mlir-hlo/include
)
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ void populateONNXToMhloConversionPattern(
populateLoweringONNXSoftmaxOpToMhloPattern(patterns, ctx);
populateLoweringONNXGemmOpToMhloPattern(patterns, ctx);
// Neural network
populateLoweringONNXNormalizationOpToMhloPattern(patterns, ctx);
populateLoweringONNXPoolingOpToMhloPattern(patterns, ctx);
// Tensor
populateLoweringONNXConcatOpToMhloPattern(patterns, ctx);
populateLoweringONNXConstantOpToMhloPattern(patterns, ctx);
populateLoweringONNXReshapeOpToMhloPattern(patterns, ctx);
}
Expand Down
59 changes: 59 additions & 0 deletions src/Conversion/ONNXToMhlo/NN/Normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===----------- Normalization.cpp - Lowering Normalization Ops -----------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers ONNX Normalization Operators to Mhlo dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

struct ONNXBatchNormalizationInferenceModeOpLoweringToMhlo
: public ConversionPattern {
ONNXBatchNormalizationInferenceModeOpLoweringToMhlo(MLIRContext *ctx)
: ConversionPattern(
mlir::ONNXBatchNormalizationInferenceModeOp::getOperationName(), 1,
ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
// scale * (x - mean) / sqrt(variance + epsilon) + bias
ONNXBatchNormalizationInferenceModeOpAdaptor operandAdaptor(
operands, op->getAttrDictionary());
Location loc = op->getLoc();

Value operand = operandAdaptor.X();
Value scale = operandAdaptor.scale();
Value bias = operandAdaptor.B();
Value mean = operandAdaptor.mean();
Value variance = operandAdaptor.var();
llvm::APFloat eps = operandAdaptor.epsilon();

Value result = rewriter.create<mhlo::BatchNormInferenceOp>(loc,
op->getResultTypes(), operand, scale, bias, mean, variance, eps, 1);
rewriter.replaceOp(op, result);
return success();
}
};

} // namespace

void populateLoweringONNXNormalizationOpToMhloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXBatchNormalizationInferenceModeOpLoweringToMhlo>(ctx);
}

} // namespace onnx_mlir
4 changes: 4 additions & 0 deletions src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ void populateLoweringONNXSoftmaxOpToMhloPattern(
void populateLoweringONNXGemmOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
// `NN` directory methods:
void populateLoweringONNXNormalizationOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXPoolingOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
// `Tensor` directory methods:
void populateLoweringONNXConcatOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXConstantOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXReshapeOpToMhloPattern(
Expand Down
65 changes: 65 additions & 0 deletions src/Conversion/ONNXToMhlo/Tensor/Concat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- Concat.cpp - Lowering Concat Op -------------------===//
//
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Concat Operator to Mhlo dialect.
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/Debug.h"

#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp"
#include "src/Support/TypeUtilities.hpp"

#define DEBUG_TYPE "onnx_to_mhlo"

using namespace mlir;

namespace onnx_mlir {

namespace {

struct ONNXConcatOpLoweringToMhlo : public ConversionPattern {
ONNXConcatOpLoweringToMhlo(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

Location loc = op->getLoc();
ONNXConcatOpAdaptor operandAdaptor(operands);
ONNXConcatOp concatOp = llvm::cast<ONNXConcatOp>(op);

assert(op->getNumResults() == 1 && "ONNXConcatOp shoule have 1 result");
Type resultType = op->getResult(0).getType();
if (!onnx_mlir::isRankedShapedType(resultType)) {
LLVM_DEBUG(llvm::dbgs() << "Concat Output Is Not Ranked\n");
return failure();
}
int64_t rank = onnx_mlir::getRank(resultType);
int64_t axis = concatOp.axis();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1 && "Axis out of rank range");

ValueRange inputs = operandAdaptor.inputs();
Value result = rewriter.create<mhlo::ConcatenateOp>(
loc, op->getResultTypes(), inputs, rewriter.getI64IntegerAttr(axis));
rewriter.replaceOp(op, result);
return success();
}
};

} // namespace

void populateLoweringONNXConcatOpToMhloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConcatOpLoweringToMhlo>(ctx);
}

} // namespace onnx_mlir
6 changes: 6 additions & 0 deletions src/Support/TypeUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ ArrayRef<int64_t> getShape(Type ty) {
return ty.cast<ShapedType>().getShape();
}

/// Get rank.
int64_t getRank(Type ty) {
assert(isRankedShapedType(ty) && "Type must be ranked");
return ty.cast<ShapedType>().getRank();
}

/// Get the number of elements.
int64_t getNumberOfElements(Type ty) {
ArrayRef<int64_t> shape = getShape(ty);
Expand Down
2 changes: 2 additions & 0 deletions src/Support/TypeUtilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ mlir::Type getElementType(mlir::Type ty);
bool isRankedShapedType(mlir::Type ty);
/// Get shape.
llvm::ArrayRef<int64_t> getShape(mlir::Type ty);
/// Get rank.
int64_t getRank(mlir::Type ty);
/// Get the number of elements.
int64_t getNumberOfElements(mlir::Type ty);
/// Get the element size in bytes.
Expand Down
19 changes: 19 additions & 0 deletions test/mlir/conversion/onnx_to_mhlo/NN/Normalization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s

func @test_batch_normalization(%arg0 : tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> {
%0 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32>
%1 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32>
%2 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32>
%3 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32>
%4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32>
"func.return"(%4) : (tensor<1x3x10x10xf32>) -> ()
// CHECK-LABEL: func @test_batch_normalization
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK-NEXT: [[VAR_2_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK-NEXT: [[VAR_3_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK-NEXT: [[VAR_4_:%.+]] = "mhlo.batch_norm_inference"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, feature_index = 1 : i64} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32>
// CHECK-NEXT: return [[VAR_4_]] : tensor<1x3x10x10xf32>
// CHECK-NEXT: }
}
25 changes: 25 additions & 0 deletions test/mlir/conversion/onnx_to_mhlo/Tensor/Concat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s

// Test when output shape is unkown
func @test_concat_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>, %arg1 : tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> {
%0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32>
"func.return"(%0) : (tensor<5x5x?x32xf32>) -> ()
// CHECK-LABEL: func @test_concat_dynamic_shape
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x?x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x?x32xf32>
// CHECK-NEXT: }
}

// -----

// Test when axis is negative
func @test_concat_negative_axis(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> {
%0 = "onnx.Concat"(%arg0, %arg1) { axis = -2 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32>
"func.return"(%0) : (tensor<5x5x4x32xf32>) -> ()
// CHECK-LABEL: func @test_concat_negative_axis
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x1x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x4x32xf32>
// CHECK-NEXT: }
}

0 comments on commit 3a6db23

Please sign in to comment.