Skip to content

Commit

Permalink
remove direct access to analysis and transform namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 17, 2020
1 parent c2eb0b2 commit a5cf36b
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 27 deletions.
18 changes: 3 additions & 15 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from . import transform
from . import analysis
from .analysis import call_graph, feature, alpha_equal
from .analysis import alpha_equal
from .build_module import build, create_executor, optimize
from .transform import build_config
from . import debug
Expand All @@ -43,26 +43,20 @@
from .op import Op
from .op import nn
from .op import image
from .op import vision
from .op import annotation
from .op import vision
from .op import contrib
from .op.reduce import *
from .op.tensor import *
from .op.transform import *
from .op.algorithm import *
from .op.nn import *
from .op.vision import *
from .op.contrib import *
from .op.image import *
from . import frontend
from . import backend
from . import quantize

# Dialects
from . import qnn

# Load Memory pass
from .transform import memory_alloc

# Required to traverse large programs
setrecursionlimit(10000)

Expand Down Expand Up @@ -151,9 +145,3 @@
ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass
Sequential = transform.Sequential

# Feature
Feature = feature.Feature

# CallGraph
CallGraph = call_graph.CallGraph
2 changes: 2 additions & 0 deletions python/tvm/relay/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@

# Feature
from . import feature

CallGraph = call_graph.CallGraph
18 changes: 9 additions & 9 deletions tests/python/relay/test_call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_callgraph_construct():
x = relay.var("x", shape=(2, 3))
y = relay.var("y", shape=(2, 3))
mod["g1"] = relay.Function([x, y], x + y)
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
assert "g1" in str(call_graph)
assert relay.alpha_equal(mod, call_graph.module)

Expand All @@ -38,7 +38,7 @@ def test_print_element():
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
mod["g1"] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)

assert "#refs = 0" in str(call_graph.print_var("g0"))
assert "#refs = 0" in str(call_graph.print_var("g1"))
Expand All @@ -54,13 +54,13 @@ def test_global_call_count():
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)

p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)

assert call_graph.global_call_count(g0) == 0
assert call_graph.global_call_count(g1) == 1
Expand All @@ -77,13 +77,13 @@ def test_ref_count():
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)

p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)

assert call_graph.ref_count(g0) == 1
assert call_graph.ref_count(g1) == 1
Expand All @@ -100,13 +100,13 @@ def test_nested_ref():
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)

p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)

assert call_graph.ref_count(g0) == 2
assert call_graph.ref_count(g1) == 1
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_recursive_func():
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)

assert call_graph.is_recursive(sum_up)
assert call_graph.ref_count(sum_up) == 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tvm import te
import numpy as np
from tvm import relay
from tvm.relay import memory_alloc
from tvm.relay.transform import memory_alloc

def check_vm_alloc(func, check_fn):
mod = tvm.IRModule()
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_pass_to_graph_normal_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay import op, create_executor, transform, Feature
from tvm.relay import op, create_executor, transform
from tvm.relay.analysis import Feature
from tvm.relay.analysis import detect_feature


Expand Down

0 comments on commit a5cf36b

Please sign in to comment.