From 8037fc82f2c29b9c4779c95e1566646f1f615593 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 13 Jan 2020 22:40:26 +0000 Subject: [PATCH] fix RemoveUnusedFunctions pass --- src/relay/backend/vm/removed_unused_funcs.cc | 37 ++++++------------- .../test_pass_remove_unused_functions.py | 24 ++++++++++++ 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index 419b09588a7b..23bcdc373e26 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -53,33 +53,20 @@ struct CallTracer : ExprVisitor { called_funcs_{}, visiting_{} {} - void CheckExpr(const Expr& expr) { - if (auto func_node = expr.as()) { - auto func = GetRef(func_node); - auto it = visiting_.find(func); - if (it != visiting_.end()) { - return; - } - visiting_.insert(func); - VisitExpr(func); - } else if (auto global = expr.as()) { - called_funcs_.insert(global->name_hint); - auto func = module_->Lookup(global->name_hint); - auto it = visiting_.find(func); - if (it != visiting_.end()) { - return; - } - visiting_.insert(func); - VisitExpr(func); - } else { - VisitExpr(expr); - } + void VisitExpr_(const GlobalVarNode* op) final { + called_funcs_.insert(op->name_hint); + auto func = module_->Lookup(op->name_hint); + VisitExpr(func); } - void VisitExpr_(const CallNode* call_node) final { - CheckExpr(call_node->op); - for (auto param : call_node->args) { - CheckExpr(param); + void VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + if (visiting_.find(func) == visiting_.end()) { + visiting_.insert(func); + for (auto param : func_node->params) { + ExprVisitor::VisitExpr(param); + } + ExprVisitor::VisitExpr(func_node->body); } } diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index 97d8646922c0..2a4cbd2579e7 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -20,6 +20,7 @@ from tvm.relay import transform from tvm.relay.prelude import Prelude + def test_remove_all_prelude_functions(): mod = relay.Module() p = Prelude(mod) @@ -29,6 +30,7 @@ def test_remove_all_prelude_functions(): l = set([x[0].name_hint for x in mod.functions.items()]) assert l == set(['main']) + def test_remove_all_prelude_functions_but_referenced_functions(): mod = relay.Module() p = Prelude(mod) @@ -42,6 +44,7 @@ def test_remove_all_prelude_functions_but_referenced_functions(): l = set([x[0].name_hint for x in mod.functions.items()]) assert l == set(['id_func', 'main']) + def test_keep_only_referenced_prelude_functions(): mod = relay.Module() p = Prelude(mod) @@ -54,6 +57,7 @@ def test_keep_only_referenced_prelude_functions(): l = set([x[0].name_hint for x in mod.functions.items()]) assert l == set(['tl', 'hd', 'main']) + def test_multiple_entry_functions(): mod = relay.Module() p = Prelude(mod) @@ -72,6 +76,7 @@ def test_multiple_entry_functions(): l = set([x[0].name_hint for x in mod.functions.items()]) assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1']) + def test_globalvar_as_call_arg(): mod = relay.Module() p = Prelude(mod) @@ -88,5 +93,24 @@ def test_globalvar_as_call_arg(): l = set([x[0].name_hint for x in mod.functions.items()]) assert 'tensor_array_int32' in l + +def test_call_globalvar_without_args(): + def get_mod(): + mod = relay.Module({}) + fn1 = relay.Function([], relay.const(1)) + fn2 = relay.Function([], relay.const(2)) + g1 = relay.GlobalVar('g1') + g2 = relay.GlobalVar('g2') + mod[g1] = fn1 + mod[g2] = fn2 + p = relay.var('p', 'bool') + mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), [])) + return mod + mod = get_mod() + ref_mod = get_mod() + mod = relay.transform.RemoveUnusedFunctions()(mod) + assert relay.alpha_equal(mod, ref_mod) + + if __name__ == '__main__': pytest.main()