diff --git a/paddle/cinn/operator_fusion/graph_transformer/matcher.h b/paddle/cinn/operator_fusion/graph_transformer/matcher.h index 001d234814334..52f8a10e4d76e 100644 --- a/paddle/cinn/operator_fusion/graph_transformer/matcher.h +++ b/paddle/cinn/operator_fusion/graph_transformer/matcher.h @@ -226,18 +226,18 @@ struct TransposeOpMatcher { struct ReshapeOpMatcher { bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return (node->sink_op()->name() == "cinn_op.reshape"); + return node->ops().size() == 1 && + node->sink_op()->name() == "cinn_op.reshape"; } }; struct ReshapeConnectionMatcher { bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - bool upstream_match = - node->downstream().size() == 1 && - node->downstream()[0]->sink_op()->name() == "cinn_op.reshape" && - node->downstream()[0]->downstream().size() == 1; - bool downstream_match = node->sink_op()->name() == "cinn_op.reshape" && - node->downstream().size() == 1; + bool upstream_match = node->downstream().size() == 1 && + ReshapeOpMatcher()(graph, node->downstream()[0]) && + node->downstream()[0]->downstream().size() == 1; + bool downstream_match = + ReshapeOpMatcher()(graph, node) && node->downstream().size() == 1; return upstream_match || downstream_match; } }; @@ -255,7 +255,7 @@ struct LeafReshapeConnectionMatcher { }); }; const auto match_downstream = [&graph](const PatternNodePtr& downstream) { - return downstream->sink_op()->name() == "cinn_op.reshape" && + return ReshapeOpMatcher()(graph, downstream) && downstream->downstream().size() == 1 && downstream->downstream()[0]->downstream().empty() && downstream->fusion_iters().loop_iters == diff --git a/paddle/cinn/operator_fusion/pattern_node.h b/paddle/cinn/operator_fusion/pattern_node.h index 679c6a44d7785..b2eff0bca45f4 100644 --- a/paddle/cinn/operator_fusion/pattern_node.h +++ b/paddle/cinn/operator_fusion/pattern_node.h @@ -66,6 +66,9 @@ struct PatternNode { } pir::Operation* sink_op() const { return sink_op_; } + std::vector ops() const { + return GetOpsInPattern(stmt_pattern_); + } const StmtPattern& stmt_pattern() const { return stmt_pattern_; } void set_stmt_pattern(const StmtPattern& pattern) { stmt_pattern_ = pattern; } const std::vector& upstream() const { return upstream_; }