Skip to content

Commit

Permalink
fix tranpsoe inline leaf reshape (#69581)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjiyi authored Nov 21, 2024
1 parent 3c1ca43 commit 7ca7f2c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 8 additions & 8 deletions paddle/cinn/operator_fusion/graph_transformer/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,18 @@ struct TransposeOpMatcher {

struct ReshapeOpMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return (node->sink_op()->name() == "cinn_op.reshape");
return node->ops().size() == 1 &&
node->sink_op()->name() == "cinn_op.reshape";
}
};

struct ReshapeConnectionMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
bool upstream_match =
node->downstream().size() == 1 &&
node->downstream()[0]->sink_op()->name() == "cinn_op.reshape" &&
node->downstream()[0]->downstream().size() == 1;
bool downstream_match = node->sink_op()->name() == "cinn_op.reshape" &&
node->downstream().size() == 1;
bool upstream_match = node->downstream().size() == 1 &&
ReshapeOpMatcher()(graph, node->downstream()[0]) &&
node->downstream()[0]->downstream().size() == 1;
bool downstream_match =
ReshapeOpMatcher()(graph, node) && node->downstream().size() == 1;
return upstream_match || downstream_match;
}
};
Expand All @@ -255,7 +255,7 @@ struct LeafReshapeConnectionMatcher {
});
};
const auto match_downstream = [&graph](const PatternNodePtr& downstream) {
return downstream->sink_op()->name() == "cinn_op.reshape" &&
return ReshapeOpMatcher()(graph, downstream) &&
downstream->downstream().size() == 1 &&
downstream->downstream()[0]->downstream().empty() &&
downstream->fusion_iters().loop_iters ==
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/operator_fusion/pattern_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ struct PatternNode {
}

pir::Operation* sink_op() const { return sink_op_; }
std::vector<pir::Operation*> ops() const {
return GetOpsInPattern(stmt_pattern_);
}
const StmtPattern& stmt_pattern() const { return stmt_pattern_; }
void set_stmt_pattern(const StmtPattern& pattern) { stmt_pattern_ = pattern; }
const std::vector<PatternNodePtr>& upstream() const { return upstream_; }
Expand Down

0 comments on commit 7ca7f2c

Please sign in to comment.