Skip to content

Commit

Permalink
Merge branch 'multi_downstream' of https://github.com/feifei-111/Paddle
Browse files Browse the repository at this point in the history
… into multi_downstream
  • Loading branch information
feifei-111 committed Apr 28, 2024
2 parents 3b74908 + 4eda8bc commit 03d4fa2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
19 changes: 3 additions & 16 deletions paddle/cinn/operator_fusion/graph_transformer/search_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,9 @@ struct SearchAlgorithm<ReverseTopoNodePairPattern,
graph_ = graph;

// Do reverse topological sort, and store the results in reverse_topo_nodes.
std::unordered_map<PatternNodePtr<Phrase>, int>
unvisited_nodes_to_out_degree;
for (const auto& node_ptr : graph->all_pattern_nodes()) {
unvisited_nodes_to_out_degree[node_ptr] = node_ptr->downstream().size();
}

while (!unvisited_nodes_to_out_degree.empty()) {
const auto& it =
std::find_if(unvisited_nodes_to_out_degree.begin(),
unvisited_nodes_to_out_degree.end(),
[&](const auto& pair) { return pair.second == 0; });
reverse_topo_nodes.push(it->first);
for (const auto& upstream : it->first->upstream()) {
--unvisited_nodes_to_out_degree[upstream];
}
unvisited_nodes_to_out_degree.erase(it);
auto reverse_topo_sort_result = graph->SortByReverseTopoOrder();
for (const auto& node : reverse_topo_sort_result) {
reverse_topo_nodes.push(node);
}
}

Expand Down
28 changes: 28 additions & 0 deletions paddle/cinn/operator_fusion/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,34 @@ std::vector<PatternNodePtr<T>> PatternGraph<T>::SortByTopoOrder() {
return res;
}

template <typename T>
std::vector<PatternNodePtr<T>> PatternGraph<T>::SortByReverseTopoOrder() {
// sort all_pattern_nodes_ by reverse topo order.
std::vector<PatternNodePtr<T>> res;
std::list<PatternNodePtr<T>> reverse_topo_queue;
std::map<PatternNodePtr<T>, int> degree;

for (const auto& node : all_pattern_nodes_) {
degree[node] = node->downstream().size();
if (degree[node] == 0) {
reverse_topo_queue.push_back(node);
}
}

while (!reverse_topo_queue.empty()) {
PatternNodePtr<T> node = reverse_topo_queue.front();
reverse_topo_queue.pop_front();
res.push_back(node);
for (const auto& upstream : node->upstream()) {
degree[upstream]--;
if (degree[upstream] == 0) {
reverse_topo_queue.push_back(upstream);
}
}
}
return res;
}

template <typename T>
void PatternGraph<T>::SinkTrivialPattern() {
// TODO(@wuzhanfei) change sink trivial pattern algorithm, skip pattern with
Expand Down

0 comments on commit 03d4fa2

Please sign in to comment.