-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mhlo] BatchNorm & Concat #1539
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
94d96cb
[Mhlo] BatchNorm & Concat
chenchongsong 2cedfc1
[Mhlo] use 'result' instead of 'result->getResults()'
chenchongsong e3ed189
[Mhlo] fixes for code review
chenchongsong 1b74bda
Merge branch 'main' into upstream-main
tungld 62e8096
Fixes CMakeList & logging for debug
chenchongsong 117929b
Merge branch 'main' into upstream-main
AlexandreEichenberger 5afc553
Merge branch 'main' into upstream-main
tungld File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===----------- Normalization.cpp - Lowering Normalization Ops -----------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX Normalization Operators to Mhlo dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
struct ONNXBatchNormalizationInferenceModeOpLoweringToMhlo | ||
: public ConversionPattern { | ||
ONNXBatchNormalizationInferenceModeOpLoweringToMhlo(MLIRContext *ctx) | ||
: ConversionPattern( | ||
mlir::ONNXBatchNormalizationInferenceModeOp::getOperationName(), 1, | ||
ctx) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
// batchnorm{epsilon}(x, scale, bias, mean, variance) = | ||
// scale * (x - mean) / sqrt(variance + epsilon) + bias | ||
ONNXBatchNormalizationInferenceModeOpAdaptor operandAdaptor( | ||
operands, op->getAttrDictionary()); | ||
Location loc = op->getLoc(); | ||
|
||
Value operand = operandAdaptor.X(); | ||
Value scale = operandAdaptor.scale(); | ||
Value bias = operandAdaptor.B(); | ||
Value mean = operandAdaptor.mean(); | ||
Value variance = operandAdaptor.var(); | ||
llvm::APFloat eps = operandAdaptor.epsilon(); | ||
|
||
Value result = rewriter.create<mhlo::BatchNormInferenceOp>(loc, | ||
op->getResultTypes(), operand, scale, bias, mean, variance, eps, 1); | ||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXNormalizationOpToMhloPattern( | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
patterns.insert<ONNXBatchNormalizationInferenceModeOpLoweringToMhlo>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===---------------- Concat.cpp - Lowering Concat Op -------------------===// | ||
// | ||
// Copyright 2019-2022 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers the ONNX Concat Operator to Mhlo dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "llvm/Support/Debug.h" | ||
|
||
#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" | ||
#include "src/Support/TypeUtilities.hpp" | ||
|
||
#define DEBUG_TYPE "onnx_to_mhlo" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
struct ONNXConcatOpLoweringToMhlo : public ConversionPattern { | ||
ONNXConcatOpLoweringToMhlo(MLIRContext *ctx) | ||
: ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
|
||
Location loc = op->getLoc(); | ||
ONNXConcatOpAdaptor operandAdaptor(operands); | ||
ONNXConcatOp concatOp = llvm::cast<ONNXConcatOp>(op); | ||
|
||
assert(op->getNumResults() == 1 && "ONNXConcatOp shoule have 1 result"); | ||
Type resultType = op->getResult(0).getType(); | ||
if (!onnx_mlir::isRankedShapedType(resultType)) { | ||
LLVM_DEBUG(llvm::dbgs() << "Concat Output Is Not Ranked\n"); | ||
return failure(); | ||
} | ||
int64_t rank = onnx_mlir::getRank(resultType); | ||
int64_t axis = concatOp.axis(); | ||
axis = axis >= 0 ? axis : rank + axis; | ||
assert(axis >= -rank && axis <= rank - 1 && "Axis out of rank range"); | ||
|
||
ValueRange inputs = operandAdaptor.inputs(); | ||
Value result = rewriter.create<mhlo::ConcatenateOp>( | ||
loc, op->getResultTypes(), inputs, rewriter.getI64IntegerAttr(axis)); | ||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXConcatOpToMhloPattern( | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
patterns.insert<ONNXConcatOpLoweringToMhlo>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s | ||
|
||
func @test_batch_normalization(%arg0 : tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> { | ||
%0 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> | ||
%1 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> | ||
%2 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> | ||
%3 = "onnx.Constant"() {value = dense<1.0> : tensor<3xf32>} : () -> tensor<3xf32> | ||
%4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32> | ||
"func.return"(%4) : (tensor<1x3x10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_batch_normalization | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> | ||
// CHECK-NEXT: [[VAR_1_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> | ||
// CHECK-NEXT: [[VAR_2_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> | ||
// CHECK-NEXT: [[VAR_3_:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> | ||
// CHECK-NEXT: [[VAR_4_:%.+]] = "mhlo.batch_norm_inference"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, feature_index = 1 : i64} : (tensor<1x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x3x10x10xf32> | ||
// CHECK-NEXT: return [[VAR_4_]] : tensor<1x3x10x10xf32> | ||
// CHECK-NEXT: } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s | ||
|
||
// Test when output shape is unkown | ||
func @test_concat_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>, %arg1 : tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> { | ||
%0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> | ||
"func.return"(%0) : (tensor<5x5x?x32xf32>) -> () | ||
// CHECK-LABEL: func @test_concat_dynamic_shape | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x?x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x?x32xf32> | ||
// CHECK-NEXT: } | ||
} | ||
|
||
// ----- | ||
|
||
// Test when axis is negative | ||
func @test_concat_negative_axis(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { | ||
%0 = "onnx.Concat"(%arg0, %arg1) { axis = -2 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> | ||
"func.return"(%0) : (tensor<5x5x4x32xf32>) -> () | ||
// CHECK-LABEL: func @test_concat_negative_axis | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x1x32xf32>, [[PARAM_1_:%.+]]: tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<5x5x4x32xf32> | ||
// CHECK-NEXT: } | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not specific to this PR, I am a bit concerned about failures during rewrite patterns that can cause the whole compilation pipeline terminated. This would work if all ONNX ops can be lowered to MHLO. However, if an op does not satisfy the condition here, there is no a fallback path so that the op will be lowered to other dialects such as Krnl for CPU or ZHigh for NNPA.
One way to solve this is to use
DynamicallyLegalOp
in--convert--onnx-to-mhlo
pass to selectively add operations into the conversion target. We did this for--convert-onnx-to-zhigh
: https://github.com/onnx/onnx-mlir/blob/main/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp#L323FYI @AlexandreEichenberger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment! Will discuss with co-workers about the fallback path for mhlo lowering. @yaochengji
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tungld Currently
-convert-onnx-to-hlo
usesapplyPartialConversion
and it doesn't make ONNX Dialect or any ONNX operations illegal. Therefore it is fine if there're some ONNX ops couldn't be converted to MHLO.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaochengji I see. Thanks for clarifying that!