Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fuse quantize+transpose and transpose+dequantize #49509

Merged
merged 20 commits into from
Feb 8, 2023
Merged

fuse quantize+transpose and transpose+dequantize #49509

merged 20 commits into from
Feb 8, 2023

Conversation

paulinagacek
Copy link
Contributor

@paulinagacek paulinagacek commented Jan 3, 2023

PR types

Others

PR changes

Others

Describe

U2++ model often has chains of quantize followed by transpose and transpose followed by dequantize operators.
As quantize, dequantize and transpose are all reorders we could extend transpose to do quantize or dequantize operation.

This PR aims at adding a fuse pass which makes below fuses:

  • fp32 -> quantize -> transpose2 fp32 -> transpose2 (output int8)
  • int8 -> transpose2 -> dequantize int8 -> transpose2 (output fp32)

Performance improvement (faster by 0.08s):
In BERT model avg. inference time per batch:

  • without fuse - 28.24ms
  • with fuse - 28.16ms

@paddle-bot
Copy link

paddle-bot bot commented Jan 3, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paulinagacek paulinagacek marked this pull request as ready for review January 16, 2023 15:37
@paulinagacek
Copy link
Contributor Author

@wozna @sfraczek could you review please?

@jczaja jczaja self-requested a review January 30, 2023 10:11
Copy link
Contributor

@sfraczek sfraczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work 👍 . I just have a few minor suggestions.

@@ -0,0 +1,37 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2023

@@ -0,0 +1,173 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2023

@@ -0,0 +1,121 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2023

Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("quant_transpose2_dequant_onednn_fuse_pass", graph);
Copy link
Contributor

@sfraczek sfraczek Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest putting this string in a private variable named name_scope. You can put it in header in the class declaration so it's accessible in both functions.

Comment on lines 43 to 46
if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
(transpose2_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool,
transpose2_op->Op()->GetAttr("use_mkldnn"))))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When first condition gets evaluated to false, only then the second condition is evaluated so you can safely remove the second check for existance of attribute use_mkldnn and just get the value.

->assert_is_op_input("quantize", "Input");
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto *quant_out = pattern->NewNode(quant_out_repr())
->AsOutput()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be AsIntermediate() because it's removed in pass.

auto *transpose2_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");
auto dequant_in = pattern->NewNode(dequant_in_repr())
->assert_is_op_input("dequantize", "Input");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's also removed so I think it should be AsIntermediate()

inputs={'Input': ['transpose2_output_2']},
outputs={'Output': ['dequantize_output']},
attrs={
'is_negative_input': False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if dequantize uses this attribute ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it does not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, as we discussed, let's leave it because some other places in code do read or write this attribute.

inputs={'Input': ['input_data']},
outputs={'Output': ['quantize_output']},
attrs={
'is_negative_input': False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could/should this also be sampled from True or False? If yes, please make sure generate_input function is providing data in accordance to the value.

Comment on lines +107 to +117
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True,
passes=['quant_transpose2_dequant_onednn_fuse_pass'],
)
yield config, ['transpose2', 'transpose2'], (1e-5, 1e-5)

def test(self):
self.run_and_statis(
quant=False, passes=['quant_transpose2_dequant_onednn_fuse_pass']
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see passes being defined in sample_predictor_configs in other tests. Maybe there it is not required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without it test does not see the pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor

@sfraczek sfraczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a few more


int found_patterns_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Silv3S is our expert on this: He noticed lack of the below check and a need to add in constructor of the class some calls to AddOpCompat like in other passes.

Suggested change
Graph *g) {
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this check I have an error: ValueError: (InvalidArgument) At least one OpCompat instance should be added

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REGISTER_PASS_CAPABILITY(quant_transpose2_dequant_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().GE(
"transpose2", 0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Silv3S should this be .EQ("transpose2", 0);? Because there is only version 0 and we don't know what will be in the future?
Also there should be similar line for quantize op but .LE("quantize", 2)) because it works regardless if shift attr or bfloat is present? (See the bottom of quantize_op.cc)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more elegant to add .EQ rather than .GE, because we can't be sure that newer version will be compatible with this pass

Comment on lines +119 to +120
if (output_data_type == "bf16") {
out_dtype = DataType::BFLOAT16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have no test for this case here if we have performed a fuse and the fuse has set bf16

Copy link
Contributor Author

@paulinagacek paulinagacek Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfraczek So what would you suugest?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add a unit test or to add enforce it's not bf16 until this feature is validated for bf16.

Comment on lines +69 to +72
bool is_bfloat =
quant_op->Op()->HasAttr("bfloat16")
? PADDLE_GET_CONST(bool, quant_op->Op()->GetAttr("bfloat16"))
: false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we don't have to deal with an atribute output_format from quantize_op.cc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attribute is only set but never used, so I skipped it

->AsInput()
->assert_is_op_input("quantize", "Input");
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto *quant_out = pattern->NewNode(quant_out_repr())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should check if quant_out has only one output, because we want to make sure that there is only one connection between quantize and transpose so we can safely remove quantize.

}

PDNode *patterns::Transpose2Dequant::operator()() {
auto *transpose2_op =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, check if dequant_in has only one output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we check if dequant_in has only one input?

pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto dequant_op = pattern->NewNode(dequant_op_repr())
->assert_is_op("dequantize")
->AsIntermediate();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think AsIntermediate has no effect on ops, only vars. I meant that you add this to the dequant_in above if it's removed.

auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto *quant_op = pattern->NewNode(quant_op_repr())
->assert_is_op("quantize")
->AsIntermediate();
Copy link
Contributor

@sfraczek sfraczek Feb 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, ops don't seem to have roles. I think they only concern links. However, I don't know what all the code is doing in this file so I may be wrong but that's what we've been doing :).
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you found out that we should add AsIntermediate to removed nodes, it would be great and we could fix all patterns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfraczek From what I've found in code AsIntermediate is used only during subgraph validation - adding AsIntermediate prevents node from being used in multiple overlapping patterns or linked with external nodes.

Copy link
Contributor

@sfraczek sfraczek Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell if it is only for vars or ops too?

Comment on lines 1006 to 1007
->AsInput()
->AsIntermediate()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsIntermediate() overwrites the result of AsInput()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right thanks!

@onecatcn onecatcn assigned weishengying and unassigned yeliang2258 Feb 7, 2023
@paulinagacek paulinagacek requested review from sfraczek and removed request for jczaja February 7, 2023 08:39
Copy link
Contributor

@sfraczek sfraczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :)

Copy link
Contributor

@wozna wozna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jczaja jczaja merged commit 197a4ff into PaddlePaddle:develop Feb 8, 2023
@paddle-bot
Copy link

paddle-bot bot commented Feb 8, 2023

你的PR已合入Paddle库,请关注后续测试结果。
Your PR has been merged into the repository. An official integration test will be conducted later. Stay tuned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers Intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants