Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fuse quantize+transpose and transpose+dequantize #49509

Merged
merged 20 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8232bbe
QuantTranpose pattern is being found by pass
paulinagacek Dec 14, 2022
dd593de
quant + transpose fuse
paulinagacek Jan 10, 2023
f9b3813
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Jan 10, 2023
ff9d0ba
code style changes
paulinagacek Jan 12, 2023
b54b5f6
UT written, reorder fixed
paulinagacek Jan 13, 2023
db9f824
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Jan 13, 2023
6dfdf2d
Dequantize + transpose2 fuse added
paulinagacek Jan 16, 2023
5ad2b62
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Jan 16, 2023
8ed8af0
pass name changed
paulinagacek Jan 16, 2023
340634b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Jan 16, 2023
b6d7048
UT added & shift corrected
paulinagacek Jan 16, 2023
0192d21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Jan 16, 2023
ac0667d
got rid of redundancy
paulinagacek Jan 16, 2023
bbe8fad
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Jan 16, 2023
b25be03
review changes
paulinagacek Feb 1, 2023
7e4018f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Feb 1, 2023
f005c98
AsIntermediate corrected
paulinagacek Feb 3, 2023
fb3aa34
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Feb 3, 2023
cd54e09
compat added
paulinagacek Feb 3, 2023
92e714b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Feb 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ if(WITH_MKLDNN)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(layer_norm_onednn_optimization_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(quant_transpose2_dequant_onednn_fuse_pass inference DIR mkldnn)
pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
Expand Down
38 changes: 38 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,44 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out;
}

PDNode *patterns::QuantTranspose2::operator()() {
auto *quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto *quant_out = pattern->NewNode(quant_out_repr())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should check if quant_out has only one output, because we want to make sure that there is only one connection between quantize and transpose so we can safely remove quantize.

->AsOutput()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be AsIntermediate() because it's removed in pass.

->AsIntermediate()
->assert_has_n_outputs(1)
->assert_is_op_output("quantize")
->assert_is_op_input("transpose2", "X");
auto *transpose2_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");

quant_op->LinksFrom({quant_in}).LinksTo({quant_out});
transpose2_op->LinksFrom({quant_out});

return transpose2_op;
}

PDNode *patterns::Transpose2Dequant::operator()() {
auto *transpose2_op =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, check if dequant_in has only one output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we check if dequant_in has only one input?

pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");
auto dequant_in = pattern->NewNode(dequant_in_repr())
->AsIntermediate()
->assert_has_n_inputs(1)
->assert_is_op_input("dequantize", "Input");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's also removed so I think it should be AsIntermediate()

auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");

transpose2_op->LinksTo({dequant_in});
dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out});
return dequant_out;
}

PDNode *patterns::Squeeze2Transpose2::operator()() {
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr())
->AsInput()
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,29 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out);
};

struct QuantTranspose2 : public PatternBase {
QuantTranspose2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_transpose2") {}

PDNode* operator()();

PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out);
PATTERN_DECL_NODE(transpose2_op);
};

struct Transpose2Dequant : public PatternBase {
Transpose2Dequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "transpose2_dequant") {}
PDNode* operator()();

PATTERN_DECL_NODE(transpose2_op);
PATTERN_DECL_NODE(dequant_in);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};

struct Squeeze2Transpose2 : public PatternBase {
Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "squeeze2_transpose2") {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2(
Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
patterns::QuantTranspose2 quant_transpose2_pattern(gpd.mutable_pattern(),
name_scope);
quant_transpose2_pattern();

int found_patterns_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Silv3S is our expert on this: He noticed lack of the below check and a need to add in constructor of the class some calls to AddOpCompat like in other passes.

Suggested change
Graph *g) {
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this check I have an error: ValueError: (InvalidArgument) At least one OpCompat instance should be added

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, quant_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, quant_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, quant_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, quant_transpose2_pattern);

if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) {
VLOG(4)
<< "Only oneDNN version of transpose2 can be fused with quantize.";
return;
}

float scale =
quant_op->Op()->HasAttr("Scale")
? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale"))
: 1;
float shift =
quant_op->Op()->HasAttr("Shift")
? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Shift"))
: 0;

transpose2_op->Op()->SetAttr("scale", scale);
transpose2_op->Op()->SetAttr("shift", shift);

bool is_negative_output =
quant_op->Op()->HasAttr("is_negative_input")
? PADDLE_GET_CONST(bool,
quant_op->Op()->GetAttr("is_negative_input"))
: false;
bool is_bfloat =
quant_op->Op()->HasAttr("bfloat16")
? PADDLE_GET_CONST(bool, quant_op->Op()->GetAttr("bfloat16"))
: false;
Comment on lines +71 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we don't have to deal with an atribute output_format from quantize_op.cc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attribute is only set but never used, so I skipped it


std::string output_dtype;
if (is_bfloat) {
output_dtype = "bf16";
} else if (is_negative_output) {
output_dtype = "int8";
} else {
output_dtype = "uint8";
}
transpose2_op->Op()->SetAttr("output_data_type", output_dtype);
transpose2_op->Op()->SetInput("X",
std::vector<std::string>({quant_in->Name()}));

IR_NODE_LINK_TO(quant_in, transpose2_op);
GraphSafeRemoveNodes(graph, {quant_op, quant_out});
found_patterns_count++;
};
gpd(graph, handler);
AddStatis(found_patterns_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) {
paddle::string::PrettyLogDetail("--- fused %d quant with transpose2",
found_patterns_count);
}
}

void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize(
Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
patterns::Transpose2Dequant transpose2_dequant_pattern(gpd.mutable_pattern(),
name_scope);
transpose2_dequant_pattern();

int found_patterns_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, transpose2_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dequant_in, dequant_in, transpose2_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dequant_op, dequant_op, transpose2_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dequant_out, dequant_out, transpose2_dequant_pattern);

if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) {
VLOG(4)
<< "Only oneDNN version of transpose2 can be fused with dequantize.";
return;
}

float scale =
dequant_op->Op()->HasAttr("Scale")
? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Scale"))
: 1;
float reorder_scale = 1.0 / scale;
float shift =
dequant_op->Op()->HasAttr("Shift")
? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Shift"))
: 0;

transpose2_op->Op()->SetAttr("scale", reorder_scale);
transpose2_op->Op()->SetAttr("shift", shift);
transpose2_op->Op()->SetAttr("output_data_type", std::string("fp32"));
transpose2_op->Op()->SetOutput(
"Out", std::vector<std::string>({dequant_out->Name()}));

IR_NODE_LINK_TO(transpose2_op, dequant_out);
GraphSafeRemoveNodes(graph, {dequant_in, dequant_op});
found_patterns_count++;
};

gpd(graph, handler);
AddStatis(found_patterns_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) {
paddle::string::PrettyLogDetail("--- fused %d transpose2 with dequant",
found_patterns_count);
}
}

void FuseQuantTranspose2DequantOneDNNPass::ApplyImpl(Graph *graph) const {
FuseQuantizeTranspose2(graph);
FuseTranspose2Dequantize(graph);
}

FuseQuantTranspose2DequantOneDNNPass::FuseQuantTranspose2DequantOneDNNPass() {
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(quant_transpose2_dequant_onednn_fuse_pass,
paddle::framework::ir::FuseQuantTranspose2DequantOneDNNPass);
REGISTER_PASS_CAPABILITY(quant_transpose2_dequant_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"transpose2", 0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Silv3S should this be .EQ("transpose2", 0);? Because there is only version 0 and we don't know what will be in the future?
Also there should be similar line for quantize op but .LE("quantize", 2)) because it works regardless if shift attr or bfloat is present? (See the bottom of quantize_op.cc)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more elegant to add .EQ rather than .GE, because we can't be sure that newer version will be compatible with this pass

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
namespace ir {

class FuseQuantTranspose2DequantOneDNNPass : public FusePassBase {
public:
virtual ~FuseQuantTranspose2DequantOneDNNPass() {}
FuseQuantTranspose2DequantOneDNNPass();

protected:
void ApplyImpl(Graph *graph) const override;
void FuseQuantizeTranspose2(Graph *graph) const;
void FuseTranspose2Dequantize(Graph *graph) const;

private:
std::string name_scope = "quant_transpose2_dequant_onednn_fuse_pass";
};

} // namespace ir
} // namespace framework

} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass");
passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass");
passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass");
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/ops_extra_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"Bias_scales", ExtraAttrProperty::ONEDNN},
{"Output_shift_scale", ExtraAttrProperty::ONEDNN},
{"Sum_scale", ExtraAttrProperty::ONEDNN},
{"scale", ExtraAttrProperty::ONEDNN},
{"shift", ExtraAttrProperty::ONEDNN},
{"output_data_type", ExtraAttrProperty::ONEDNN},
// GPUDNN dedicated attributes
{"exhaustive_search", ExtraAttrProperty::GPUDNN},
{"fuse_relu_before_depthwise_conv", ExtraAttrProperty::GPUDNN},
Expand Down
Loading