Skip to content

Commit

Permalink
fix one_to_many op Canonicalization (PaddlePaddle#15)
Browse files Browse the repository at this point in the history
* fix one_to_many op Canonicalization

* rename func
  • Loading branch information
gglin001 authored Aug 3, 2021
1 parent a03acc3 commit cc6895e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
24 changes: 24 additions & 0 deletions paddle/fluid/framework/ipu/popart_canonicalization_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,29 @@ SymbolHandler GetHandler(const std::string &kind) {
return {};
}

void MoveNodeInputs(ir::Node *node, ir::Node *new_node) {
new_node->inputs = node->inputs;
for (auto *node_in : node->inputs) {
for (size_t i = 0; i < node_in->outputs.size(); ++i) {
if (node_in->outputs[i] == node) {
node_in->outputs[i] = new_node;
break;
}
}
}
}

void MoveNodeOutputs(ir::Node *node, ir::Node *new_node) {
new_node->outputs = node->outputs;
for (auto *node_out : node->outputs) {
for (size_t i = 0; i < node_out->inputs.size(); ++i) {
if (node_out->inputs[i] == node) {
node_out->inputs[i] = new_node;
break;
}
}
}
}

} // namespace framework
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/framework/ipu/popart_canonicalization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,8 @@ bool RegisterHandler(const std::string &, const SymbolHandler &);

SymbolHandler GetHandler(const std::string &);

void MoveNodeInputs(ir::Node *node, ir::Node *new_node);
void MoveNodeOutputs(ir::Node *node, ir::Node *new_node);

} // namespace framework
} // namespace paddle
23 changes: 4 additions & 19 deletions paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,12 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
SymbolHandler handler = GetHandler(op_type);
if (handler) {
new_node = handler(graph, node);
new_node->inputs = node->inputs;
new_node->outputs = node->outputs;
// restore node releations
for (auto* node_in : node->inputs) {
for (size_t i = 0; i < node_in->outputs.size(); ++i) {
if (node_in->outputs[i] == node) {
node_in->outputs[i] = new_node;
break;
}
}
if (new_node->inputs.empty()) {
MoveNodeInputs(node, new_node);
}
for (auto* node_out : node->outputs) {
for (size_t i = 0; i < node_out->inputs.size(); ++i) {
if (node_out->inputs[i] == node) {
node_out->inputs[i] = new_node;
break;
}
}
if (new_node->outputs.empty()) {
MoveNodeOutputs(node, new_node);
}
}
if (new_node) {
graph->RemoveNode(node);
}
}
Expand Down

0 comments on commit cc6895e

Please sign in to comment.