Skip to content

Commit

Permalink
Correct multiple inputs and outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
wozna committed Dec 9, 2022
1 parent a1bdc65 commit af961a2
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 35 deletions.
102 changes: 77 additions & 25 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
float shift,
std::string shift_attr_name) const {
auto inputs = op->inputs;
auto var_names = op->Op()->Inputs().at(input_name);
std::vector<std::string> unique_var_names;
for (unsigned i = 0; i < var_names.size(); i++)
if (std::find(unique_var_names.begin(),
unique_var_names.end(),
var_names[i]) == unique_var_names.end())
unique_var_names.push_back(var_names[i]);

auto output = op->outputs[0];
PADDLE_ENFORCE_GE(inputs.size(),
1,
Expand All @@ -163,33 +171,59 @@ void CPUQuantizePass::QuantizeInputs(Graph* g,
// create a quantize op desc prototype
OpDesc q_desc;
q_desc.SetType("quantize");

std::vector<Node*> quantize_out_nodes(inputs.size());
std::vector<std::string> quantize_out_node_names(inputs.size());

double scale_out = GetScaleValueForNode(output);
unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
float scale = scale_out * max;

for (size_t i = 0; i < inputs.size(); i++) {
// Create quantize output variable
for (size_t var_id = 0; var_id < unique_var_names.size(); var_id++) {
auto index = -1;
for (size_t it = 0; it < inputs.size(); it++) {
if (inputs[it]->Name() == unique_var_names[var_id]) index = it;
}

if (index == -1) {
PADDLE_ENFORCE_NE(index,
-1,
platform::errors::InvalidArgument(
"Var(%s) isn't the input of the %s operator.",
unique_var_names[var_id],
op->Op()->Type()));
}

auto* input = inputs.at(index);

VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
quantize_out_node_names[i] = quantize_out_nodes[i]->Name();
quantize_out_nodes[var_id] = g->CreateVarNode(&quantize_out_desc);
quantize_out_node_names[var_id] = quantize_out_nodes[var_id]->Name();

q_desc.SetAttr("Scale", scale);
q_desc.SetAttr("Shift", shift);
q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node_names[i]}));
q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
q_desc.SetOutput(
"Output", std::vector<std::string>({quantize_out_node_names[var_id]}));
q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.

// link quantize op
UnlinkNodes(inputs[i], op);
IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
UnlinkNodes(input, op);
IR_NODE_LINK_TO(input, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[var_id]);
IR_NODE_LINK_TO(quantize_out_nodes[var_id], op);
}

// If any inputs were duplicated, now you have to enter them in the correct
// order.
for (size_t i = unique_var_names.size(); i < var_names.size(); i++) {
auto index = std::find(
unique_var_names.begin(), unique_var_names.end(), var_names[i]);
if (index != unique_var_names.end()) {
auto id = std::distance(unique_var_names.begin(), index);
quantize_out_node_names[i] = quantize_out_nodes[id]->Name();
IR_NODE_LINK_TO(quantize_out_nodes[id], op);
}
}

// update op's input
Expand Down Expand Up @@ -252,44 +286,62 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g,
bool is_unsigned,
std::string scale_attr_name) const {
auto outputs = op->outputs;
auto var_names = op->Op()->Outputs().at(output_name);

PADDLE_ENFORCE_GE(outputs.size(),
1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(),
outputs.size()));

std::vector<std::string> quantize_in_node_names(outputs.size());
std::vector<std::string> dequantize_in_node_names(outputs.size());
std::vector<Node*> dequantize_in_nodes(outputs.size());

unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;

for (size_t i = 0; i < outputs.size(); i++) {
for (size_t var_id = 0; var_id < var_names.size(); var_id++) {
auto index = -1;
for (size_t it = 0; it < outputs.size(); it++) {
if (outputs[it]->Name() == var_names[var_id]) index = it;
}

if (index == -1) {
PADDLE_ENFORCE_NE(index,
-1,
platform::errors::InvalidArgument(
"Var(%s) isn't the input of the %s operator.",
var_names[var_id],
op->Op()->Type()));
}

auto* output = outputs.at(index);

// Create dequantize input variable
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
Node* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
quantize_in_node_names[i] = dequantize_in_node->Name();
dequantize_in_nodes[var_id] = g->CreateVarNode(&dequantize_in_desc);
dequantize_in_node_names[var_id] = dequantize_in_nodes[var_id]->Name();

// create a dequantize op node for output.
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({quantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));
deq_desc.SetInput(
"Input", std::vector<std::string>({dequantize_in_node_names[var_id]}));
deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
deq_desc.SetAttr("Scale", scale);
deq_desc.SetAttr("is_negative_input", !is_unsigned);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

// link dequantize op
UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);
UnlinkNodes(op, output);
IR_NODE_LINK_TO(op, dequantize_in_nodes[var_id]);
IR_NODE_LINK_TO(dequantize_in_nodes[var_id], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, output);
}

// update op's output
op->Op()->SetOutput(output_name, quantize_in_node_names);
op->Op()->SetOutput(output_name, dequantize_in_node_names);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,45 @@ TEST(CpuQuantizePass, multi_gru_3) {
MainTestMultiGru(layers);
}

static const std::initializer_list<std::string>
variable_names_multi_inputs_outputs = {"a", "b", "c1", "c2", "d", "e"};

// a->Pool->b
// b->Split->c1, c2
// (c1, c2, c1, c2)->Concat->d
// d->Pool->e
ProgramDesc BuildProgramDescMulti() {
ProgramDesc prog;
for (auto& v : variable_names_multi_inputs_outputs) {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}

SetOp(&prog, "pool2d", "Pool", {"a"}, {"b"}, true, "float32");
SetOp(&prog, "split", "Split", {"b"}, {"c1", "c2"}, true, "int8");
SetOp(
&prog, "concat", "Concat", {"c1", "c2", "c1", "c2"}, {"d"}, true, "int8");
SetOp(&prog, "pool2d", "Pool2", {"d"}, {"e"}, true, "float32");

return prog;
}

TEST(CpuQuantizePass, multi_inputs_outputs_ops) {
// a->QUANT1->Split
// b1->DEQUANT->OUT->QUANT
// b2->DEQUANT->OUT->QUANT
// (b1, b2, b1, b2)->Concat->c->DEQUANT->Pool->d
int added_nodes = 6 * 2;
std::unordered_map<std::string, int> expected_operators = {{"pool2d", 2},
{"concat", 1},
{"split", 1},
{"quantize", 3},
{"dequantize", 3}};
MainTest(BuildProgramDescMulti(),
variable_names_multi_inputs_outputs,
expected_operators,
added_nodes);
}

} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down
23 changes: 13 additions & 10 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
float dequant_shift = dequant_op->Op()->GetAttrIfExists<float>("Shift");
float quant_shift = quant_op->Op()->GetAttrIfExists<float>("Shift");
if (quant_op->Op()->GetAttrIfExists<bool>("is_negative_input") !=
dequant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
return;
}

PADDLE_ENFORCE_NE(
nodes_keep_counter->find(dequant_out),
nodes_keep_counter->end(),
Expand All @@ -169,14 +174,13 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
if (dequant_scale == quant_scale && dequant_shift == quant_shift) {
// squash dequantize-quantize to nothing
auto quant_out_var_name = quant_out->Name();
auto next_op_inputs = next_op_desc->InputNames();
for (const auto& name : next_op_inputs) {
auto input_names = next_op_desc->Input(name);
for (auto input_name : next_op_desc->InputNames()) {
auto& input_names = next_op_desc->MutableInputs()->at(input_name);
std::replace(input_names.begin(),
input_names.end(),
quant_out_var_name,
dequant_in->Name());
next_op_desc->SetInput(name, input_names);
next_op_desc->SetInput(input_name, input_names);
}

if (keep_dequant)
Expand Down Expand Up @@ -413,12 +417,11 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {

// update the next operator input,
// by replacing quant_out with first_quant_out
auto last_op_names = last_op->Op()->Input(last_op_input_name);
last_op_names.erase(
std::remove(
last_op_names.begin(), last_op_names.end(), quant_out->Name()),
last_op_names.end());
last_op_names.push_back(first_quant_out->Name());
auto last_op_names = last_op->Op()->Inputs().at(last_op_input_name);
std::replace(last_op_names.begin(),
last_op_names.end(),
quant_out->Name(),
first_quant_out->Name());
last_op_op->SetInput(last_op_input_name,
std::vector<std::string>(last_op_names));

Expand Down

0 comments on commit af961a2

Please sign in to comment.