diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 684dae7cc4817..c3d478140eb4a 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -346,16 +346,28 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef(c); } Expr ExprMutator::VisitExpr_(const MatchNode* m) { + bool unchanged = true; std::vector clauses; for (const Clause& p : m->clauses) { - clauses.push_back(VisitClause(p)); + Clause c = VisitClause(p); + clauses.push_back(c); + unchanged &= c.same_as(p); } - return Match(Mutate(m->data), clauses, m->complete); + Expr data = Mutate(m->data); + unchanged &= data.same_as(m->data); + if (unchanged) { + return GetRef(m); + } + return Match(data, clauses, m->complete); } Clause ExprMutator::VisitClause(const Clause& c) { Pattern p = VisitPattern(c->lhs); - return Clause(p, Mutate(c->rhs)); + Expr rhs = Mutate(c->rhs); + if (p.same_as(c->lhs) && rhs.same_as(c->rhs)) { + return c; + } + return Clause(p, rhs); } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 5d91dcb140565..467e30bc769de 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1140,6 +1140,18 @@ def conv_bias_relu(x, w, b): assert pattern2.match(relu) assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu)) +def test_match_match(): + add_pattern = is_op('add')(wildcard(), wildcard()) + class TestRewrite(DFPatternCallback): + def __init__(self): + self.pattern = add_pattern + def callback(self, pre, post, node_map): + return post.args[0] - post.args[1] + mod = tvm.IRModule({}) + tvm.relay.prelude.Prelude(mod) + # Apply rewrite on IR including relay.Match + out = rewrite(TestRewrite(), mod['tensor_concatenate_int64']) + assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out) if __name__ == "__main__": test_expr_pattern() @@ -1196,3 +1208,4 @@ def conv_bias_relu(x, w, b): test_partition_check() test_partition_check_types() test_partition_option() + test_match_match()