-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mhlo] BatchNorm & Concat #1539
Conversation
Can one of the admins verify this patch? |
Can one of the admins verify this patch? |
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s | ||
|
||
// Test when output shape is unkown | ||
func @test_concat_unknown_dims(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<?x?x?x?xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this unknown? The shape could be inferred?
What you want to demo is <5x5x?x32> , <5x5x?x32> -> <5x5x?x32> ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we call --shape-inference
before --convert-onnx-to-mhlo
? It depends on whether MHLO ops require ranked tensors not not.
@jenkins-droid test this please |
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s -split-input-file | FileCheck %s | ||
|
||
// Test when output shape is unkown | ||
func @test_concat_unknown_dims(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<?x?x?x?xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we call --shape-inference
before --convert-onnx-to-mhlo
? It depends on whether MHLO ops require ranked tensors not not.
if (op->getNumResults() < 1) { | ||
op->emitError() << "ONNXConcatOp Has No Output\n"; | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this catched by the operation verifier?
return failure(); | ||
} | ||
RankedTensorType resultType = | ||
op->getResult(0).getType().dyn_cast_or_null<RankedTensorType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, we have isRankedShapedType
in TypeUtilities` for this purpose. Please feel free to use it.
op->getResult(0).getType().dyn_cast_or_null<RankedTensorType>(); | ||
if (resultType == nullptr) { | ||
op->emitError() << "Concat Output Is Not Ranked\n"; | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not specific to this PR, I am a bit concerned about failures during rewrite patterns that can cause the whole compilation pipeline terminated. This would work if all ONNX ops can be lowered to MHLO. However, if an op does not satisfy the condition here, there is no a fallback path so that the op will be lowered to other dialects such as Krnl for CPU or ZHigh for NNPA.
One way to solve this is to use DynamicallyLegalOp
in --convert--onnx-to-mhlo
pass to selectively add operations into the conversion target. We did this for --convert-onnx-to-zhigh
: https://github.com/onnx/onnx-mlir/blob/main/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp#L323
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment! Will discuss with co-workers about the fallback path for mhlo lowering. @yaochengji
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tungld Currently -convert-onnx-to-hlo
uses applyPartialConversion
and it doesn't make ONNX Dialect or any ONNX operations illegal. Therefore it is fine if there're some ONNX ops couldn't be converted to MHLO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaochengji I see. Thanks for clarifying that!
Signed-off-by: chongsong.chen <[email protected]>
Signed-off-by: chongsong.chen <[email protected]>
Signed-off-by: chongsong.chen <[email protected]>
322c4ee
to
e3ed189
Compare
Can one of the admins verify this patch? |
1 similar comment
Can one of the admins verify this patch? |
assert(op->getNumResults() == 1 && "ONNXConcatOp shoule have 1 result"); | ||
Type resultType = op->getResult(0).getType(); | ||
if (!onnx_mlir::isRankedShapedType(resultType)) { | ||
op->emitError() << "Concat Output Is Not Ranked\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's an error. Suggest using llvm::dbgs() and DEBUG_TYPE here.
Signed-off-by: chongsong.chen <[email protected]>
Can one of the admins verify this patch? |
@chenchongsong merged with the latest before starting to make a run, hope its ok with you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
Jenkins Linux amd64 Build #6671 [push] [Mhlo] BatchNorm & Conca... started at 03:42 |
Jenkins Linux ppc64le Build #5783 [push] [Mhlo] BatchNorm & Conca... started at 04:45 |
Jenkins Linux s390x Build #6687 [push] [Mhlo] BatchNorm & Conca... started at 04:42 |
Jenkins Linux s390x Build #6687 [push] [Mhlo] BatchNorm & Conca... passed after 1 hr 43 min |
Jenkins Linux ppc64le Build #5783 [push] [Mhlo] BatchNorm & Conca... passed after 1 hr 45 min |
Jenkins Linux amd64 Build #6671 [push] [Mhlo] BatchNorm & Conca... passed after 1 hr 59 min |
This PR serves the following purposes:
ONNXBatchNormalizationInferenceModeOp
andONNXConcatOp
from ONNX-MLIR to MHLO.