Skip to content

Commit

Permalink
[Relay] Fix dataflow_pattern.rewrite() hang if Match in IR
Browse files Browse the repository at this point in the history
  rewrite() quits only if graph stop changing, but ExprMutator
  always creates new Match node. This patch fixes this.
  • Loading branch information
lixiaoquan committed May 29, 2020
1 parent 95b3ad9 commit f46285c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,28 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef<Expr>(c); }

Expr ExprMutator::VisitExpr_(const MatchNode* m) {
bool unchanged = true;
std::vector<Clause> clauses;
for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p));
Clause c = VisitClause(p);
clauses.push_back(c);
unchanged &= c.same_as(p);
}
return Match(Mutate(m->data), clauses, m->complete);
Expr data = Mutate(m->data);
unchanged &= data.same_as(m->data);
if (unchanged) {
return GetRef<Expr>(m);
}
return Match(data, clauses, m->complete);
}

Clause ExprMutator::VisitClause(const Clause& c) {
Pattern p = VisitPattern(c->lhs);
return Clause(p, Mutate(c->rhs));
Expr rhs = Mutate(c->rhs);
if (p.same_as(c->lhs) && rhs.same_as(c->rhs)) {
return c;
}
return Clause(p, rhs);
}

Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,18 @@ def conv_bias_relu(x, w, b):
assert pattern2.match(relu)
assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))

def test_match_match():
add_pattern = is_op('add')(wildcard(), wildcard())
class TestRewrite(DFPatternCallback):
def __init__(self):
self.pattern = add_pattern
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
mod = tvm.IRModule({})
tvm.relay.prelude.Prelude(mod)
# Apply rewrite on IR including relay.Match
out = rewrite(TestRewrite(), mod['tensor_concatenate_int64'])
assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out)

if __name__ == "__main__":
test_expr_pattern()
Expand Down Expand Up @@ -1196,3 +1208,4 @@ def conv_bias_relu(x, w, b):
test_partition_check()
test_partition_check_types()
test_partition_option()
test_match_match()

0 comments on commit f46285c

Please sign in to comment.