diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 1ec09aac2606..2ab4ca2e1404 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -101,6 +101,7 @@ # ExprFunctor ExprFunctor = expr_functor.ExprFunctor +ExprVisitor = expr_functor.ExprVisitor ExprMutator = expr_functor.ExprMutator # Parser diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index a924847fa238..9ca094158c1c 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -36,9 +36,8 @@ def __init__(self): # pylint: disable=no-else-return def visit(self, expr): """Apply the visitor to an expression.""" - found = self.memo_map.get(expr) - if found: - return found + if expr in self.memo_map: + return self.memo_map[expr] if isinstance(expr, Function): res = self.visit_function(expr) @@ -126,6 +125,68 @@ def visit_match(self, _): raise NotImplementedError() +class ExprVisitor(ExprFunctor): + """ + A visitor over Expr. + + The default behavior recursively traverses the AST. + """ + def visit_tuple(self, t): + for x in t.fields: + self.visit(x) + + def visit_call(self, c): + self.visit(c.op) + for a in c.args: + self.visit(a) + + def visit_var(self, v): + pass + + def visit_let(self, l): + self.visit(l.var) + self.visit(l.value) + self.visit(l.body) + + def visit_function(self, f): + self.visit(f.body) + + def visit_if(self, i): + self.visit(i.cond) + self.visit(i.true_branch) + self.visit(i.false_branch) + + def visit_global_var(self, gv): + pass + + def visit_constructor(self, c): + pass + + def visit_op(self, op): + pass + + def visit_constant(self, const): + pass + + def visit_ref_create(self, r): + self.visit(r.value) + + def visit_ref_read(self, r): + self.visit(r.ref) + + def visit_ref_write(self, r): + self.visit(r.ref) + self.visit(r.value) + + def visit_tuple_getitem(self, t): + self.visit(t.tuple_value) + + def visit_match(self, m): + self.visit(m.data) + for c in m.clause: + self.visit(c.rhs) + + class ExprMutator(ExprFunctor): """ A functional visitor over Expr. diff --git a/tests/python/relay/test_expr_functor.py b/tests/python/relay/test_expr_functor.py index 2a58c282b4c7..ae5ee7bd8bd4 100644 --- a/tests/python/relay/test_expr_functor.py +++ b/tests/python/relay/test_expr_functor.py @@ -16,34 +16,42 @@ # under the License. import tvm from tvm import relay -from tvm.relay import ExprFunctor, ExprMutator +from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor def check_visit(expr): - ef = ExprFunctor() try: + ef = ExprFunctor() ef.visit(expr) assert False except NotImplementedError: pass + ev = ExprVisitor() + ev.visit(expr) + em = ExprMutator() assert em.visit(expr) + def test_constant(): check_visit(relay.const(1.0)) + def test_tuple(): t = relay.Tuple([relay.var('x', shape=())]) check_visit(t) + def test_var(): v = relay.var('x', shape=()) check_visit(v) + def test_global(): v = relay.GlobalVar('f') check_visit(v) + def test_function(): x = relay.var('x', shape=()) y = relay.var('y', shape=()) @@ -61,12 +69,14 @@ def test_function(): ) check_visit(f) + def test_call(): x = relay.var('x', shape=()) y = relay.var('y', shape=()) call = relay.op.add(x, y) check_visit(call) + def test_let(): x = relay.var('x', shape=()) value = relay.const(2.0) @@ -74,30 +84,43 @@ def test_let(): l = relay.Let(x, value, body) check_visit(l) + def test_ite(): cond = relay.var('x', shape=(), dtype='bool') ite = relay.If(cond, cond, cond) check_visit(ite) + def test_get_item(): t = relay.Tuple([relay.var('x', shape=())]) t = relay.TupleGetItem(t, 0) check_visit(t) + def test_ref_create(): r = relay.expr.RefCreate(relay.const(1.0)) check_visit(r) + def test_ref_read(): ref = relay.expr.RefCreate(relay.const(1.0)) r = relay.expr.RefRead(ref) check_visit(r) + def test_ref_write(): ref = relay.expr.RefCreate(relay.const(1.0)) r = relay.expr.RefWrite(ref, relay.const(2.0)) check_visit(r) + +def test_memo(): + expr = relay.const(1) + for _ in range(100): + expr = expr + expr + check_visit(expr) + + if __name__ == "__main__": test_constant() test_tuple() @@ -110,3 +133,4 @@ def test_ref_write(): test_ref_create() test_ref_read() test_ref_write() + test_memo()