diff --git a/paddle/cinn/operator_fusion/graph_transformer/operation.h b/paddle/cinn/operator_fusion/graph_transformer/operation.h index fa41ed2d5598d..5f1cd26f4d672 100644 --- a/paddle/cinn/operator_fusion/graph_transformer/operation.h +++ b/paddle/cinn/operator_fusion/graph_transformer/operation.h @@ -132,7 +132,7 @@ struct LiftToHorizontalFusionPatternOperation { struct LiftToAnchorPatternOperation { template void operator()(PatternGraph* graph, PatternNodePtr node) { - // TODO(@wuzhanfei) + node->set_stmt_pattern(AnchorPattern(node->stmt_pattern())); } }; diff --git a/paddle/cinn/operator_fusion/pattern.h b/paddle/cinn/operator_fusion/pattern.h index d08e820f23ed1..b9298dbc3fb76 100644 --- a/paddle/cinn/operator_fusion/pattern.h +++ b/paddle/cinn/operator_fusion/pattern.h @@ -79,6 +79,24 @@ struct ReduceTreePlusTrivialPattern { std::vector fake_reduce_iter_idx; }; +template +struct AnchorPattern { + explicit AnchorPattern(const StmtPattern& pattern) : pattern_(pattern) { + ExtendVector(ops_, GetOpsInPattern(pattern)); + // TODO(@wuzhanfei): initialize anchor_ and anchor_state using ops_ and + // pattern + } + + StmtPattern pattern_; + std::vector ops_; + pir::Value anchor_; // Choose only one anchor + AnchorState anchor_state; + std::vector ops() const { return ops_; } + std::vector outputs() const { return outputs_; } + pir::Value anchor() const { return anchor_; } + static std::string name() { return "AnchorPattern"; } +}; + template class UnsupportPattern {}; @@ -90,6 +108,7 @@ using StmtPatternBase = std::variant, ReducePattern, ReduceTreePattern, ReduceTreePlusTrivialPattern, + AnchorPattern, HorizontalFusionPattern, UnsupportPattern>; @@ -100,20 +119,4 @@ struct StmtPattern final : public StmtPatternBase { return static_cast&>(*this); } }; - -template -struct AnchorPattern { - explicit AnchorPattern( - const std::vector& ops, - const pir::Value& anchor const AnchorState& anchor_state) - : ops_(ops), anchor_(anchor), {} - std::vector ops_; - pir::Value anchor_; // Choose only one anchor - AnchorState anchor_state; - std::vector ops() const { return ops_; } - std::vector outputs() const { return outputs_; } - pir::Value anchor() const { return anchor_; } - static std::string name() { return "AnchorPattern"; } -}; - } // namespace cinn::fusion