-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fuse quantize+transpose and transpose+dequantize #49509
fuse quantize+transpose and transpose+dequantize #49509
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… fuse_quant_tranpose
… fuse_quant_tranpose
… fuse_quant_tranpose
… fuse_quant_tranpose
… fuse_quant_tranpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work 👍 . I just have a few minor suggestions.
@@ -0,0 +1,37 @@ | |||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2023
@@ -0,0 +1,173 @@ | |||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2023
@@ -0,0 +1,121 @@ | |||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2023
Graph *graph) const { | ||
PADDLE_ENFORCE_NOT_NULL( | ||
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); | ||
FusePassBase::Init("quant_transpose2_dequant_onednn_fuse_pass", graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest putting this string in a private variable named name_scope. You can put it in header in the class declaration so it's accessible in both functions.
if (!transpose2_op->Op()->HasAttr("use_mkldnn") || | ||
(transpose2_op->Op()->HasAttr("use_mkldnn") && | ||
!(PADDLE_GET_CONST(bool, | ||
transpose2_op->Op()->GetAttr("use_mkldnn"))))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When first condition gets evaluated to false, only then the second condition is evaluated so you can safely remove the second check for existance of attribute use_mkldnn and just get the value.
->assert_is_op_input("quantize", "Input"); | ||
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); | ||
auto *quant_out = pattern->NewNode(quant_out_repr()) | ||
->AsOutput() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be AsIntermediate()
because it's removed in pass.
auto *transpose2_op = | ||
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); | ||
auto dequant_in = pattern->NewNode(dequant_in_repr()) | ||
->assert_is_op_input("dequantize", "Input"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it's also removed so I think it should be AsIntermediate()
inputs={'Input': ['transpose2_output_2']}, | ||
outputs={'Output': ['dequantize_output']}, | ||
attrs={ | ||
'is_negative_input': False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if dequantize uses this attribute ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it does not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, as we discussed, let's leave it because some other places in code do read or write this attribute.
inputs={'Input': ['input_data']}, | ||
outputs={'Output': ['quantize_output']}, | ||
attrs={ | ||
'is_negative_input': False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could/should this also be sampled from True or False? If yes, please make sure generate_input function is providing data in accordance to the value.
def sample_predictor_configs(self, program_config): | ||
config = self.create_inference_config( | ||
use_mkldnn=True, | ||
passes=['quant_transpose2_dequant_onednn_fuse_pass'], | ||
) | ||
yield config, ['transpose2', 'transpose2'], (1e-5, 1e-5) | ||
|
||
def test(self): | ||
self.run_and_statis( | ||
quant=False, passes=['quant_transpose2_dequant_onednn_fuse_pass'] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see passes
being defined in sample_predictor_configs
in other tests. Maybe there it is not required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without it test does not see the pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found a few more
|
||
int found_patterns_count = 0; | ||
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, | ||
Graph *g) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Silv3S is our expert on this: He noticed lack of the below check and a need to add in constructor of the class some calls to AddOpCompat like in other passes.
Graph *g) { | |
Graph *g) { | |
if (!IsCompat(subgraph, g)) { | |
LOG(WARNING) << "Pass in op compat failed."; | |
return; | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this check I have an error: ValueError: (InvalidArgument) At least one OpCompat instance should be added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you added something like this? https://github.com/PaddlePaddle/Paddle/pull/43519/files#diff-56e4db16c655c73ec8154c0e6530c3a5bde9d180170e4308efd1ff827771520eR97
REGISTER_PASS_CAPABILITY(quant_transpose2_dequant_onednn_fuse_pass) | ||
.AddCombination( | ||
paddle::framework::compatible::OpVersionComparatorCombination().GE( | ||
"transpose2", 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Silv3S should this be .EQ("transpose2", 0);
? Because there is only version 0 and we don't know what will be in the future?
Also there should be similar line for quantize op but .LE("quantize", 2)) because it works regardless if shift attr or bfloat is present? (See the bottom of quantize_op.cc
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's more elegant to add .EQ rather than .GE, because we can't be sure that newer version will be compatible with this pass
if (output_data_type == "bf16") { | ||
out_dtype = DataType::BFLOAT16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have no test for this case here if we have performed a fuse and the fuse has set bf16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sfraczek So what would you suugest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add a unit test or to add enforce it's not bf16 until this feature is validated for bf16.
bool is_bfloat = | ||
quant_op->Op()->HasAttr("bfloat16") | ||
? PADDLE_GET_CONST(bool, quant_op->Op()->GetAttr("bfloat16")) | ||
: false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we don't have to deal with an atribute output_format
from quantize_op.cc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This attribute is only set but never used, so I skipped it
->AsInput() | ||
->assert_is_op_input("quantize", "Input"); | ||
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); | ||
auto *quant_out = pattern->NewNode(quant_out_repr()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should check if quant_out has only one output, because we want to make sure that there is only one connection between quantize and transpose so we can safely remove quantize.
} | ||
|
||
PDNode *patterns::Transpose2Dequant::operator()() { | ||
auto *transpose2_op = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, check if dequant_in has only one output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we check if dequant_in has only one input?
… fuse_quant_tranpose
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); | ||
auto dequant_op = pattern->NewNode(dequant_op_repr()) | ||
->assert_is_op("dequantize") | ||
->AsIntermediate(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think AsIntermediate has no effect on ops, only vars. I meant that you add this to the dequant_in above if it's removed.
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); | ||
auto *quant_op = pattern->NewNode(quant_op_repr()) | ||
->assert_is_op("quantize") | ||
->AsIntermediate(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you found out that we should add AsIntermediate to removed nodes, it would be great and we could fix all patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sfraczek From what I've found in code AsIntermediate is used only during subgraph validation - adding AsIntermediate prevents node from being used in multiple overlapping patterns or linked with external nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you tell if it is only for vars or ops too?
… fuse_quant_tranpose
->AsInput() | ||
->AsIntermediate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AsIntermediate() overwrites the result of AsInput()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right thanks!
… fuse_quant_tranpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
你的PR已合入Paddle库,请关注后续测试结果。 |
PR types
Others
PR changes
Others
Describe
U2++ model often has chains of quantize followed by transpose and transpose followed by dequantize operators.
As quantize, dequantize and transpose are all reorders we could extend transpose to do quantize or dequantize operation.
This PR aims at adding a fuse pass which makes below fuses:
Performance improvement (faster by 0.08s):
In BERT model avg. inference time per batch: