Skip to content

Commit

Permalink
fix RemoveUnusedFunctions pass
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and wweic committed Jan 14, 2020
1 parent 5699637 commit 8037fc8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
37 changes: 12 additions & 25 deletions src/relay/backend/vm/removed_unused_funcs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,20 @@ struct CallTracer : ExprVisitor {
called_funcs_{},
visiting_{} {}

void CheckExpr(const Expr& expr) {
if (auto func_node = expr.as<FunctionNode>()) {
auto func = GetRef<Function>(func_node);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
return;
}
visiting_.insert(func);
VisitExpr(func);
} else if (auto global = expr.as<GlobalVarNode>()) {
called_funcs_.insert(global->name_hint);
auto func = module_->Lookup(global->name_hint);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
return;
}
visiting_.insert(func);
VisitExpr(func);
} else {
VisitExpr(expr);
}
void VisitExpr_(const GlobalVarNode* op) final {
called_funcs_.insert(op->name_hint);
auto func = module_->Lookup(op->name_hint);
VisitExpr(func);
}

void VisitExpr_(const CallNode* call_node) final {
CheckExpr(call_node->op);
for (auto param : call_node->args) {
CheckExpr(param);
void VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
if (visiting_.find(func) == visiting_.end()) {
visiting_.insert(func);
for (auto param : func_node->params) {
ExprVisitor::VisitExpr(param);
}
ExprVisitor::VisitExpr(func_node->body);
}
}

Expand Down
24 changes: 24 additions & 0 deletions tests/python/relay/test_pass_remove_unused_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.relay import transform
from tvm.relay.prelude import Prelude


def test_remove_all_prelude_functions():
mod = relay.Module()
p = Prelude(mod)
Expand All @@ -29,6 +30,7 @@ def test_remove_all_prelude_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['main'])


def test_remove_all_prelude_functions_but_referenced_functions():
mod = relay.Module()
p = Prelude(mod)
Expand All @@ -42,6 +44,7 @@ def test_remove_all_prelude_functions_but_referenced_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['id_func', 'main'])


def test_keep_only_referenced_prelude_functions():
mod = relay.Module()
p = Prelude(mod)
Expand All @@ -54,6 +57,7 @@ def test_keep_only_referenced_prelude_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main'])


def test_multiple_entry_functions():
mod = relay.Module()
p = Prelude(mod)
Expand All @@ -72,6 +76,7 @@ def test_multiple_entry_functions():
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])


def test_globalvar_as_call_arg():
mod = relay.Module()
p = Prelude(mod)
Expand All @@ -88,5 +93,24 @@ def test_globalvar_as_call_arg():
l = set([x[0].name_hint for x in mod.functions.items()])
assert 'tensor_array_int32' in l


def test_call_globalvar_without_args():
def get_mod():
mod = relay.Module({})
fn1 = relay.Function([], relay.const(1))
fn2 = relay.Function([], relay.const(2))
g1 = relay.GlobalVar('g1')
g2 = relay.GlobalVar('g2')
mod[g1] = fn1
mod[g2] = fn2
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
return mod
mod = get_mod()
ref_mod = get_mod()
mod = relay.transform.RemoveUnusedFunctions()(mod)
assert relay.alpha_equal(mod, ref_mod)


if __name__ == '__main__':
pytest.main()

0 comments on commit 8037fc8

Please sign in to comment.