-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mhlo] BatchNorm & Concat #1539
Changes from 4 commits
94d96cb
2cedfc1
e3ed189
1b74bda
62e8096
117929b
5afc553
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===----------- Normalization.cpp - Lowering Normalization Ops -----------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX Normalization Operators to Mhlo dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
struct ONNXBatchNormalizationInferenceModeOpLoweringToMhlo | ||
: public ConversionPattern { | ||
ONNXBatchNormalizationInferenceModeOpLoweringToMhlo(MLIRContext *ctx) | ||
: ConversionPattern( | ||
mlir::ONNXBatchNormalizationInferenceModeOp::getOperationName(), 1, | ||
ctx) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
// batchnorm{epsilon}(x, scale, bias, mean, variance) = | ||
// scale * (x - mean) / sqrt(variance + epsilon) + bias | ||
ONNXBatchNormalizationInferenceModeOpAdaptor operandAdaptor( | ||
operands, op->getAttrDictionary()); | ||
Location loc = op->getLoc(); | ||
|
||
Value operand = operandAdaptor.X(); | ||
Value scale = operandAdaptor.scale(); | ||
Value bias = operandAdaptor.B(); | ||
Value mean = operandAdaptor.mean(); | ||
Value variance = operandAdaptor.var(); | ||
llvm::APFloat eps = operandAdaptor.epsilon(); | ||
|
||
Value result = rewriter.create<mhlo::BatchNormInferenceOp>(loc, | ||
op->getResultTypes(), operand, scale, bias, mean, variance, eps, 1); | ||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXNormalizationOpToMhloPattern( | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
patterns.insert<ONNXBatchNormalizationInferenceModeOpLoweringToMhlo>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===---------------- Concat.cpp - Lowering Concat Op -------------------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers the ONNX Concat Operator to Mhlo dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" | ||
#include "src/Support/TypeUtilities.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
struct ONNXConcatOpLoweringToMhlo : public ConversionPattern { | ||
ONNXConcatOpLoweringToMhlo(MLIRContext *ctx) | ||
: ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
|
||
Location loc = op->getLoc(); | ||
ONNXConcatOpAdaptor operandAdaptor(operands); | ||
ONNXConcatOp concatOp = llvm::cast<ONNXConcatOp>(op); | ||
|
||
assert(op->getNumResults() == 1 && "ONNXConcatOp shoule have 1 result"); | ||
Type resultType = op->getResult(0).getType(); | ||
if (!onnx_mlir::isRankedShapedType(resultType)) { | ||
op->emitError() << "Concat Output Is Not Ranked\n"; | ||
return failure(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not specific to this PR, I am a bit concerned about failures during rewrite patterns that can cause the whole compilation pipeline terminated. This would work if all ONNX ops can be lowered to MHLO. However, if an op does not satisfy the condition here, there is no a fallback path so that the op will be lowered to other dialects such as Krnl for CPU or ZHigh for NNPA. One way to solve this is to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the comment! Will discuss with co-workers about the fallback path for mhlo lowering. @yaochengji There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tungld Currently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yaochengji I see. Thanks for clarifying that! |
||
} | ||
int64_t rank = onnx_mlir::getRank(resultType); | ||
int64_t axis = concatOp.axis(); | ||
axis = axis >= 0 ? axis : rank + axis; | ||
assert(axis >= -rank && axis <= rank - 1 && "Axis out of rank range"); | ||
|
||
ValueRange inputs = operandAdaptor.inputs(); | ||
Value result = rewriter.create<mhlo::ConcatenateOp>( | ||
loc, op->getResultTypes(), inputs, rewriter.getI64IntegerAttr(axis)); | ||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXConcatOpToMhloPattern( | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
patterns.insert<ONNXConcatOpLoweringToMhlo>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s | ||
|
||
func @test_batch_normalization(%arg0 : tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> { | ||
%0 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> | ||
%1 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> | ||
%2 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> | ||
%3 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> | ||
%4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32> | ||
"func.return"(%4) : (tensor<1x3x10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_batch_normalization | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> | ||
// CHECK-NEXT: [[VAR_1_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> | ||
// CHECK-NEXT: [[VAR_2_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> | ||
// CHECK-NEXT: [[VAR_3_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> | ||
// CHECK-NEXT: [[VAR_4_:%.+]] = "mhlo.batch_norm_inference"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, feature_index = 1 : i64} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32> | ||
// CHECK-NEXT: return [[VAR_4_]] : tensor<1x3x10x10xf32> | ||
// CHECK-NEXT: } | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s | ||
|
||
// Test when output shape is unkown | ||
func @test_concat_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>, %arg1 : tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> { | ||
%0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> | ||
"func.return"(%0) : (tensor<5x5x?x32xf32>) -> () | ||
// CHECK-LABEL: func @test_concat_dynamic_shape | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x?x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x?x32xf32> | ||
// CHECK-NEXT: } | ||
} | ||
|
||
// ----- | ||
|
||
// Test when axis is negative | ||
func @test_concat_negative_axis(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { | ||
%0 = "onnx.Concat"(%arg0, %arg1) { axis = -2 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> | ||
"func.return"(%0) : (tensor<5x5x4x32xf32>) -> () | ||
// CHECK-LABEL: func @test_concat_negative_axis | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x1x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x4x32xf32> | ||
// CHECK-NEXT: } | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's an error. Suggest using llvm::dbgs() and DEBUG_TYPE here.