Skip to content

Commit

Permalink
dicrease the complexity of CalcDep from exponential to linear (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji authored and wweic committed Oct 18, 2019
1 parent cb305f2 commit 4b1b88f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/relay/pass/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,15 @@ class CalcDep : private ExprVisitor {
VarMap<size_t> use_map_;

void VisitExpr(const Expr& e) final {
return ExprFunctor<void(const Expr& e)>::VisitExpr(e);
visit_counter_[e.get()]++;
// The dce code seprate variable into three parts:
// used 0 times (remove)
// used 1 times (inline)
// used 2 times (dont do anything).
if (visit_counter_[e.get()] <= 2) {
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(e);
}
}

void VisitExpr_(const LetNode* l) final {
Expand Down
9 changes: 9 additions & 0 deletions tests/python/relay/test_pass_dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from tvm.relay import Function, transform
from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal
from tvm.relay.op import log, add, equal, subtract
from tvm.relay.testing import inception_v3

import pytest

class env:
def __init__(self):
Expand Down Expand Up @@ -129,6 +131,12 @@ def test_tuple_get_item():
assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))


@pytest.mark.timeout(timeout=10, method="thread")
def test_complexity():
g = inception_v3.get_net(1, 1000, (3, 299, 299), 'float32')
run_opt_pass(g, transform.DeadCodeElimination())


if __name__ == "__main__":
test_let()
test_used_let()
Expand All @@ -138,3 +146,4 @@ def test_tuple_get_item():
test_recursion_dead()
test_op_let()
test_tuple_get_item()
test_complexity()

0 comments on commit 4b1b88f

Please sign in to comment.