From 9745be198ad4cdb437123a3376369cebbd07698d Mon Sep 17 00:00:00 2001 From: zhangbaizhou Date: Sun, 28 Apr 2024 03:27:50 +0000 Subject: [PATCH] move logic of reverse topo sort to pattern_graph.cc --- .../graph_transformer/search_algorithm.h | 19 ++----------- paddle/cinn/operator_fusion/pattern_graph.cc | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/paddle/cinn/operator_fusion/graph_transformer/search_algorithm.h b/paddle/cinn/operator_fusion/graph_transformer/search_algorithm.h index db15ce8aa9ba00..354b7fca137bbf 100644 --- a/paddle/cinn/operator_fusion/graph_transformer/search_algorithm.h +++ b/paddle/cinn/operator_fusion/graph_transformer/search_algorithm.h @@ -111,22 +111,9 @@ struct SearchAlgorithm, int> - unvisited_nodes_to_out_degree; - for (const auto& node_ptr : graph->all_pattern_nodes()) { - unvisited_nodes_to_out_degree[node_ptr] = node_ptr->downstream().size(); - } - - while (!unvisited_nodes_to_out_degree.empty()) { - const auto& it = - std::find_if(unvisited_nodes_to_out_degree.begin(), - unvisited_nodes_to_out_degree.end(), - [&](const auto& pair) { return pair.second == 0; }); - reverse_topo_nodes.push(it->first); - for (const auto& upstream : it->first->upstream()) { - --unvisited_nodes_to_out_degree[upstream]; - } - unvisited_nodes_to_out_degree.erase(it); + auto reverse_topo_sort_result = graph->SortByReverseTopoOrder(); + for (const auto& node : reverse_topo_sort_result) { + reverse_topo_nodes.push(node); } } diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index bd4bd962b0877c..9bff297462af6b 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -98,6 +98,34 @@ std::vector> PatternGraph::SortByTopoOrder() { return res; } +template +std::vector> PatternGraph::SortByReverseTopoOrder() { + // sort all_pattern_nodes_ by reverse topo order. + std::vector> res; + std::list> reverse_topo_queue; + std::map, int> degree; + + for (const auto& node : all_pattern_nodes_) { + degree[node] = node->downstream().size(); + if (degree[node] == 0) { + reverse_topo_queue.push_back(node); + } + } + + while (!reverse_topo_queue.empty()) { + PatternNodePtr node = reverse_topo_queue.front(); + reverse_topo_queue.pop_front(); + res.push_back(node); + for (const auto& upstream : node->upstream()) { + degree[upstream]--; + if (degree[upstream] == 0) { + reverse_topo_queue.push_back(upstream); + } + } + } + return res; +} + template void PatternGraph::SinkTrivialPattern() { // TODO(@wuzhanfei) change sink trivial pattern algorithm, skip pattern with