Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#13 from Fridge003/multi-down
Browse files Browse the repository at this point in the history
skip multi-downstream nodes when doing trivial sink
  • Loading branch information
feifei-111 authored Apr 28, 2024
2 parents 03d4fa2 + 18a3eb8 commit 2a85bbf
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 39 deletions.
19 changes: 12 additions & 7 deletions paddle/cinn/operator_fusion/graph_transformer/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ struct CanFuseRxTMatcher {
}
};

struct SinkTrivialMatcher {
template <typename T>
bool operator()(const PatternGraph<T>& graph, const PatternNodePtr<T>& node) {
return StmtPatternGraphMatcher<TrivialPattern<T>>()(graph, node) &&
node->downstream().size() == 1 &&
(std::holds_alternative<ReducePattern<Phrase>>(
node->downstream().at(0)->stmt_pattern()) ||
std::holds_alternative<TrivialPattern<Phrase>>(
node->downstream().at(0)->stmt_pattern()));
}
};

struct CanFuseReduceTreeMatcher {
template <typename T>
bool operator()(const PatternGraph<T>& graph, const PatternNodePtr<T>& node) {
Expand Down Expand Up @@ -133,13 +145,6 @@ struct HorizontalFusionMatcher {
}
};

struct NonSinkNodeMatcher {
template <typename T>
bool operator()(const PatternGraph<T>& graph, const PatternNodePtr<T>& node) {
return !node->downstream().empty();
}
};

struct IsOutputNodeMatcher {
template <typename T>
bool operator()(const PatternGraph<T>& graph, const PatternNodePtr<T>& node) {
Expand Down
59 changes: 36 additions & 23 deletions paddle/cinn/operator_fusion/graph_transformer/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,35 +79,48 @@ struct LiftReduceToReduceTreeOperation {
};

struct MergeTrivialPatternOperation {
// TOOD(@wuzhanfei)
template <typename Phrase>
void operator()(PatternGraph<Phrase>* graph,
PatternNodePtr<Phrase> upstream) {
std::vector<PatternNodePtr<Phrase>> fusion_candidate =
upstream->downstream();
upstream->ClearDownstream();
for (const auto& downstream : fusion_candidate) {
if (std::holds_alternative<ReducePattern<Phrase>>(
downstream->stmt_pattern()) ||
std::holds_alternative<TrivialPattern<Phrase>>(
downstream->stmt_pattern())) {
auto merged_node =
graph->MergeNode(upstream, downstream, MergePattern<Phrase>);
graph->RemoveNode(downstream);
VLOG(4) << "MergeTrivialPatternOperation: \nupstream "
<< upstream->DebugStr() << "\ndownstream "
<< downstream->DebugStr() << "\nmerged "
<< merged_node->DebugStr();
} else {
upstream->AddNodeToDownstream(downstream);
}
}
if (upstream->downstream().empty()) {
graph->RemoveNode(upstream);
}
const auto& downstream = node->downstream().at(0);
auto merged_node =
graph->MergeNode(upstream, downstream, MergePattern<Phrase>);
graph->RemoveNode(downstream);
VLOG(4) << "MergeTrivialPatternOperation: \nupstream "
<< upstream->DebugStr() << "\ndownstream " << downstream->DebugStr()
<< "\nmerged " << merged_node->DebugStr();
}
};

// struct MergeTrivialPatternOperation {
// template <typename Phrase>
// void operator()(PatternGraph<Phrase>* graph,
// PatternNodePtr<Phrase> upstream) {
// std::vector<PatternNodePtr<Phrase>> fusion_candidate =
// upstream->downstream();
// upstream->ClearDownstream();
// for (const auto& downstream : fusion_candidate) {
// if (std::holds_alternative<ReducePattern<Phrase>>(
// downstream->stmt_pattern()) ||
// std::holds_alternative<TrivialPattern<Phrase>>(
// downstream->stmt_pattern())) {
// auto merged_node =
// graph->MergeNode(upstream, downstream, MergePattern<Phrase>);
// graph->RemoveNode(downstream);
// VLOG(4) << "MergeTrivialPatternOperation: \nupstream "
// << upstream->DebugStr() << "\ndownstream "
// << downstream->DebugStr() << "\nmerged "
// << merged_node->DebugStr();
// } else {
// upstream->AddNodeToDownstream(downstream);
// }
// }
// if (upstream->downstream().empty()) {
// graph->RemoveNode(upstream);
// }
// }
// };

struct LiftToHorizontalFusionPatternOperation {
template <typename Phrase>
void operator()(PatternGraph<Phrase>* graph, PatternNodePtr<Phrase> node) {
Expand Down
13 changes: 4 additions & 9 deletions paddle/cinn/operator_fusion/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ template <typename T>
std::vector<PatternNodePtr<T>> PatternGraph<T>::ClusterOps() {
VLOG(4) << "[Group Cluster] Initial Condition: " << GraphInfo();

// TODO(@wuzhanfei) Remove All IsOutputPattern Check

VLOG(4) << "[Group Cluster] Start SinkTrivialPattern";
SinkTrivialPattern();
VLOG(4) << "[Group Cluster] After SinkTrivialPattern: " << GraphInfo();
Expand Down Expand Up @@ -128,13 +126,10 @@ std::vector<PatternNodePtr<T>> PatternGraph<T>::SortByReverseTopoOrder() {

template <typename T>
void PatternGraph<T>::SinkTrivialPattern() {
// TODO(@wuzhanfei) change sink trivial pattern algorithm, skip pattern with
// multi downstream
GraphTransformer<
NodePattern,
T,
And<NonSinkNodeMatcher, StmtPatternGraphMatcher<TrivialPattern<T>>>,
MergeTrivialPatternOperation>(this);
GraphTransformer<NodePattern,
T,
SinkTrivialMatcher,
MergeTrivialPatternOperation>(this);
}

template <typename T>
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/operator_fusion/pattern_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class PatternGraph {
const PatternNodePtr<T>& downstream,
MergePatternFn<T> merge_pattern_fn);
std::vector<PatternNodePtr<T>> SortByTopoOrder();
std::vector<PatternNodePtr<T>> SortByReverseTopoOrder();

const PatternNodePtrSet<T>& all_pattern_nodes() const {
return all_pattern_nodes_;
Expand Down

0 comments on commit 2a85bbf

Please sign in to comment.