Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#16 from Fridge003/multi-down
Browse files Browse the repository at this point in the history
Implementation of anchor pattern recomputing mechanism
  • Loading branch information
feifei-111 authored Apr 30, 2024
2 parents eb6ef40 + affebad commit af8078c
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 42 deletions.
7 changes: 7 additions & 0 deletions paddle/cinn/operator_fusion/backend/pattern_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ StmtPattern<BackendStage> MergePatternImpl(
return TrivialPattern<BackendStage>(ops, trivial_op);
}

template <>
StmtPattern<BackendStage> MergePatternImpl(
const AnchorPattern<BackendStage>& first,
const AnchorPattern<BackendStage>& second) {
// TODO(@wuzhanfei)
}

template <>
StmtPattern<BackendStage> MergePatternImpl(
const HorizontalFusionPattern<BackendStage>& first,
Expand Down
5 changes: 5 additions & 0 deletions paddle/cinn/operator_fusion/backend/pattern_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ StmtPattern<BackendStage> MergePatternImpl(
const TrivialPattern<BackendStage>& first,
const TrivialPattern<BackendStage>& second);

template <>
StmtPattern<BackendStage> MergePatternImpl(
const AnchorPattern<BackendStage>& first,
const AnchorPattern<BackendStage>& second);

template <>
StmtPattern<BackendStage> MergePatternImpl(
const HorizontalFusionPattern<BackendStage>& first,
Expand Down
7 changes: 7 additions & 0 deletions paddle/cinn/operator_fusion/frontend/pattern_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ StmtPattern<FrontendStage> MergePatternImpl(
return TrivialPattern<FrontendStage>(contents);
}

template <>
StmtPattern<FrontendStage> MergePatternImpl(
const AnchorPattern<FrontendStage>& first,
const AnchorPattern<FrontendStage>& second) {
// TODO(@wuzhanfei)
}

template <>
StmtPattern<FrontendStage> MergePatternImpl(
const HorizontalFusionPattern<FrontendStage>& first,
Expand Down
5 changes: 5 additions & 0 deletions paddle/cinn/operator_fusion/frontend/pattern_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ StmtPattern<FrontendStage> MergePatternImpl(
const TrivialPattern<FrontendStage>& first,
const TrivialPattern<FrontendStage>& second);

template <>
StmtPattern<FrontendStage> MergePatternImpl(
const AnchorPattern<FrontendStage>& first,
const AnchorPattern<FrontendStage>& second);

template <>
StmtPattern<FrontendStage> MergePatternImpl(
const HorizontalFusionPattern<FrontendStage>& first,
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/operator_fusion/graph_transformer/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ struct CanFuseReduceTreeAndTrivialMatcher {
struct RecomputeNodeMatcher {
template <typename T>
bool operator()(const PatternGraph<T>& graph, const PatternNodePtr<T>& node) {
// TODO(@wuzhanfei)
return false;
return StmtPatternGraphMatcher<AnchorPattern<T>>()(graph, node) &&
node->downstream().size() > 1 &&
(node->stmt_pattern.can_recompute());
}
};

Expand Down
85 changes: 47 additions & 38 deletions paddle/cinn/operator_fusion/graph_transformer/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,47 +80,23 @@ struct LiftReduceToReduceTreeOperation {

struct MergeTrivialPatternOperation {
template <typename Phrase>
void operator()(PatternGraph<Phrase>* graph,
PatternNodePtr<Phrase> upstream) {
void operator()(PatternGraph<Phrase>* graph, PatternNodePtr<Phrase> node) {
PADDLE_ENFORCE_EQ(
node->downstream().size(),
1,
phi::errors::PreconditionNotMet("The downstream of the Sink Trivial "
"Pattern node should be 1, but got %d.",
node->downstream().size()));
const auto& downstream = node->downstream().at(0);
auto merged_node =
graph->MergeNode(upstream, downstream, MergePattern<Phrase>);
auto merged_node = graph->MergeNode(node, downstream, MergePattern<Phrase>);
graph->RemoveNode(downstream);
VLOG(4) << "MergeTrivialPatternOperation: \nupstream "
<< upstream->DebugStr() << "\ndownstream " << downstream->DebugStr()
<< "\nmerged " << merged_node->DebugStr();
graph->RemoveNode(node);
VLOG(4) << "MergeTrivialPatternOperation: \nupstream " << node->DebugStr()
<< "\ndownstream " << downstream->DebugStr() << "\nmerged "
<< merged_node->DebugStr();
}
};

// struct MergeTrivialPatternOperation {
// template <typename Phrase>
// void operator()(PatternGraph<Phrase>* graph,
// PatternNodePtr<Phrase> upstream) {
// std::vector<PatternNodePtr<Phrase>> fusion_candidate =
// upstream->downstream();
// upstream->ClearDownstream();
// for (const auto& downstream : fusion_candidate) {
// if (std::holds_alternative<ReducePattern<Phrase>>(
// downstream->stmt_pattern()) ||
// std::holds_alternative<TrivialPattern<Phrase>>(
// downstream->stmt_pattern())) {
// auto merged_node =
// graph->MergeNode(upstream, downstream, MergePattern<Phrase>);
// graph->RemoveNode(downstream);
// VLOG(4) << "MergeTrivialPatternOperation: \nupstream "
// << upstream->DebugStr() << "\ndownstream "
// << downstream->DebugStr() << "\nmerged "
// << merged_node->DebugStr();
// } else {
// upstream->AddNodeToDownstream(downstream);
// }
// }
// if (upstream->downstream().empty()) {
// graph->RemoveNode(upstream);
// }
// }
// };

struct LiftToHorizontalFusionPatternOperation {
template <typename Phrase>
void operator()(PatternGraph<Phrase>* graph, PatternNodePtr<Phrase> node) {
Expand Down Expand Up @@ -158,8 +134,41 @@ struct FuseDownstreamAnchorOperation {

struct SplitRecomputeOperation {
template <typename Phrase>
void operator()(PatternGraph<Phrase>* graph, PatternNodePtr<Phrase> node) {
// TODO(@wuzhanfei)
void operator()(PatternGraph<Phrase>* graph,
PatternNodePtr<Phrase> upstream) {
PADDLE_ENFORCE_GT(upstream->downstream().size(),
1,
phi::errors::PreconditionNotMet(
"The downstream of node for recomputation should be "
"more than 1, but got %d.",
upstream->downstream().size()));

std::vector<PatternNodePtr<Phrase>> fusion_candidate =
upstream->downstream();
upstream->ClearDownstream();

for (const auto& downstream : fusion_candidate) {
bool can_fuse = graph->policy_manager()
.template GetPolicy<AnchorSearchPolicy>()
->HasDownstreamAnchor(upstream, downstream) ||
graph->policy_manager()
.template GetPolicy<AnchorSearchPolicy>()
->HasUpstreamAnchor(upstream, downstream);
if (can_fuse) {
auto merged_node =
graph->MergeNode(upstream, downstream, MergePattern<Phrase>);
graph->RemoveNode(downstream);
VLOG(4) << "Spliting recomputable anchor pattern: \nupstream "
<< upstream->DebugStr() << "\ndownstream "
<< downstream->DebugStr() << "\nmerged "
<< merged_node->DebugStr();
} else {
upstream->AddNodeToDownstream(downstream);
}
}
if (upstream->downstream().empty()) {
graph->RemoveNode(upstream);
}
}
};

Expand Down
21 changes: 21 additions & 0 deletions paddle/cinn/operator_fusion/pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,27 @@ struct AnchorPattern {
std::vector<pir::Operation*> ops() const { return ops_; }
std::vector<pir::Value> outputs() const { return outputs_; }
pir::Value anchor() const { return anchor_; }

bool can_recompute() const {
// Current Algorithm:
// An AnchorPattern can be recomputed iff:
// 1. It didn't go through any pattern merging during prior fusions, which
// means it only has one output_expr in anchor_state.
// 2. It only contains trivial ops.

if (anchor_state.output_exprs.size() > 1) {
return false;
}

for (const auto& op : ops_) {
const auto& op_kind = GetOpPatternKind(op);
if (op_kind >= hlir::framework::kReduction) {
return false;
}
}

return true;
}
static std::string name() { return "AnchorPattern"; }
};

Expand Down
7 changes: 7 additions & 0 deletions paddle/cinn/operator_fusion/pattern_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ template <typename T>
StmtPattern<T> MergePatternImpl(const TrivialPattern<T>& first,
const TrivialPattern<T>& second);

template <typename T>
StmtPattern<T> MergePatternImpl(const AnchorPattern<T>& first,
const AnchorPattern<T>& second);

template <typename T>
StmtPattern<T> MergePatternImpl(const HorizontalFusionPattern<T>& first,
const HorizontalFusionPattern<T>& second);
Expand All @@ -193,6 +197,9 @@ StmtPattern<T> MergePattern(const StmtPattern<T>& first,
[&](const TrivialPattern<T>& lhs, const TrivialPattern<T>& rhs) {
return MergePatternImpl(lhs, rhs);
},
[&](const AnchorPattern<T>& lhs, const AnchorPattern<T>& rhs) {
return MergePatternImpl(lhs, rhs);
},
[&](const HorizontalFusionPattern<T>& lhs,
const HorizontalFusionPattern<T>& rhs) {
return MergePatternImpl(lhs, rhs);
Expand Down
2 changes: 0 additions & 2 deletions paddle/cinn/operator_fusion/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ void PatternGraph<T>::ReduceTree_Trivial_Fusion() {

template <typename T>
void PatternGraph<T>::LiftToAnchorPattern() {
// TODO(@wuzhanfei)
GraphTransformer<NodePattern, T, AlwaysTrue<T>, LiftToAnchorPatternOperation>(
this);
}
Expand All @@ -200,7 +199,6 @@ void PatternGraph<T>::AnchorPatternFusion() {

template <typename T>
void PatternGraph<T>::SplitRecomputePattern() {
// TODO(@wuzhanfei)
GraphTransformer<NodePattern,
T,
RecomputeNodeMatcher,
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/operator_fusion/policy/policy_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/cinn/operator_fusion/pattern_node.h"
#include "paddle/cinn/operator_fusion/policy/anchor_search_policy.h"
#include "paddle/cinn/operator_fusion/policy/general_topo_policy.h"
#include "paddle/cinn/operator_fusion/policy/policy_base.h"
#include "paddle/cinn/operator_fusion/policy/relative_judge_policy.h"
Expand Down

0 comments on commit af8078c

Please sign in to comment.