From 7ca7f2c351551140f520a2d1460db0497fa627aa Mon Sep 17 00:00:00 2001 From: huangjiyi <43315610+huangjiyi@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:51:48 +0800 Subject: [PATCH] fix tranpsoe inline leaf reshape (#69581) --- .../operator_fusion/graph_transformer/matcher.h | 16 ++++++++-------- paddle/cinn/operator_fusion/pattern_node.h | 3 +++ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/paddle/cinn/operator_fusion/graph_transformer/matcher.h b/paddle/cinn/operator_fusion/graph_transformer/matcher.h index 001d234814334..52f8a10e4d76e 100644 --- a/paddle/cinn/operator_fusion/graph_transformer/matcher.h +++ b/paddle/cinn/operator_fusion/graph_transformer/matcher.h @@ -226,18 +226,18 @@ struct TransposeOpMatcher { struct ReshapeOpMatcher { bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return (node->sink_op()->name() == "cinn_op.reshape"); + return node->ops().size() == 1 && + node->sink_op()->name() == "cinn_op.reshape"; } }; struct ReshapeConnectionMatcher { bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - bool upstream_match = - node->downstream().size() == 1 && - node->downstream()[0]->sink_op()->name() == "cinn_op.reshape" && - node->downstream()[0]->downstream().size() == 1; - bool downstream_match = node->sink_op()->name() == "cinn_op.reshape" && - node->downstream().size() == 1; + bool upstream_match = node->downstream().size() == 1 && + ReshapeOpMatcher()(graph, node->downstream()[0]) && + node->downstream()[0]->downstream().size() == 1; + bool downstream_match = + ReshapeOpMatcher()(graph, node) && node->downstream().size() == 1; return upstream_match || downstream_match; } }; @@ -255,7 +255,7 @@ struct LeafReshapeConnectionMatcher { }); }; const auto match_downstream = [&graph](const PatternNodePtr& downstream) { - return downstream->sink_op()->name() == "cinn_op.reshape" && + return ReshapeOpMatcher()(graph, downstream) && downstream->downstream().size() == 1 && downstream->downstream()[0]->downstream().empty() && downstream->fusion_iters().loop_iters == diff --git a/paddle/cinn/operator_fusion/pattern_node.h b/paddle/cinn/operator_fusion/pattern_node.h index 679c6a44d7785..b2eff0bca45f4 100644 --- a/paddle/cinn/operator_fusion/pattern_node.h +++ b/paddle/cinn/operator_fusion/pattern_node.h @@ -66,6 +66,9 @@ struct PatternNode { } pir::Operation* sink_op() const { return sink_op_; } + std::vector ops() const { + return GetOpsInPattern(stmt_pattern_); + } const StmtPattern& stmt_pattern() const { return stmt_pattern_; } void set_stmt_pattern(const StmtPattern& pattern) { stmt_pattern_ = pattern; } const std::vector& upstream() const { return upstream_; }