-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fuse quantize+transpose and transpose+dequantize #49509
Changes from all commits
8232bbe
dd593de
f9b3813
ff9d0ba
b54b5f6
db9f824
6dfdf2d
5ad2b62
8ed8af0
340634b
b6d7048
0192d21
ac0667d
bbe8fad
b25be03
7e4018f
f005c98
fb3aa34
cd54e09
92e714b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -979,6 +979,44 @@ PDNode *patterns::OperatorActivation::operator()( | |
return activation_out; | ||
} | ||
|
||
PDNode *patterns::QuantTranspose2::operator()() { | ||
auto *quant_in = pattern->NewNode(quant_in_repr()) | ||
->AsInput() | ||
->assert_is_op_input("quantize", "Input"); | ||
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); | ||
auto *quant_out = pattern->NewNode(quant_out_repr()) | ||
->AsOutput() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be |
||
->AsIntermediate() | ||
->assert_has_n_outputs(1) | ||
->assert_is_op_output("quantize") | ||
->assert_is_op_input("transpose2", "X"); | ||
auto *transpose2_op = | ||
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); | ||
|
||
quant_op->LinksFrom({quant_in}).LinksTo({quant_out}); | ||
transpose2_op->LinksFrom({quant_out}); | ||
|
||
return transpose2_op; | ||
} | ||
|
||
PDNode *patterns::Transpose2Dequant::operator()() { | ||
auto *transpose2_op = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, check if dequant_in has only one output. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't we check if dequant_in has only one input? |
||
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); | ||
auto dequant_in = pattern->NewNode(dequant_in_repr()) | ||
->AsIntermediate() | ||
->assert_has_n_inputs(1) | ||
->assert_is_op_input("dequantize", "Input"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here it's also removed so I think it should be |
||
auto dequant_op = | ||
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); | ||
auto dequant_out = pattern->NewNode(dequant_out_repr()) | ||
->AsOutput() | ||
->assert_is_op_output("dequantize", "Output"); | ||
|
||
transpose2_op->LinksTo({dequant_in}); | ||
dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); | ||
return dequant_out; | ||
} | ||
|
||
PDNode *patterns::Squeeze2Transpose2::operator()() { | ||
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr()) | ||
->AsInput() | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,194 @@ | ||||||||||||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||||||||||||||
// | ||||||||||||||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||
// you may not use this file except in compliance with the License. | ||||||||||||||
// You may obtain a copy of the License at | ||||||||||||||
// | ||||||||||||||
// http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||
// | ||||||||||||||
// Unless required by applicable law or agreed to in writing, software | ||||||||||||||
// distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||
// See the License for the specific language governing permissions and | ||||||||||||||
// limitations under the License. | ||||||||||||||
|
||||||||||||||
#include "paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h" | ||||||||||||||
#include "paddle/fluid/framework/op_version_registry.h" | ||||||||||||||
#include "paddle/fluid/string/pretty_log.h" | ||||||||||||||
|
||||||||||||||
namespace paddle { | ||||||||||||||
namespace framework { | ||||||||||||||
namespace ir { | ||||||||||||||
|
||||||||||||||
void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2( | ||||||||||||||
Graph *graph) const { | ||||||||||||||
PADDLE_ENFORCE_NOT_NULL( | ||||||||||||||
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); | ||||||||||||||
FusePassBase::Init(name_scope, graph); | ||||||||||||||
|
||||||||||||||
GraphPatternDetector gpd; | ||||||||||||||
patterns::QuantTranspose2 quant_transpose2_pattern(gpd.mutable_pattern(), | ||||||||||||||
name_scope); | ||||||||||||||
quant_transpose2_pattern(); | ||||||||||||||
|
||||||||||||||
int found_patterns_count = 0; | ||||||||||||||
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, | ||||||||||||||
Graph *g) { | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Silv3S is our expert on this: He noticed lack of the below check and a need to add in constructor of the class some calls to AddOpCompat like in other passes.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With this check I have an error: ValueError: (InvalidArgument) At least one OpCompat instance should be added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you added something like this? https://github.com/PaddlePaddle/Paddle/pull/43519/files#diff-56e4db16c655c73ec8154c0e6530c3a5bde9d180170e4308efd1ff827771520eR97 |
||||||||||||||
if (!IsCompat(subgraph, g)) { | ||||||||||||||
LOG(WARNING) << "Pass in op compat failed."; | ||||||||||||||
return; | ||||||||||||||
} | ||||||||||||||
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, quant_transpose2_pattern); | ||||||||||||||
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, quant_transpose2_pattern); | ||||||||||||||
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, quant_transpose2_pattern); | ||||||||||||||
GET_IR_NODE_FROM_SUBGRAPH( | ||||||||||||||
transpose2_op, transpose2_op, quant_transpose2_pattern); | ||||||||||||||
|
||||||||||||||
if (!transpose2_op->Op()->HasAttr("use_mkldnn") || | ||||||||||||||
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) { | ||||||||||||||
VLOG(4) | ||||||||||||||
<< "Only oneDNN version of transpose2 can be fused with quantize."; | ||||||||||||||
return; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
float scale = | ||||||||||||||
quant_op->Op()->HasAttr("Scale") | ||||||||||||||
? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale")) | ||||||||||||||
: 1; | ||||||||||||||
float shift = | ||||||||||||||
quant_op->Op()->HasAttr("Shift") | ||||||||||||||
? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Shift")) | ||||||||||||||
: 0; | ||||||||||||||
|
||||||||||||||
transpose2_op->Op()->SetAttr("scale", scale); | ||||||||||||||
transpose2_op->Op()->SetAttr("shift", shift); | ||||||||||||||
|
||||||||||||||
bool is_negative_output = | ||||||||||||||
quant_op->Op()->HasAttr("is_negative_input") | ||||||||||||||
? PADDLE_GET_CONST(bool, | ||||||||||||||
quant_op->Op()->GetAttr("is_negative_input")) | ||||||||||||||
: false; | ||||||||||||||
bool is_bfloat = | ||||||||||||||
quant_op->Op()->HasAttr("bfloat16") | ||||||||||||||
? PADDLE_GET_CONST(bool, quant_op->Op()->GetAttr("bfloat16")) | ||||||||||||||
: false; | ||||||||||||||
Comment on lines
+71
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we don't have to deal with an atribute There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This attribute is only set but never used, so I skipped it |
||||||||||||||
|
||||||||||||||
std::string output_dtype; | ||||||||||||||
if (is_bfloat) { | ||||||||||||||
output_dtype = "bf16"; | ||||||||||||||
} else if (is_negative_output) { | ||||||||||||||
output_dtype = "int8"; | ||||||||||||||
} else { | ||||||||||||||
output_dtype = "uint8"; | ||||||||||||||
} | ||||||||||||||
transpose2_op->Op()->SetAttr("output_data_type", output_dtype); | ||||||||||||||
transpose2_op->Op()->SetInput("X", | ||||||||||||||
std::vector<std::string>({quant_in->Name()})); | ||||||||||||||
|
||||||||||||||
IR_NODE_LINK_TO(quant_in, transpose2_op); | ||||||||||||||
GraphSafeRemoveNodes(graph, {quant_op, quant_out}); | ||||||||||||||
found_patterns_count++; | ||||||||||||||
}; | ||||||||||||||
gpd(graph, handler); | ||||||||||||||
AddStatis(found_patterns_count); | ||||||||||||||
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) { | ||||||||||||||
paddle::string::PrettyLogDetail("--- fused %d quant with transpose2", | ||||||||||||||
found_patterns_count); | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize( | ||||||||||||||
Graph *graph) const { | ||||||||||||||
PADDLE_ENFORCE_NOT_NULL( | ||||||||||||||
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); | ||||||||||||||
FusePassBase::Init(name_scope, graph); | ||||||||||||||
|
||||||||||||||
GraphPatternDetector gpd; | ||||||||||||||
patterns::Transpose2Dequant transpose2_dequant_pattern(gpd.mutable_pattern(), | ||||||||||||||
name_scope); | ||||||||||||||
transpose2_dequant_pattern(); | ||||||||||||||
|
||||||||||||||
int found_patterns_count = 0; | ||||||||||||||
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, | ||||||||||||||
Graph *g) { | ||||||||||||||
if (!IsCompat(subgraph, g)) { | ||||||||||||||
LOG(WARNING) << "Pass in op compat failed."; | ||||||||||||||
return; | ||||||||||||||
} | ||||||||||||||
GET_IR_NODE_FROM_SUBGRAPH( | ||||||||||||||
transpose2_op, transpose2_op, transpose2_dequant_pattern); | ||||||||||||||
GET_IR_NODE_FROM_SUBGRAPH( | ||||||||||||||
dequant_in, dequant_in, transpose2_dequant_pattern); | ||||||||||||||
GET_IR_NODE_FROM_SUBGRAPH( | ||||||||||||||
dequant_op, dequant_op, transpose2_dequant_pattern); | ||||||||||||||
GET_IR_NODE_FROM_SUBGRAPH( | ||||||||||||||
dequant_out, dequant_out, transpose2_dequant_pattern); | ||||||||||||||
|
||||||||||||||
if (!transpose2_op->Op()->HasAttr("use_mkldnn") || | ||||||||||||||
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) { | ||||||||||||||
VLOG(4) | ||||||||||||||
<< "Only oneDNN version of transpose2 can be fused with dequantize."; | ||||||||||||||
return; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
float scale = | ||||||||||||||
dequant_op->Op()->HasAttr("Scale") | ||||||||||||||
? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Scale")) | ||||||||||||||
: 1; | ||||||||||||||
float reorder_scale = 1.0 / scale; | ||||||||||||||
float shift = | ||||||||||||||
dequant_op->Op()->HasAttr("Shift") | ||||||||||||||
? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Shift")) | ||||||||||||||
: 0; | ||||||||||||||
|
||||||||||||||
transpose2_op->Op()->SetAttr("scale", reorder_scale); | ||||||||||||||
transpose2_op->Op()->SetAttr("shift", shift); | ||||||||||||||
transpose2_op->Op()->SetAttr("output_data_type", std::string("fp32")); | ||||||||||||||
transpose2_op->Op()->SetOutput( | ||||||||||||||
"Out", std::vector<std::string>({dequant_out->Name()})); | ||||||||||||||
|
||||||||||||||
IR_NODE_LINK_TO(transpose2_op, dequant_out); | ||||||||||||||
GraphSafeRemoveNodes(graph, {dequant_in, dequant_op}); | ||||||||||||||
found_patterns_count++; | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
gpd(graph, handler); | ||||||||||||||
AddStatis(found_patterns_count); | ||||||||||||||
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) { | ||||||||||||||
paddle::string::PrettyLogDetail("--- fused %d transpose2 with dequant", | ||||||||||||||
found_patterns_count); | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
void FuseQuantTranspose2DequantOneDNNPass::ApplyImpl(Graph *graph) const { | ||||||||||||||
FuseQuantizeTranspose2(graph); | ||||||||||||||
FuseTranspose2Dequantize(graph); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
FuseQuantTranspose2DequantOneDNNPass::FuseQuantTranspose2DequantOneDNNPass() { | ||||||||||||||
AddOpCompat(OpCompat("transpose2")) | ||||||||||||||
.AddInput("X") | ||||||||||||||
.IsTensor() | ||||||||||||||
.End() | ||||||||||||||
.AddOutput("Out") | ||||||||||||||
.IsTensor() | ||||||||||||||
.End() | ||||||||||||||
.AddOutput("XShape") | ||||||||||||||
.IsOptional() | ||||||||||||||
.IsTensor() | ||||||||||||||
.End() | ||||||||||||||
.AddAttr("axis") | ||||||||||||||
.IsType<std::vector<int>>() | ||||||||||||||
.End(); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
} // namespace ir | ||||||||||||||
} // namespace framework | ||||||||||||||
} // namespace paddle | ||||||||||||||
|
||||||||||||||
REGISTER_PASS(quant_transpose2_dequant_onednn_fuse_pass, | ||||||||||||||
paddle::framework::ir::FuseQuantTranspose2DequantOneDNNPass); | ||||||||||||||
REGISTER_PASS_CAPABILITY(quant_transpose2_dequant_onednn_fuse_pass) | ||||||||||||||
.AddCombination( | ||||||||||||||
paddle::framework::compatible::OpVersionComparatorCombination().EQ( | ||||||||||||||
"transpose2", 0)); | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Silv3S should this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's more elegant to add .EQ rather than .GE, because we can't be sure that newer version will be compatible with this pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/ir/fuse_pass_base.h" | ||
#include "paddle/fluid/framework/ir/graph.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
class FuseQuantTranspose2DequantOneDNNPass : public FusePassBase { | ||
public: | ||
virtual ~FuseQuantTranspose2DequantOneDNNPass() {} | ||
FuseQuantTranspose2DequantOneDNNPass(); | ||
|
||
protected: | ||
void ApplyImpl(Graph *graph) const override; | ||
void FuseQuantizeTranspose2(Graph *graph) const; | ||
void FuseTranspose2Dequantize(Graph *graph) const; | ||
|
||
private: | ||
std::string name_scope = "quant_transpose2_dequant_onednn_fuse_pass"; | ||
}; | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
|
||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should check if quant_out has only one output, because we want to make sure that there is only one connection between quantize and transpose so we can safely remove quantize.