Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix RemoveIntermediateOut in fuse_elewise_add_act_pass while converting graph to program #44593

Merged
merged 10 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,18 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
}
}
}
GraphSafeRemoveNodes(graph, need_removed_nodes);
details::RemovedVars *saved_removed_nodes = new details::RemovedVars;
GraphSafeRemoveNodes(graph, need_removed_nodes, saved_removed_nodes);
if (!saved_removed_nodes->empty()) {
// TODO(pangyoki): If kRemovedVars exists, merge saved_removed_nodes into
// RemovedVars.
PADDLE_ENFORCE_EQ(graph->Has(details::kRemovedVars),
false,
platform::errors::PreconditionNotMet(
"Removed nodes are only saved for "
"fuse_elewise_add_act_pass in temporary."));
graph->Set(details::kRemovedVars, saved_removed_nodes);
}
}

void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ namespace details {
// This attr is not recommended, because the graph should not dependence
// the program once it is built.
constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs";
constexpr char kRemovedVars[] = "removed_vars";
typedef std::unordered_set<std::shared_ptr<ir::Node>> RemovedVars;
} // namespace details

namespace ir {
Expand Down
22 changes: 17 additions & 5 deletions paddle/fluid/framework/ir/graph_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,18 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
}
}

template <class T = Node *>
static void GetGraphVarDesc(const Graph &graph,
const std::unordered_set<T> &nodes,
std::vector<proto::VarDesc> *vars) {
for (T node : nodes) {
if (node->IsVar() && node->Var() &&
node->GetVarNodeBlockId() == graph.GetBlockId()) {
vars->emplace_back(*node->Var()->Proto());
}
}
}

static void GraphToBlock(const Graph &graph,
proto::BlockDesc *block,
const SortKind *sort_kind) {
Expand All @@ -562,11 +574,11 @@ static void GraphToBlock(const Graph &graph,
}

std::vector<proto::VarDesc> vars_in_graph;
for (Node *node : graph.Nodes()) {
if (node->IsVar() && node->Var() &&
node->GetVarNodeBlockId() == graph.GetBlockId()) {
vars_in_graph.emplace_back(*node->Var()->Proto());
}
GetGraphVarDesc<Node *>(graph, graph.Nodes(), &vars_in_graph);
if (graph.Has(details::kRemovedVars)) {
auto &removed_vars = graph.Get<details::RemovedVars>(details::kRemovedVars);
GetGraphVarDesc<std::shared_ptr<ir::Node>>(
graph, removed_vars, &vars_in_graph);
}

// add vars_in_graph to blcok
Expand Down
14 changes: 11 additions & 3 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,18 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
return var->Name() == op->Op()->Output(argument)[nth];
}

void GraphSafeRemoveNodes(Graph *graph,
const std::unordered_set<const Node *> &nodes) {
void GraphSafeRemoveNodes(
Graph *graph,
const std::unordered_set<const Node *> &nodes,
std::unordered_set<std::shared_ptr<Node>> *saved_nodes) {
for (auto *node : nodes) {
graph->RemoveNode(const_cast<Node *>(node));
if (saved_nodes != nullptr) {
// prevent unique_ptr node from being released
saved_nodes->insert(
std::move(graph->RemoveNode(const_cast<Node *>(node))));
} else {
graph->RemoveNode(const_cast<Node *>(node));
}
}

for (auto *node : graph->Nodes()) {
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,10 @@ bool HasOutput(Node* op, const std::string& argument);
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);

// Graph safely remove some nodes, will automatically clean up the edges.
void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes);
void GraphSafeRemoveNodes(
Graph* graph,
const std::unordered_set<const Node*>& nodes,
std::unordered_set<std::shared_ptr<Node>>* saved_nodes = nullptr);

// Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better re-usage
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ using pybind11::return_value_policy;
namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
m->def("graph_safe_remove_nodes",
[](Graph *graph, const std::unordered_set<const Node *> &nodes) {
return GraphSafeRemoveNodes(graph, nodes);
});
m->def("has_circle", HasCircle);
m->def("graph_num", GraphNum);
m->def(
Expand Down