From 196acd4befe6466f1577d77d214905b6dfcd564a Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 25 Jul 2022 11:12:37 +0000 Subject: [PATCH 1/8] fix RemoveNode in fuse_elewise_add_act_pass --- .../framework/ir/fuse_elewise_add_act_pass.cc | 12 +++++++++++- paddle/fluid/framework/ir/graph.h | 2 ++ paddle/fluid/framework/ir/graph_helper.cc | 14 ++++++++++++++ .../fluid/framework/ir/graph_pattern_detector.cc | 9 +++++++-- paddle/fluid/framework/ir/graph_pattern_detector.h | 3 ++- 5 files changed, 36 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index 5bd26e9eb9f2d..e7482fa3f9f1a 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -297,7 +297,17 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { } } } - GraphSafeRemoveNodes(graph, need_removed_nodes); + GraphSafeRemoveNodes(graph, need_removed_nodes, true); + if (!need_removed_nodes.empty()) { + if (!graph->Has(details::kRemovedVars)) { + graph->Set(details::kRemovedVars, + new std::unordered_set(need_removed_nodes)); + } else { + auto &removed_vars = + graph->Get>(details::kRemovedVars); + removed_vars.insert(need_removed_nodes.begin(), need_removed_nodes.end()); + } + } } void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 5a954110775d6..dce428ee38e88 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -45,6 +45,7 @@ namespace details { // This attr is not recommended, because the graph should not dependence // the program once it is built. constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs"; +constexpr char kRemovedVars[] = "removed_vars"; } // namespace details namespace ir { @@ -457,6 +458,7 @@ class Graph { std::map> attr_dels_; std::map> nodes_; std::unordered_set node_set_; + std::unorderd_set removed_node_set_; size_t num_node_created_{0}; // help to generate a unique node id. // NOTE(Aurelius84): Whether is constructed with partial ProgramDesc. // In case of @to_static, whole trainning program is splited into two diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 97f486065ac62..3d34340a7fba8 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -536,6 +536,20 @@ static void GraphToBlock(const Graph &graph, } } } + if (graph.Has(details::kRemovedVars)) { + auto &removed_vars = + graph.Get>(details::kRemovedVars); + for (const Node *n : removed_vars) { + if (n->IsVar()) { + if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && + !vars2remove.count(n->Var()->Name()) && + n->GetVarNodeBlockId() == graph.GetBlockId()) { + visited_vars.insert(n->Var()->Name()); + block->add_vars()->MergeFrom(*n->Var()->Proto()); + } + } + } + } block->clear_ops(); std::vector nodes; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6191c2efe9087..17664b51e5384 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -772,9 +772,14 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { } void GraphSafeRemoveNodes(Graph *graph, - const std::unordered_set &nodes) { + const std::unordered_set &nodes, + bool flag_save_nodes) { for (auto *node : nodes) { - graph->RemoveNode(const_cast(node)); + if (flag_save_nodes) { + graph->RemoveNode(const_cast(node)).release(); + } else { + graph->RemoveNode(const_cast(node)); + } } for (auto *node : graph->Nodes()) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 00e565b7161a2..d8198ee4e7d2b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -393,7 +393,8 @@ bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); // Graph safely remove some nodes, will automatically clean up the edges. void GraphSafeRemoveNodes(Graph* graph, - const std::unordered_set& nodes); + const std::unordered_set& nodes, + bool flag_save_nodes = false); // Some pre-defined patterns those can be reused in multiple passes. // The related Fluid Layer or Op should be one pattern here for better re-usage From 10d1cb5cb36c40ddd4e25ff1aebd12b81d294923 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 25 Jul 2022 11:18:05 +0000 Subject: [PATCH 2/8] fix --- paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc | 4 ---- paddle/fluid/framework/ir/graph.h | 1 - 2 files changed, 5 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index e7482fa3f9f1a..3ed68160443a9 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -302,10 +302,6 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { if (!graph->Has(details::kRemovedVars)) { graph->Set(details::kRemovedVars, new std::unordered_set(need_removed_nodes)); - } else { - auto &removed_vars = - graph->Get>(details::kRemovedVars); - removed_vars.insert(need_removed_nodes.begin(), need_removed_nodes.end()); } } } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index dce428ee38e88..1ccc82b4a9a59 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -458,7 +458,6 @@ class Graph { std::map> attr_dels_; std::map> nodes_; std::unordered_set node_set_; - std::unorderd_set removed_node_set_; size_t num_node_created_{0}; // help to generate a unique node id. // NOTE(Aurelius84): Whether is constructed with partial ProgramDesc. // In case of @to_static, whole trainning program is splited into two From cf1ab75f764123dfe3d82363d8844f00113093bf Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 25 Jul 2022 13:43:47 +0000 Subject: [PATCH 3/8] change pointer to share_ptr --- .../fluid/framework/ir/fuse_elewise_add_act_pass.cc | 7 ++++--- paddle/fluid/framework/ir/graph.h | 1 + paddle/fluid/framework/ir/graph_helper.cc | 5 ++--- paddle/fluid/framework/ir/graph_pattern_detector.cc | 11 ++++++++++- paddle/fluid/framework/ir/graph_pattern_detector.h | 5 ++++- paddle/fluid/pybind/ir.cc | 5 ++++- 6 files changed, 25 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index 3ed68160443a9..c087c55e61277 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -297,11 +297,12 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { } } } - GraphSafeRemoveNodes(graph, need_removed_nodes, true); - if (!need_removed_nodes.empty()) { + std::unordered_set> save_removed_nodes; + GraphSafeRemoveNodes(graph, need_removed_nodes, &save_removed_nodes, true); + if (!save_removed_nodes.empty()) { if (!graph->Has(details::kRemovedVars)) { graph->Set(details::kRemovedVars, - new std::unordered_set(need_removed_nodes)); + new details::RemovedVars(save_removed_nodes)); } } } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 1ccc82b4a9a59..3eb2df7011c7e 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -46,6 +46,7 @@ namespace details { // the program once it is built. constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs"; constexpr char kRemovedVars[] = "removed_vars"; +typedef std::unordered_set> RemovedVars; } // namespace details namespace ir { diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 3d34340a7fba8..4ec478508f8eb 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -537,9 +537,8 @@ static void GraphToBlock(const Graph &graph, } } if (graph.Has(details::kRemovedVars)) { - auto &removed_vars = - graph.Get>(details::kRemovedVars); - for (const Node *n : removed_vars) { + auto &removed_vars = graph.Get(details::kRemovedVars); + for (auto &n : removed_vars) { if (n->IsVar()) { if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && !vars2remove.count(n->Var()->Name()) && diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 17664b51e5384..864dbdea0ad02 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -771,12 +771,21 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { return var->Name() == op->Op()->Output(argument)[nth]; } +void GraphSafeRemoveNodes(Graph *graph, + const std::unordered_set &nodes) { + std::unordered_set> empty_nodes; + GraphSafeRemoveNodes(graph, nodes, empty_nodes, false); +} + void GraphSafeRemoveNodes(Graph *graph, const std::unordered_set &nodes, + std::unordered_set> *save_nodes, bool flag_save_nodes) { for (auto *node : nodes) { if (flag_save_nodes) { - graph->RemoveNode(const_cast(node)).release(); + // prevent unique_ptr node from being released + save_nodes->insert( + std::move(graph->RemoveNode(const_cast(node)))); } else { graph->RemoveNode(const_cast(node)); } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index d8198ee4e7d2b..0be111377dc80 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -392,9 +392,12 @@ bool HasOutput(Node* op, const std::string& argument); bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); // Graph safely remove some nodes, will automatically clean up the edges. +void GraphSafeRemoveNodes(Graph* graph, + const std::unordered_set& nodes); void GraphSafeRemoveNodes(Graph* graph, const std::unordered_set& nodes, - bool flag_save_nodes = false); + std::unordered_set>* save_nodes, + bool flag_save_nodes); // Some pre-defined patterns those can be reused in multiple passes. // The related Fluid Layer or Op should be one pattern here for better re-usage diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index b8b127201cccd..73f7e9a098c14 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -50,7 +50,10 @@ using pybind11::return_value_policy; namespace paddle { namespace pybind { void BindGraph(py::module *m) { - m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes); + m->def("graph_safe_remove_nodes", + [](Graph *graph, const std::unordered_set &nodes) { + return GraphSafeRemoveNodes(graph, nodes); + }); m->def("has_circle", HasCircle); m->def("graph_num", GraphNum); m->def( From d130300cee0bdfd21215065f2a338777ae0486ef Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 25 Jul 2022 13:51:54 +0000 Subject: [PATCH 4/8] fix --- paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc | 2 +- paddle/fluid/framework/ir/graph_pattern_detector.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index c087c55e61277..3ecfd01da689e 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -297,7 +297,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { } } } - std::unordered_set> save_removed_nodes; + details::RemovedVars save_removed_nodes; GraphSafeRemoveNodes(graph, need_removed_nodes, &save_removed_nodes, true); if (!save_removed_nodes.empty()) { if (!graph->Has(details::kRemovedVars)) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 864dbdea0ad02..7a89e985cf0ff 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -774,7 +774,7 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { void GraphSafeRemoveNodes(Graph *graph, const std::unordered_set &nodes) { std::unordered_set> empty_nodes; - GraphSafeRemoveNodes(graph, nodes, empty_nodes, false); + GraphSafeRemoveNodes(graph, nodes, &empty_nodes, false); } void GraphSafeRemoveNodes(Graph *graph, From 24ea9079272cbde30b03db00678db83c0c065671 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 26 Jul 2022 14:37:58 +0000 Subject: [PATCH 5/8] fix --- .../framework/ir/fuse_elewise_add_act_pass.cc | 18 +++++++++------ paddle/fluid/framework/ir/graph_helper.cc | 23 ++++++++----------- .../framework/ir/graph_pattern_detector.cc | 18 +++++---------- .../framework/ir/graph_pattern_detector.h | 10 ++++---- paddle/fluid/pybind/ir.cc | 5 +--- 5 files changed, 31 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index 3ecfd01da689e..67aa5a822edae 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -297,13 +297,17 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { } } } - details::RemovedVars save_removed_nodes; - GraphSafeRemoveNodes(graph, need_removed_nodes, &save_removed_nodes, true); - if (!save_removed_nodes.empty()) { - if (!graph->Has(details::kRemovedVars)) { - graph->Set(details::kRemovedVars, - new details::RemovedVars(save_removed_nodes)); - } + details::RemovedVars *saved_removed_nodes = new details::RemovedVars; + GraphSafeRemoveNodes(graph, need_removed_nodes, saved_removed_nodes); + if (!saved_removed_nodes->empty()) { + // TODO(pangyoki): If kRemovedVars exists, merge saved_removed_nodes into + // RemovedVars. + PADDLE_ENFORCE_EQ(graph->Has(details::kRemovedVars), + false, + platform::errors::PreconditionNotMet( + "Removed nodes are only saved for " + "fuse_elewise_add_act_pass in temporary.")); + graph->Set(details::kRemovedVars, saved_removed_nodes); } } diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 59ba04351b286..da3ff6bf2593c 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -568,6 +568,15 @@ static void GraphToBlock(const Graph &graph, vars_in_graph.emplace_back(*node->Var()->Proto()); } } + if (graph.Has(details::kRemovedVars)) { + auto &removed_vars = graph.Get(details::kRemovedVars); + for (auto &node : removed_vars) { + if (node->IsVar() && node->Var() && + node->GetVarNodeBlockId() == graph.GetBlockId()) { + vars_in_graph.emplace_back(*node->Var()->Proto()); + } + } + } // add vars_in_graph to blcok block->clear_vars(); @@ -581,20 +590,6 @@ static void GraphToBlock(const Graph &graph, } } - if (graph.Has(details::kRemovedVars)) { - auto &removed_vars = graph.Get(details::kRemovedVars); - for (auto &n : removed_vars) { - if (n->IsVar()) { - if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && - !vars2remove.count(n->Var()->Name()) && - n->GetVarNodeBlockId() == graph.GetBlockId()) { - visited_vars.insert(n->Var()->Name()); - block->add_vars()->MergeFrom(*n->Var()->Proto()); - } - } - } - } - block->clear_ops(); std::vector nodes; if (sort_kind != nullptr) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 7a89e985cf0ff..cce1ec89a2e82 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -771,20 +771,14 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { return var->Name() == op->Op()->Output(argument)[nth]; } -void GraphSafeRemoveNodes(Graph *graph, - const std::unordered_set &nodes) { - std::unordered_set> empty_nodes; - GraphSafeRemoveNodes(graph, nodes, &empty_nodes, false); -} - -void GraphSafeRemoveNodes(Graph *graph, - const std::unordered_set &nodes, - std::unordered_set> *save_nodes, - bool flag_save_nodes) { +void GraphSafeRemoveNodes( + Graph *graph, + const std::unordered_set &nodes, + std::unordered_set> *saved_nodes) { for (auto *node : nodes) { - if (flag_save_nodes) { + if (saved_nodes != nullptr) { // prevent unique_ptr node from being released - save_nodes->insert( + saved_nodes->insert( std::move(graph->RemoveNode(const_cast(node)))); } else { graph->RemoveNode(const_cast(node)); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 0be111377dc80..794c25e85a555 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -392,12 +392,10 @@ bool HasOutput(Node* op, const std::string& argument); bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); // Graph safely remove some nodes, will automatically clean up the edges. -void GraphSafeRemoveNodes(Graph* graph, - const std::unordered_set& nodes); -void GraphSafeRemoveNodes(Graph* graph, - const std::unordered_set& nodes, - std::unordered_set>* save_nodes, - bool flag_save_nodes); +void GraphSafeRemoveNodes( + Graph* graph, + const std::unordered_set& nodes, + std::unordered_set>* saved_nodes = nullptr); // Some pre-defined patterns those can be reused in multiple passes. // The related Fluid Layer or Op should be one pattern here for better re-usage diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 73f7e9a098c14..b8b127201cccd 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -50,10 +50,7 @@ using pybind11::return_value_policy; namespace paddle { namespace pybind { void BindGraph(py::module *m) { - m->def("graph_safe_remove_nodes", - [](Graph *graph, const std::unordered_set &nodes) { - return GraphSafeRemoveNodes(graph, nodes); - }); + m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes); m->def("has_circle", HasCircle); m->def("graph_num", GraphNum); m->def( From eb2066638f9226344aa65a54cbcaf0b7e9f14c9d Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 26 Jul 2022 15:26:46 +0000 Subject: [PATCH 6/8] fix format --- paddle/fluid/framework/ir/graph_helper.cc | 27 +++++++++++++---------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index da3ff6bf2593c..090778e52c3a8 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -549,6 +549,18 @@ static void GetGraphOpDesc(const std::vector &nodes, } } +template +static void GetGraphVarDesc(const Graph &graph, + const std::unordered_set &nodes, + std::vector *vars) { + for (T node : nodes) { + if (node->IsVar() && node->Var() && + node->GetVarNodeBlockId() == graph.GetBlockId()) { + vars.emplace_back(*node->Var()->Proto()); + } + } +} + static void GraphToBlock(const Graph &graph, proto::BlockDesc *block, const SortKind *sort_kind) { @@ -562,20 +574,11 @@ static void GraphToBlock(const Graph &graph, } std::vector vars_in_graph; - for (Node *node : graph.Nodes()) { - if (node->IsVar() && node->Var() && - node->GetVarNodeBlockId() == graph.GetBlockId()) { - vars_in_graph.emplace_back(*node->Var()->Proto()); - } - } + GetGraphVarDesc(graph, graph.Nodes(), &vars_in_graph); if (graph.Has(details::kRemovedVars)) { auto &removed_vars = graph.Get(details::kRemovedVars); - for (auto &node : removed_vars) { - if (node->IsVar() && node->Var() && - node->GetVarNodeBlockId() == graph.GetBlockId()) { - vars_in_graph.emplace_back(*node->Var()->Proto()); - } - } + GetGraphVarDesc>( + graph, removed_vars, &vars_in_graph); } // add vars_in_graph to blcok From 32e6e8d53309c7f9cf529312f2fb7e5e641edae7 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 26 Jul 2022 15:35:00 +0000 Subject: [PATCH 7/8] fix --- paddle/fluid/framework/ir/graph_helper.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 090778e52c3a8..a7bf131805dc1 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -556,7 +556,7 @@ static void GetGraphVarDesc(const Graph &graph, for (T node : nodes) { if (node->IsVar() && node->Var() && node->GetVarNodeBlockId() == graph.GetBlockId()) { - vars.emplace_back(*node->Var()->Proto()); + vars->emplace_back(*node->Var()->Proto()); } } } From 2193bf0bfedc54e1db5b32849135249f2c1eef43 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 27 Jul 2022 04:01:40 +0000 Subject: [PATCH 8/8] fix graph_safe_remove_nodes --- paddle/fluid/pybind/ir.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index b8b127201cccd..73f7e9a098c14 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -50,7 +50,10 @@ using pybind11::return_value_policy; namespace paddle { namespace pybind { void BindGraph(py::module *m) { - m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes); + m->def("graph_safe_remove_nodes", + [](Graph *graph, const std::unordered_set &nodes) { + return GraphSafeRemoveNodes(graph, nodes); + }); m->def("has_circle", HasCircle); m->def("graph_num", GraphNum); m->def(