From bc6530500c319e80904d65d4b54219b40ec75a85 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 16 Mar 2020 18:18:02 +0000 Subject: [PATCH] revert relay/ir/*.py to relay --- docs/api/python/relay/base.rst | 12 +- docs/api/python/relay/expr.rst | 38 ++--- docs/api/python/relay/scope_builder.rst | 6 +- .../graph_tuner/utils/traverse_graph.py | 4 +- python/tvm/relay/__init__.py | 108 +++++++------- python/tvm/relay/{ir => }/_ffi_api.py | 0 python/tvm/relay/{ir => }/_parser.py | 10 +- python/tvm/relay/{ir => }/adt.py | 16 +- python/tvm/relay/analysis/analysis.py | 56 +++---- python/tvm/relay/analysis/call_graph.py | 2 +- python/tvm/relay/backend/compile_engine.py | 10 +- python/tvm/relay/backend/interpreter.py | 12 +- python/tvm/relay/{ir => }/base.py | 34 +---- python/tvm/relay/{ir => }/expr.py | 25 ++-- python/tvm/relay/{ir => }/expr_functor.py | 2 +- python/tvm/relay/frontend/pytorch.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 4 +- python/tvm/relay/ir/__init__.py | 95 ------------ python/tvm/relay/{ir => }/loops.py | 0 python/tvm/relay/op/__init__.py | 3 +- python/tvm/relay/op/_tensor_grad.py | 2 +- python/tvm/relay/op/algorithm.py | 2 +- python/tvm/relay/op/nn/nn.py | 2 +- python/tvm/relay/op/op.py | 5 +- python/tvm/relay/op/op_attrs.py | 140 ++++++++++-------- python/tvm/relay/op/reduce.py | 2 +- python/tvm/relay/op/tensor.py | 2 +- python/tvm/relay/op/transform.py | 6 +- python/tvm/relay/op/vision/multibox.py | 2 +- python/tvm/relay/op/vision/nms.py | 2 +- python/tvm/relay/{ir => }/parser.py | 4 +- python/tvm/relay/{ir => }/prelude.py | 4 +- python/tvm/relay/qnn/op/qnn.py | 2 +- python/tvm/relay/quantize/_annotate.py | 3 +- python/tvm/relay/quantize/_partition.py | 3 +- python/tvm/relay/quantize/quantize.py | 4 +- python/tvm/relay/{ir => }/scope_builder.py | 2 +- python/tvm/relay/testing/__init__.py | 6 +- python/tvm/relay/testing/nat.py | 6 +- python/tvm/relay/testing/py_converter.py | 6 +- python/tvm/relay/transform/memory_alloc.py | 6 +- python/tvm/relay/transform/transform.py | 3 +- python/tvm/relay/{ir => }/ty.py | 2 +- python/tvm/relay/{ir => }/type_functor.py | 0 tests/python/relay/test_adt.py | 2 +- ...st_feature.py => test_analysis_feature.py} | 2 +- tests/python/relay/test_any.py | 2 +- .../python/relay/test_backend_interpreter.py | 2 +- tests/python/relay/test_ir_module.py | 2 +- tests/python/relay/test_ir_well_formed.py | 2 +- tests/python/relay/test_pass_annotation.py | 3 +- tests/python/relay/test_pass_gradient.py | 2 +- tests/python/relay/test_pass_manager.py | 3 +- tests/python/relay/test_pass_partial_eval.py | 8 +- .../python/relay/test_pass_partition_graph.py | 2 +- .../test_pass_remove_unused_functions.py | 2 +- .../relay/test_pass_to_a_normal_form.py | 2 +- tests/python/relay/test_pass_to_cps.py | 3 +- .../python/relay/test_pass_unmatched_cases.py | 2 +- tests/python/relay/test_py_converter.py | 3 +- tests/python/relay/test_type_functor.py | 4 +- tests/python/relay/test_vm.py | 6 +- tests/python/relay/test_vm_serialization.py | 4 +- .../test_autotvm_graph_tuner_utils.py | 2 +- 64 files changed, 304 insertions(+), 409 deletions(-) rename python/tvm/relay/{ir => }/_ffi_api.py (100%) rename python/tvm/relay/{ir => }/_parser.py (99%) rename python/tvm/relay/{ir => }/adt.py (92%) rename python/tvm/relay/{ir => }/base.py (59%) rename python/tvm/relay/{ir => }/expr.py (97%) rename python/tvm/relay/{ir => }/expr_functor.py (99%) delete mode 100644 python/tvm/relay/ir/__init__.py rename python/tvm/relay/{ir => }/loops.py (100%) rename python/tvm/relay/{ir => }/parser.py (94%) rename python/tvm/relay/{ir => }/prelude.py (99%) rename python/tvm/relay/{ir => }/scope_builder.py (99%) rename python/tvm/relay/{ir => }/ty.py (97%) rename python/tvm/relay/{ir => }/type_functor.py (100%) rename tests/python/relay/{test_feature.py => test_analysis_feature.py} (98%) diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst index 61b18ac9c815..dc9dac0f67bd 100644 --- a/docs/api/python/relay/base.rst +++ b/docs/api/python/relay/base.rst @@ -15,16 +15,16 @@ specific language governing permissions and limitations under the License. -tvm.relay.ir.base +tvm.relay.base -------------- -.. automodule:: tvm.relay.ir.base +.. automodule:: tvm.relay.base -.. autofunction:: tvm.relay.ir.base.register_relay_node +.. autofunction:: tvm.relay.base.register_relay_node -.. autofunction:: tvm.relay.ir.base.register_relay_attr_node +.. autofunction:: tvm.relay.base.register_relay_attr_node -.. autoclass:: tvm.relay.ir.base.RelayNode +.. autoclass:: tvm.relay.base.RelayNode :members: -.. autoclass:: tvm.relay.ir.base.Id +.. autoclass:: tvm.relay.base.Id :members: diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst index e7944730a4b8..57a4a2511b72 100644 --- a/docs/api/python/relay/expr.rst +++ b/docs/api/python/relay/expr.rst @@ -15,56 +15,56 @@ specific language governing permissions and limitations under the License. -tvm.relay.ir.expr +tvm.relay.expr -------------- -.. automodule:: tvm.relay.ir.expr +.. automodule:: tvm.relay.expr -.. autofunction:: tvm.relay.ir.expr.var +.. autofunction:: tvm.relay.expr.var -.. autofunction:: tvm.relay.ir.expr.const +.. autofunction:: tvm.relay.expr.const -.. autofunction:: tvm.relay.ir.expr.bind +.. autofunction:: tvm.relay.expr.bind -.. autoclass:: tvm.relay.ir.expr.Expr +.. autoclass:: tvm.relay.expr.Expr :members: -.. autoclass:: tvm.relay.ir.expr.Constant +.. autoclass:: tvm.relay.expr.Constant :members: -.. autoclass:: tvm.relay.ir.expr.Tuple +.. autoclass:: tvm.relay.expr.Tuple :members: -.. autoclass:: tvm.relay.ir.expr.Function +.. autoclass:: tvm.relay.expr.Function :members: -.. autoclass:: tvm.relay.ir.expr.Call +.. autoclass:: tvm.relay.expr.Call :members: -.. autoclass:: tvm.relay.ir.expr.Let +.. autoclass:: tvm.relay.expr.Let :members: -.. autoclass:: tvm.relay.ir.expr.If +.. autoclass:: tvm.relay.expr.If :members: -.. autoclass:: tvm.relay.ir.expr.TupleGetItem +.. autoclass:: tvm.relay.expr.TupleGetItem :members: -.. autoclass:: tvm.relay.ir.expr.RefCreate +.. autoclass:: tvm.relay.expr.RefCreate :members: -.. autoclass:: tvm.relay.ir.expr.RefRead +.. autoclass:: tvm.relay.expr.RefRead :members: -.. autoclass:: tvm.relay.ir.expr.RefWrite +.. autoclass:: tvm.relay.expr.RefWrite :members: -.. autoclass:: tvm.relay.ir.expr.TupleGetItem +.. autoclass:: tvm.relay.expr.TupleGetItem :members: -.. autoclass:: tvm.relay.ir.expr.TempExpr +.. autoclass:: tvm.relay.expr.TempExpr :members: -.. autoclass:: tvm.relay.ir.expr.TupleWrapper +.. autoclass:: tvm.relay.expr.TupleWrapper :members: diff --git a/docs/api/python/relay/scope_builder.rst b/docs/api/python/relay/scope_builder.rst index 730751f7a581..6d8e01428e31 100644 --- a/docs/api/python/relay/scope_builder.rst +++ b/docs/api/python/relay/scope_builder.rst @@ -15,10 +15,10 @@ specific language governing permissions and limitations under the License. -tvm.relay.ir.scope_builder +tvm.relay.scope_builder ----------------------- -.. automodule:: tvm.relay.ir.scope_builder +.. automodule:: tvm.relay.scope_builder -.. autoclass:: tvm.relay.ir.scope_builder.ScopeBuilder +.. autoclass:: tvm.relay.scope_builder.ScopeBuilder :members: diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index e463d20242d7..f1dd40440532 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -21,8 +21,8 @@ import tvm from tvm import relay, autotvm from tvm.relay import transform -from tvm.relay.ir import Call, Function, TupleGetItem, Var, Constant, Tuple -from tvm.relay.ir import TupleType, TensorType +from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple +from tvm.relay.ty import TupleType, TensorType from tvm.autotvm.task import TaskExtractEnv from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 8184b586016c..f24fe44f9041 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -19,9 +19,16 @@ import os from sys import setrecursionlimit -from . import ir -from .ir import adt, expr, ty, base, scope_builder -from .ir import prelude, loops, parser +from . import base +from . import ty +from . import expr +from . import type_functor +from . import expr_functor +from . import adt +from . import prelude +from . import loops +from . import scope_builder +from . import parser from . import transform from . import analysis @@ -60,65 +67,66 @@ setrecursionlimit(10000) # Span -Span = ir.Span +Span = base.Span +SourceName = base.SourceName # Type -Type = ir.Type -TupleType = ir.TupleType -TensorType = ir.TensorType -TypeKind = ir.TypeKind -TypeVar = ir.TypeVar -ShapeVar = ir.ShapeVar -TypeConstraint = ir.TypeConstraint -FuncType = ir.FuncType -TypeRelation = ir.TypeRelation -IncompleteType = ir.IncompleteType -scalar_type = ir.scalar_type -RefType = ir.RefType -GlobalTypeVar = ir.GlobalTypeVar -TypeCall = ir.TypeCall -Any = ir.Any +Type = ty.Type +TupleType = ty.TupleType +TensorType = ty.TensorType +TypeKind = ty.TypeKind +TypeVar = ty.TypeVar +ShapeVar = ty.ShapeVar +TypeConstraint = ty.TypeConstraint +FuncType = ty.FuncType +TypeRelation = ty.TypeRelation +IncompleteType = ty.IncompleteType +scalar_type = ty.scalar_type +RefType = ty.RefType +GlobalTypeVar = ty.GlobalTypeVar +TypeCall = ty.TypeCall +Any = ty.Any # Expr -Expr = ir.Expr -Constant = ir.Constant -Tuple = ir.Tuple -Var = ir.Var -GlobalVar = ir.GlobalVar -Function = ir.Function -Call = ir.Call -Let = ir.Let -If = ir.If -TupleGetItem = ir.TupleGetItem -RefCreate = ir.RefCreate -RefRead = ir.RefRead -RefWrite = ir.RefWrite +Expr = expr.RelayExpr +Constant = expr.Constant +Tuple = expr.Tuple +Var = expr.Var +GlobalVar = expr.GlobalVar +Function = expr.Function +Call = expr.Call +Let = expr.Let +If = expr.If +TupleGetItem = expr.TupleGetItem +RefCreate = expr.RefCreate +RefRead = expr.RefRead +RefWrite = expr.RefWrite # ADT -Pattern = ir.Pattern -PatternWildcard = ir.PatternWildcard -PatternVar = ir.PatternVar -PatternConstructor = ir.PatternConstructor -PatternTuple = ir.PatternTuple -Constructor = ir.Constructor -TypeData = ir.TypeData -Clause = ir.Clause -Match = ir.Match +Pattern = adt.Pattern +PatternWildcard = adt.PatternWildcard +PatternVar = adt.PatternVar +PatternConstructor = adt.PatternConstructor +PatternTuple = adt.PatternTuple +Constructor = adt.Constructor +TypeData = adt.TypeData +Clause = adt.Clause +Match = adt.Match # helper functions -var = ir.var -const = ir.const -bind = ir.bind +var = expr.var +const = expr.const +bind = expr.bind # TypeFunctor -TypeFunctor = ir.TypeFunctor -TypeVisitor = ir.TypeVisitor -TypeMutator = ir.TypeMutator +TypeFunctor = type_functor.TypeFunctor +TypeVisitor = type_functor.TypeVisitor +TypeMutator = type_functor.TypeMutator # ExprFunctor -ExprFunctor = ir.ExprFunctor -ExprVisitor = ir.ExprVisitor -ExprMutator = ir.ExprMutator +ExprFunctor = expr_functor.ExprFunctor +ExprVisitor = expr_functor.ExprVisitor +ExprMutator = expr_functor.ExprMutator # Prelude Prelude = prelude.Prelude diff --git a/python/tvm/relay/ir/_ffi_api.py b/python/tvm/relay/_ffi_api.py similarity index 100% rename from python/tvm/relay/ir/_ffi_api.py rename to python/tvm/relay/_ffi_api.py diff --git a/python/tvm/relay/ir/_parser.py b/python/tvm/relay/_parser.py similarity index 99% rename from python/tvm/relay/ir/_parser.py rename to python/tvm/relay/_parser.py index 354014a78862..49bdbb393c2e 100644 --- a/python/tvm/relay/ir/_parser.py +++ b/python/tvm/relay/_parser.py @@ -40,11 +40,11 @@ def __new__(cls, *args, **kwds): import tvm.ir._ffi_api from tvm.ir import IRModule -from . import Span, SourceName +from .base import Span, SourceName from . import adt from . import expr from . import ty -from .. import op +from . import op PYTHON_VERSION = sys.version_info.major try: @@ -56,9 +56,9 @@ def __new__(cls, *args, **kwds): .format(version=PYTHON_VERSION)) try: - from ..grammar.py3.RelayVisitor import RelayVisitor - from ..grammar.py3.RelayParser import RelayParser - from ..grammar.py3.RelayLexer import RelayLexer + from .grammar.py3.RelayVisitor import RelayVisitor + from .grammar.py3.RelayParser import RelayParser + from .grammar.py3.RelayLexer import RelayLexer except ImportError: raise Exception("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") diff --git a/python/tvm/relay/ir/adt.py b/python/tvm/relay/adt.py similarity index 92% rename from python/tvm/relay/ir/adt.py rename to python/tvm/relay/adt.py index 8b4127286948..df12aaece2da 100644 --- a/python/tvm/relay/ir/adt.py +++ b/python/tvm/relay/adt.py @@ -17,8 +17,10 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import """Algebraic data types in Relay.""" from tvm.ir import Constructor, TypeData +from tvm.runtime import Object +import tvm._ffi -from .base import RelayNode, register_relay_node, Object +from .base import RelayNode from . import _ffi_api from .ty import Type from .expr import ExprWithOp, RelayExpr, Call @@ -28,7 +30,7 @@ class Pattern(RelayNode): """Base type for pattern matching constructs.""" -@register_relay_node +@tvm._ffi.register_object("relay.PatternWildcard") class PatternWildcard(Pattern): """Wildcard pattern in Relay: Matches any ADT and binds nothing.""" @@ -47,7 +49,7 @@ def __init__(self): self.__init_handle_by_constructor__(_ffi_api.PatternWildcard) -@register_relay_node +@tvm._ffi.register_object("relay.PatternVar") class PatternVar(Pattern): """Variable pattern in Relay: Matches anything and binds it to the variable.""" @@ -66,7 +68,7 @@ def __init__(self, var): self.__init_handle_by_constructor__(_ffi_api.PatternVar, var) -@register_relay_node +@tvm._ffi.register_object("relay.PatternConstructor") class PatternConstructor(Pattern): """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" @@ -91,7 +93,7 @@ def __init__(self, constructor, patterns=None): self.__init_handle_by_constructor__(_ffi_api.PatternConstructor, constructor, patterns) -@register_relay_node +@tvm._ffi.register_object("relay.PatternTuple") class PatternTuple(Pattern): """Constructor pattern in Relay: Matches a tuple, binds recursively.""" @@ -114,7 +116,7 @@ def __init__(self, patterns=None): self.__init_handle_by_constructor__(_ffi_api.PatternTuple, patterns) -@register_relay_node +@tvm._ffi.register_object("relay.Clause") class Clause(Object): """Clause for pattern matching in Relay.""" @@ -136,7 +138,7 @@ def __init__(self, lhs, rhs): self.__init_handle_by_constructor__(_ffi_api.Clause, lhs, rhs) -@register_relay_node +@tvm._ffi.register_object("relay.Match") class Match(ExprWithOp): """Pattern matching expression in Relay.""" diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 2e4465bac08c..beb3c6599e28 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -24,7 +24,7 @@ from . import _ffi_api from .feature import Feature -from ..ir import Type +from ..ty import Type def post_order_visit(expr, fvisit): @@ -34,7 +34,7 @@ def post_order_visit(expr, fvisit): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression. fvisit : function @@ -48,7 +48,7 @@ def well_formed(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -95,7 +95,7 @@ def check_constant(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -111,7 +111,7 @@ def free_vars(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -133,7 +133,7 @@ def bound_vars(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -149,7 +149,7 @@ def all_vars(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -165,7 +165,7 @@ def free_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -185,7 +185,7 @@ def bound_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -205,7 +205,7 @@ def all_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -225,10 +225,10 @@ def alpha_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.ir.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs : tvm.relay.ir.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns @@ -244,10 +244,10 @@ def assert_alpha_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.ir.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs : tvm.relay.ir.Expr + rhs : tvm.relay.Expr One of the input Expression. """ _ffi_api._assert_alpha_equal(lhs, rhs) @@ -261,10 +261,10 @@ def graph_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.ir.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs : tvm.relay.ir.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns @@ -283,10 +283,10 @@ def assert_graph_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.ir.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs : tvm.relay.ir.Expr + rhs : tvm.relay.Expr One of the input Expression. """ _ffi_api._assert_graph_equal(lhs, rhs) @@ -298,13 +298,13 @@ def collect_device_info(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression. Returns ------- ret : Dict[tvm.relay.ir.expr, int] - A dictionary mapping tvm.relay.ir.Expr to device type. + A dictionary mapping tvm.relay.Expr to device type. """ return _ffi_api.CollectDeviceInfo(expr) @@ -314,13 +314,13 @@ def collect_device_annotation_ops(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression. Returns ------- - ret : Dict[tvm.relay.ir.Expr, int] - A dictionary mapping tvm.relay.ir.Expr to device type where the keys are + ret : Dict[tvm.relay.Expr, int] + A dictionary mapping tvm.relay.Expr to device type where the keys are annotation expressions. """ return _ffi_api.CollectDeviceAnnotationOps(expr) @@ -332,7 +332,7 @@ def get_total_mac_number(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression. Returns @@ -369,10 +369,10 @@ def detect_feature(a, b=None): Parameters ---------- - a : Union[tvm.relay.ir.Expr, tvm.IRModule] + a : Union[tvm.relay.Expr, tvm.IRModule] The input expression or module. - b : Optional[Union[tvm.relay.ir.Expr, tvm.IRModule]] + b : Optional[Union[tvm.relay.Expr, tvm.IRModule]] The input expression or module. The two arguments cannot both be expression or module. @@ -391,7 +391,7 @@ def structural_hash(value): Parameters ---------- - expr : Union[tvm.relay.ir.Expr, tvm.relay.Type] + expr : Union[tvm.relay.Expr, tvm.relay.Type] The expression to hash. Returns @@ -405,7 +405,7 @@ def structural_hash(value): return int(_ffi_api._type_hash(value)) else: msg = ("found value of type {0} expected" + - "relay.ir.Expr or relay.Type").format(type(value)) + "relay.Expr or relay.Type").format(type(value)) raise TypeError(msg) diff --git a/python/tvm/relay/analysis/call_graph.py b/python/tvm/relay/analysis/call_graph.py index 0d1053612c8d..966659aac494 100644 --- a/python/tvm/relay/analysis/call_graph.py +++ b/python/tvm/relay/analysis/call_graph.py @@ -19,7 +19,7 @@ from tvm.ir import IRModule from tvm.runtime import Object -from ..ir import GlobalVar +from ..expr import GlobalVar from . import _ffi_api diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 38efafeee66e..03d91d5beb0f 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -22,7 +22,7 @@ import numpy as np import tvm from tvm import te -from ..ir.base import register_relay_node, Object +from tvm.runtime import Object from ... import target as _target from ... import autotvm from .. import expr as _expr @@ -33,7 +33,7 @@ logger = logging.getLogger('compile_engine') -@register_relay_node +@tvm._ffi.register_object("relay.LoweredOutput") class LoweredOutput(Object): """Lowered output""" def __init__(self, outputs, implement): @@ -41,7 +41,7 @@ def __init__(self, outputs, implement): _backend._make_LoweredOutput, outputs, implement) -@register_relay_node +@tvm._ffi.register_object("relay.CCacheKey") class CCacheKey(Object): """Key in the CompileEngine. @@ -58,7 +58,7 @@ def __init__(self, source_func, target): _backend._make_CCacheKey, source_func, target) -@register_relay_node +@tvm._ffi.register_object("relay.CCacheValue") class CCacheValue(Object): """Value in the CompileEngine, including usage statistics. """ @@ -261,7 +261,7 @@ def lower_call(call, inputs, target): return LoweredOutput(outputs, best_impl) -@register_relay_node +@tvm._ffi.register_object("relay.CompileEngine") class CompileEngine(Object): """CompileEngine to get lowered code. """ diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index a7245987ab01..ab39f7c56446 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -20,25 +20,25 @@ import numpy as np -from tvm.runtime import container +import tvm._ffi +from tvm.runtime import container, Object from tvm.ir import IRModule from . import _backend from .. import _make, analysis, transform from ... import nd -from ..ir.base import Object, register_relay_node -from ..ir import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const -from ..ir.scope_builder import ScopeBuilder +from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const +from ..scope_builder import ScopeBuilder -@register_relay_node +@tvm._ffi.register_object("relay.ConstructorValue") class ConstructorValue(Object): def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( _make.ConstructorValue, tag, fields, constructor) -@register_relay_node +@tvm._ffi.register_object("relay.RefValue") class RefValue(Object): def __init__(self, value): self.__init_handle_by_constructor__( diff --git a/python/tvm/relay/ir/base.py b/python/tvm/relay/base.py similarity index 59% rename from python/tvm/relay/ir/base.py rename to python/tvm/relay/base.py index ad801ae23062..2c35681deb80 100644 --- a/python/tvm/relay/ir/base.py +++ b/python/tvm/relay/base.py @@ -23,43 +23,15 @@ from tvm.ir import SourceName, Span, Node as RelayNode -__STD_PATH__ = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), \ - os.pardir), "std") +__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") + @tvm._ffi.register_func("tvm.relay.std_path") def _std_path(): return __STD_PATH__ -def register_relay_node(type_key=None): - """Register a Relay node type. - - Parameters - ---------- - type_key : str or cls - The type key of the node. - """ - if not isinstance(type_key, str): - return tvm._ffi.register_object( - "relay." + type_key.__name__)(type_key) - return tvm._ffi.register_object(type_key) - - -def register_relay_attr_node(type_key=None): - """Register a Relay attribute node. - - Parameters - ---------- - type_key : str or cls - The type key of the node. - """ - if not isinstance(type_key, str): - return tvm._ffi.register_object( - "relay.attrs." + type_key.__name__)(type_key) - return tvm._ffi.register_object(type_key) - - -@register_relay_node +@tvm._ffi.register_object("relay.Id") class Id(Object): """Unique identifier(name) used in Var. Guaranteed to be stable across all passes. diff --git a/python/tvm/relay/ir/expr.py b/python/tvm/relay/expr.py similarity index 97% rename from python/tvm/relay/ir/expr.py rename to python/tvm/relay/expr.py index 1ee187b27e7b..380cdf7d90ef 100644 --- a/python/tvm/relay/ir/expr.py +++ b/python/tvm/relay/expr.py @@ -20,11 +20,12 @@ from numbers import Number as _Number import numpy as _np +import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, convert, ndarray as _nd from tvm.ir import RelayExpr, GlobalVar, BaseFunc -from .base import RelayNode, register_relay_node +from .base import RelayNode from . import _ffi_api from . import ty as _ty @@ -159,7 +160,7 @@ def __call__(self, *args): """ return Call(self, args) -@register_relay_node +@tvm._ffi.register_object("relay.Constant") class Constant(ExprWithOp): """A constant expression in Relay. @@ -172,7 +173,7 @@ def __init__(self, data): self.__init_handle_by_constructor__(_ffi_api.Constant, data) -@register_relay_node +@tvm._ffi.register_object("relay.Tuple") class Tuple(ExprWithOp): """Tuple expression that groups several fields together. @@ -196,7 +197,7 @@ def astype(self, _): raise TypeError("astype cannot be used on tuple") -@register_relay_node +@tvm._ffi.register_object("relay.Var") class Var(ExprWithOp): """A local variable in Relay. @@ -224,7 +225,7 @@ def name_hint(self): return name -@register_relay_node +@tvm._ffi.register_object("relay.Function") class Function(BaseFunc): """A function declaration expression. @@ -286,7 +287,7 @@ def with_attr(self, attr_key, attr_value): -@register_relay_node +@tvm._ffi.register_object("relay.Call") class Call(ExprWithOp): """Function call node in Relay. @@ -315,7 +316,7 @@ def __init__(self, op, args, attrs=None, type_args=None): _ffi_api.Call, op, args, attrs, type_args) -@register_relay_node +@tvm._ffi.register_object("relay.Let") class Let(ExprWithOp): """Let variable binding expression. @@ -335,7 +336,7 @@ def __init__(self, variable, value, body): _ffi_api.Let, variable, value, body) -@register_relay_node +@tvm._ffi.register_object("relay.If") class If(ExprWithOp): """A conditional expression in Relay. @@ -355,7 +356,7 @@ def __init__(self, cond, true_branch, false_branch): _ffi_api.If, cond, true_branch, false_branch) -@register_relay_node +@tvm._ffi.register_object("relay.TupleGetItem") class TupleGetItem(ExprWithOp): """Get index-th item from a tuple. @@ -372,7 +373,7 @@ def __init__(self, tuple_value, index): _ffi_api.TupleGetItem, tuple_value, index) -@register_relay_node +@tvm._ffi.register_object("relay.RefCreate") class RefCreate(ExprWithOp): """Create a new reference from initial value. Parameters @@ -384,7 +385,7 @@ def __init__(self, value): self.__init_handle_by_constructor__(_ffi_api.RefCreate, value) -@register_relay_node +@tvm._ffi.register_object("relay.RefRead") class RefRead(ExprWithOp): """Get the value inside the reference. Parameters @@ -396,7 +397,7 @@ def __init__(self, ref): self.__init_handle_by_constructor__(_ffi_api.RefRead, ref) -@register_relay_node +@tvm._ffi.register_object("relay.RefWrite") class RefWrite(ExprWithOp): """ Update the value inside the reference. diff --git a/python/tvm/relay/ir/expr_functor.py b/python/tvm/relay/expr_functor.py similarity index 99% rename from python/tvm/relay/ir/expr_functor.py rename to python/tvm/relay/expr_functor.py index d3beb200e407..8d6923920979 100644 --- a/python/tvm/relay/ir/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -21,7 +21,7 @@ from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite from .adt import Constructor, Match, Clause -from ..op import Op +from .op import Op class ExprFunctor: """ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7b16bc9849aa..6da91c17fd94 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -30,7 +30,7 @@ from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op -from ..ir.loops import while_loop +from ..loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d1abb4f1a5d3..29d9d1bcb93b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,12 +27,12 @@ import tvm from tvm.ir import IRModule -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from .. import analysis from .. import expr as _expr from .. import op as _op -from ..ir import ExprMutator +from ..expr_functor import ExprMutator from .common import AttrCvt, get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape diff --git a/python/tvm/relay/ir/__init__.py b/python/tvm/relay/ir/__init__.py deleted file mode 100644 index 2a141aa271ed..000000000000 --- a/python/tvm/relay/ir/__init__.py +++ /dev/null @@ -1,95 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, redefined-builtin, invalid-name -"""The Relay IR namespace.""" -from . import base -from . import ty -from . import expr -from . import type_functor -from . import expr_functor -from . import adt -from . import prelude -from . import loops -from . import scope_builder - -# Span -Span = base.Span -SourceName = base.SourceName - -# Type -Type = ty.Type -TupleType = ty.TupleType -TensorType = ty.TensorType -TypeKind = ty.TypeKind -TypeVar = ty.TypeVar -ShapeVar = ty.ShapeVar -TypeConstraint = ty.TypeConstraint -FuncType = ty.FuncType -TypeRelation = ty.TypeRelation -IncompleteType = ty.IncompleteType -scalar_type = ty.scalar_type -RefType = ty.RefType -GlobalTypeVar = ty.GlobalTypeVar -TypeCall = ty.TypeCall -Any = ty.Any - -# Expr -Expr = expr.RelayExpr -Constant = expr.Constant -Tuple = expr.Tuple -Var = expr.Var -GlobalVar = expr.GlobalVar -Function = expr.Function -Call = expr.Call -Let = expr.Let -If = expr.If -TupleGetItem = expr.TupleGetItem -RefCreate = expr.RefCreate -RefRead = expr.RefRead -RefWrite = expr.RefWrite - -# ADT -Pattern = adt.Pattern -PatternWildcard = adt.PatternWildcard -PatternVar = adt.PatternVar -PatternConstructor = adt.PatternConstructor -PatternTuple = adt.PatternTuple -Constructor = adt.Constructor -TypeData = adt.TypeData -Clause = adt.Clause -Match = adt.Match - -# helper functions -var = expr.var -const = expr.const -bind = expr.bind - -# TypeFunctor -TypeFunctor = type_functor.TypeFunctor -TypeVisitor = type_functor.TypeVisitor -TypeMutator = type_functor.TypeMutator - -# ExprFunctor -ExprFunctor = expr_functor.ExprFunctor -ExprVisitor = expr_functor.ExprVisitor -ExprMutator = expr_functor.ExprMutator - -# Prelude -Prelude = prelude.Prelude - -# Scope builder -ScopeBuilder = scope_builder.ScopeBuilder diff --git a/python/tvm/relay/ir/loops.py b/python/tvm/relay/loops.py similarity index 100% rename from python/tvm/relay/ir/loops.py rename to python/tvm/relay/loops.py diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index acc975cbe854..b3054d67885b 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -41,13 +41,12 @@ from . import _transform from . import _reduce from . import _algorithm -from ..ir.base import register_relay_node def _register_op_make(): # pylint: disable=import-outside-toplevel from . import _make - from ..ir import expr + from .. import expr expr._op_make = _make _register_op_make() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 039b2f738883..33a193799288 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -21,7 +21,7 @@ from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple -from ..ir.expr import Tuple, TupleGetItem, const +from ..expr import Tuple, TupleGetItem, const from . import nn as _nn from .op import register_gradient from .reduce import sum as _sum diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 414b458cbd7c..17fab80118af 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,7 +17,7 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs from . import _make -from ..ir.expr import TupleWrapper +from ..expr import TupleWrapper def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 36f2fa565c9f..30918a4183b1 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -17,7 +17,7 @@ #pylint: disable=invalid-name, too-many-lines """Neural network operations.""" from __future__ import absolute_import as _abs -from ...ir.expr import TupleWrapper +from ...expr import TupleWrapper from . import _make from .util import get_pad_tuple2d diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 26a6b6ead8cd..e6bd6bf230dd 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -19,13 +19,12 @@ import tvm._ffi from tvm.driver import lower, build -from ..ir.base import register_relay_node -from ..ir.expr import RelayExpr +from ..expr import RelayExpr from ...target import get_native_generic_func, GenericFunc from ...runtime import Object from . import _make -@register_relay_node +@tvm._ffi.register_object("relay.Op") class Op(RelayExpr): """A Relay operator definition.""" diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 9224f570f39d..2f68f7074427 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -16,314 +16,326 @@ # under the License. """The attributes node used for Relay operators""" +import tvm._ffi from tvm.ir import Attrs -from ..ir.base import register_relay_attr_node -@register_relay_attr_node +def _register_relay_attr_node(type_key=None): + """Register a Relay attribute node. + + Parameters + ---------- + type_key : str or cls + The type key of the node. + """ + return tvm._ffi.register_object( + "relay.attrs." + type_key.__name__)(type_key) + + +@_register_relay_attr_node class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DWinogradAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_without_weight_transform""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DWinogradWeightTransformAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_weight_transform""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_nnpack_weight_transform""" -@register_relay_attr_node +@_register_relay_attr_node class GlobalPool2DAttrs(Attrs): """Attributes for nn.global_pool""" -@register_relay_attr_node +@_register_relay_attr_node class BiasAddAttrs(Attrs): """Atttribute of nn.bias_add""" -@register_relay_attr_node +@_register_relay_attr_node class DenseAttrs(Attrs): """Attributes for nn.dense""" -@register_relay_attr_node +@_register_relay_attr_node class FIFOBufferAttrs(Attrs): """Attributes for nn.fifo_buffer""" -@register_relay_attr_node +@_register_relay_attr_node class UpSamplingAttrs(Attrs): """Attributes for nn.upsampling""" -@register_relay_attr_node +@_register_relay_attr_node class UpSampling3DAttrs(Attrs): """Attributes for nn.upsampling3d""" -@register_relay_attr_node +@_register_relay_attr_node class PadAttrs(Attrs): """Attributes for nn.pad""" -@register_relay_attr_node +@_register_relay_attr_node class MirrorPadAttrs(Attrs): """Attributes for nn.mirror_pad""" -@register_relay_attr_node +@_register_relay_attr_node class LeakyReluAttrs(Attrs): """Attributes for nn.leaky_relu""" -@register_relay_attr_node +@_register_relay_attr_node class PReluAttrs(Attrs): """Attributes for nn.prelu""" -@register_relay_attr_node +@_register_relay_attr_node class DropoutAttrs(Attrs): """Attributes for nn.dropout""" -@register_relay_attr_node +@_register_relay_attr_node class BatchNormAttrs(Attrs): """Attributes for nn.batch_norm""" -@register_relay_attr_node +@_register_relay_attr_node class LRNAttrs(Attrs): """Attributes for nn.lrn""" -@register_relay_attr_node +@_register_relay_attr_node class L2NormalizeAttrs(Attrs): """Attributes for nn.l2_normalize""" -@register_relay_attr_node +@_register_relay_attr_node class DeformableConv2DAttrs(Attrs): """Attributes for nn.deformable_conv2d""" -@register_relay_attr_node +@_register_relay_attr_node class ResizeAttrs(Attrs): """Attributes for image.resize""" -@register_relay_attr_node +@_register_relay_attr_node class CropAndResizeAttrs(Attrs): """Attributes for image.crop_and_resize""" -@register_relay_attr_node +@_register_relay_attr_node class ArgsortAttrs(Attrs): """Attributes for algorithm.argsort""" -@register_relay_attr_node +@_register_relay_attr_node class OnDeviceAttrs(Attrs): """Attributes for annotation.on_device""" -@register_relay_attr_node +@_register_relay_attr_node class DebugAttrs(Attrs): """Attributes for debug""" -@register_relay_attr_node +@_register_relay_attr_node class DeviceCopyAttrs(Attrs): """Attributes for tensor.device_copy""" -@register_relay_attr_node +@_register_relay_attr_node class CastAttrs(Attrs): """Attributes for transform.cast""" -@register_relay_attr_node +@_register_relay_attr_node class ConcatenateAttrs(Attrs): """Attributes for tensor.concatenate""" -@register_relay_attr_node +@_register_relay_attr_node class TransposeAttrs(Attrs): """Attributes for transform.transpose""" -@register_relay_attr_node +@_register_relay_attr_node class ReshapeAttrs(Attrs): """Attributes for transform.reshape""" -@register_relay_attr_node +@_register_relay_attr_node class TakeAttrs(Attrs): """Attributes for transform.take""" -@register_relay_attr_node +@_register_relay_attr_node class InitOpAttrs(Attrs): """Attributes for ops specifying a tensor""" -@register_relay_attr_node +@_register_relay_attr_node class ArangeAttrs(Attrs): """Attributes used in arange operators""" -@register_relay_attr_node +@_register_relay_attr_node class StackAttrs(Attrs): """Attributes used in stack operators""" -@register_relay_attr_node +@_register_relay_attr_node class RepeatAttrs(Attrs): """Attributes used in repeat operators""" -@register_relay_attr_node +@_register_relay_attr_node class TileAttrs(Attrs): """Attributes used in tile operators""" -@register_relay_attr_node +@_register_relay_attr_node class ReverseAttrs(Attrs): """Attributes used in reverse operators""" -@register_relay_attr_node +@_register_relay_attr_node class SqueezeAttrs(Attrs): """Attributes used in squeeze operators""" -@register_relay_attr_node +@_register_relay_attr_node class SplitAttrs(Attrs): """Attributes for transform.split""" -@register_relay_attr_node +@_register_relay_attr_node class StridedSliceAttrs(Attrs): """Attributes for transform.stranded_slice""" -@register_relay_attr_node +@_register_relay_attr_node class SliceLikeAttrs(Attrs): """Attributes for transform.slice_like""" -@register_relay_attr_node +@_register_relay_attr_node class ClipAttrs(Attrs): """Attributes for transform.clip""" -@register_relay_attr_node +@_register_relay_attr_node class LayoutTransformAttrs(Attrs): """Attributes for transform.layout_transform""" -@register_relay_attr_node +@_register_relay_attr_node class ShapeOfAttrs(Attrs): """Attributes for tensor.shape_of""" -@register_relay_attr_node +@_register_relay_attr_node class MultiBoxPriorAttrs(Attrs): """Attributes for vision.multibox_prior""" -@register_relay_attr_node +@_register_relay_attr_node class MultiBoxTransformLocAttrs(Attrs): """Attributes for vision.multibox_transform_loc""" -@register_relay_attr_node +@_register_relay_attr_node class GetValidCountsAttrs(Attrs): """Attributes for vision.get_valid_counts""" -@register_relay_attr_node +@_register_relay_attr_node class NonMaximumSuppressionAttrs(Attrs): """Attributes for vision.non_maximum_suppression""" -@register_relay_attr_node +@_register_relay_attr_node class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" -@register_relay_attr_node +@_register_relay_attr_node class ROIPoolAttrs(Attrs): """Attributes for vision.roi_pool""" -@register_relay_attr_node +@_register_relay_attr_node class YoloReorgAttrs(Attrs): """Attributes for vision.yolo_reorg""" -@register_relay_attr_node +@_register_relay_attr_node class ProposalAttrs(Attrs): """Attributes used in proposal operators""" -@register_relay_attr_node +@_register_relay_attr_node class MaxPool2DAttrs(Attrs): """Attributes used in max_pool2d operators""" -@register_relay_attr_node +@_register_relay_attr_node class AvgPool2DAttrs(Attrs): """Attributes used in avg_pool2d operators""" -@register_relay_attr_node +@_register_relay_attr_node class MaxPool1DAttrs(Attrs): """Attributes used in max_pool1d operators""" -@register_relay_attr_node +@_register_relay_attr_node class AvgPool1DAttrs(Attrs): """Attributes used in avg_pool1d operators""" -@register_relay_attr_node +@_register_relay_attr_node class MaxPool3DAttrs(Attrs): """Attributes used in max_pool3d operators""" -@register_relay_attr_node +@_register_relay_attr_node class AvgPool3DAttrs(Attrs): """Attributes used in avg_pool3d operators""" -@register_relay_attr_node +@_register_relay_attr_node class BitPackAttrs(Attrs): """Attributes used in bitpack operator""" -@register_relay_attr_node +@_register_relay_attr_node class BinaryConv2DAttrs(Attrs): """Attributes used in bitserial conv2d operators""" -@register_relay_attr_node +@_register_relay_attr_node class BinaryDenseAttrs(Attrs): """Attributes used in bitserial dense operators""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DTransposeAttrs(Attrs): """Attributes used in Transposed Conv2D operators""" -@register_relay_attr_node +@_register_relay_attr_node class SubPixelAttrs(Attrs): """Attributes used in depth to space and space to depth operators""" diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 05f8932e396c..d3226012e887 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -20,7 +20,7 @@ from . import _make from .tensor import sqrt from .transform import squeeze -from ..ir.expr import Tuple, TupleWrapper +from ..expr import Tuple, TupleWrapper def argmax(data, axis=None, keepdims=False, exclude=False): """Returns the indices of the maximum values along an axis. diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 90e4604e80f3..77969185c0a7 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -20,7 +20,7 @@ from tvm.runtime import TVMContext as _TVMContext from . import _make -from ..ir.expr import Tuple +from ..expr import Tuple # We create a wrapper function for each operator in the diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 3e18447c2c6d..6a30eb2b7ec9 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -19,7 +19,7 @@ """Transform operators.""" from . import _make -from ..ir.expr import TupleWrapper, const +from ..expr import TupleWrapper, const def cast(data, dtype): @@ -38,7 +38,7 @@ def cast(data, dtype): result : relay.Expr The casted result. """ - from ..ir import _ffi_api as _relay_make + from .. import _ffi_api as _relay_make return _relay_make.cast(data, dtype) @@ -55,7 +55,7 @@ def cast_like(data, dtype_like): result : relay.Expr The casted result. """ - from .. import _make as _relay_make + from .. import _ffi_api as _relay_make return _relay_make.cast_like(data, dtype_like) diff --git a/python/tvm/relay/op/vision/multibox.py b/python/tvm/relay/op/vision/multibox.py index 73f4893412fc..55fb01c5eaef 100644 --- a/python/tvm/relay/op/vision/multibox.py +++ b/python/tvm/relay/op/vision/multibox.py @@ -16,7 +16,7 @@ # under the License. """Multibox operations.""" from . import _make -from ...ir.expr import TupleWrapper +from ...expr import TupleWrapper def multibox_prior(data, sizes=(1.0,), diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 859c4999545d..cba08bfba824 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -16,7 +16,7 @@ # under the License. """Non-maximum suppression operations.""" from . import _make -from ...ir.expr import TupleWrapper +from ...expr import TupleWrapper def get_valid_counts(data, score_threshold, diff --git a/python/tvm/relay/ir/parser.py b/python/tvm/relay/parser.py similarity index 94% rename from python/tvm/relay/ir/parser.py rename to python/tvm/relay/parser.py index 053d299da022..6c4e3131e3c2 100644 --- a/python/tvm/relay/ir/parser.py +++ b/python/tvm/relay/parser.py @@ -16,14 +16,14 @@ # under the License. """A parser for Relay's text format.""" from __future__ import absolute_import -from ... import register_func +from .. import register_func @register_func("relay.fromtext") def fromtext(data, source_name=None): """Parse a Relay program.""" # pylint: disable=import-outside-toplevel - from . import _parser + from tvm.relay import _parser x = _parser.fromtext(data + "\n", source_name) if x is None: raise Exception("cannot parse: ", data) diff --git a/python/tvm/relay/ir/prelude.py b/python/tvm/relay/prelude.py similarity index 99% rename from python/tvm/relay/ir/prelude.py rename to python/tvm/relay/prelude.py index fa68d3ae177a..5288a2e08011 100644 --- a/python/tvm/relay/ir/prelude.py +++ b/python/tvm/relay/prelude.py @@ -20,10 +20,10 @@ from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, Function, GlobalVar, If, const -from ..op.tensor import add, subtract, equal +from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard -from .. import op +from . import op class TensorArrayOps(object): diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index b53992637250..c94a4194daee 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,7 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm.relay.ir import Tuple +from tvm.relay.expr import Tuple from tvm.relay.op.nn.util import get_pad_tuple2d from . import _make diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 26d8a18f7b98..2658a0aa7dad 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -24,7 +24,6 @@ from .. import analysis as _analysis from .. import op as _op from ..op import op as _reg -from ..ir.base import register_relay_node from . import _quantize from .quantize import QAnnotateKind, current_qconfig, quantize_context from .quantize import _forward_op @@ -58,7 +57,7 @@ def simulated_quantize_compute(attrs, inputs, out_type): _reg.register_injective_schedule("annotation.cast_hint") -@register_relay_node +@tvm._ffi.register_object("relay.QAnnotateExpr") class QAnnotateExpr(_expr.TempExpr): """A special kind of Expr for Annotating. diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index 90274e879992..bb3db99eed79 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -19,7 +19,6 @@ import tvm from .. import expr as _expr from .. import analysis as _analysis -from ..ir.base import register_relay_node from ..op import op as _reg from . import _quantize from .quantize import _forward_op @@ -30,7 +29,7 @@ def _register(func): return _register(frewrite) if frewrite is not None else _register -@register_relay_node +@tvm._ffi.register_object("relay.QPartitionExpr") class QPartitionExpr(_expr.TempExpr): def __init__(self, expr): self.__init_handle_by_constructor__( diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index bd1ab6fd7dfc..2ad4e18771d7 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -17,12 +17,12 @@ #pylint: disable=unused-argument, not-context-manager """Automatic quantization toolkit.""" import tvm.ir +from tvm.runtime import Object from . import _quantize from ._calibrate import calibrate from .. import expr as _expr from .. import transform as _transform -from ..ir.base import Object, register_relay_node class QAnnotateKind(object): @@ -52,7 +52,7 @@ def _forward_op(ref_call, args): ref_call.op, args, ref_call.attrs, ref_call.type_args) -@register_relay_node("relay.quantize.QConfig") +@tvm._ffi.register_object("relay.quantize.QConfig") class QConfig(Object): """Configure the quantization behavior by setting config variables. diff --git a/python/tvm/relay/ir/scope_builder.py b/python/tvm/relay/scope_builder.py similarity index 99% rename from python/tvm/relay/ir/scope_builder.py rename to python/tvm/relay/scope_builder.py index 35357707e535..cd8dc8dcd309 100644 --- a/python/tvm/relay/ir/scope_builder.py +++ b/python/tvm/relay/scope_builder.py @@ -20,7 +20,7 @@ from . import ty as _ty from . import expr as _expr -from ..._ffi import base as _base +from .._ffi import base as _base class WithScope(object): """A wrapper for builder methods which introduce scoping. diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 068c02c6d53b..54c909179e4f 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -23,9 +23,9 @@ from tvm import te import tvm.relay as relay import tvm.relay.op as op -from tvm.relay import transform, create_executor -from tvm.relay.ir import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem -from tvm.relay.ir import TensorType, TupleType +from tvm.relay import transform +from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor +from tvm.relay import TensorType, TupleType from . import mlp from . import resnet diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index d1110fdb19d5..eb71120610d3 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -19,10 +19,10 @@ Nats are useful for testing purposes, as they make it easy to write test cases for recursion and pattern matching.""" -from tvm.relay.ir import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar +from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar from tvm.relay.backend.interpreter import ConstructorValue -from tvm.relay.ir import Var, Function, GlobalVar -from tvm.relay.ir import GlobalTypeVar, TypeVar, FuncType +from tvm.relay.expr import Var, Function, GlobalVar +from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType def define_nat_adt(prelude): """Defines a Peano (unary) natural number ADT. diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index e40436c4df22..eacfe379137f 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -21,10 +21,10 @@ import tvm from tvm import relay -from tvm.relay.ir import Pattern +from tvm.relay.adt import Pattern from tvm.relay.backend import compile_engine -from tvm.relay.ir import Expr, Function, GlobalVar, Var -from tvm.relay.ir import ExprFunctor +from tvm.relay.expr import Expr, Function, GlobalVar, Var +from tvm.relay.expr_functor import ExprFunctor OUTPUT_VAR_NAME = '_py_out' diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 7b5a0b75ac6c..c238730807d3 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -19,12 +19,12 @@ A pass for manifesting explicit memory allocations. """ import numpy as np -from ..ir.expr_functor import ExprMutator -from ..ir.scope_builder import ScopeBuilder +from ..expr_functor import ExprMutator +from ..scope_builder import ScopeBuilder from . import transform from .. import op from ... import DataType, register_func -from ..ir import ty, expr +from .. import ty, expr from ..backend import compile_engine diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index e32de3e3d112..43a116e64e5b 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -29,7 +29,6 @@ from tvm import relay from . import _ffi_api -from ..ir.base import register_relay_node def build_config(opt_level=2, @@ -83,7 +82,7 @@ def build_config(opt_level=2, disabled_pass, trace) -@register_relay_node +@tvm._ffi.register_object("relay.FunctionPass") class FunctionPass(Pass): """A pass that works on each tvm.relay.Function in a module. A function pass class should be created through `function_pass`. diff --git a/python/tvm/relay/ir/ty.py b/python/tvm/relay/ty.py similarity index 97% rename from python/tvm/relay/ir/ty.py rename to python/tvm/relay/ty.py index b9643803f2f6..19cc10aba41e 100644 --- a/python/tvm/relay/ir/ty.py +++ b/python/tvm/relay/ty.py @@ -20,7 +20,7 @@ from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType -from .base import RelayNode, register_relay_node +from .base import RelayNode from . import _ffi_api Any = _ffi_api.Any diff --git a/python/tvm/relay/ir/type_functor.py b/python/tvm/relay/type_functor.py similarity index 100% rename from python/tvm/relay/ir/type_functor.py rename to python/tvm/relay/type_functor.py diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 491f18df3de0..deeb7330f9da 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -19,7 +19,7 @@ from tvm import relay from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr import numpy as np diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_analysis_feature.py similarity index 98% rename from tests/python/relay/test_feature.py rename to tests/python/relay/test_analysis_feature.py index f54fa713e957..ec5deb3c4e60 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_analysis_feature.py @@ -20,7 +20,7 @@ from tvm import relay from tvm.relay.analysis import detect_feature, Feature from tvm.relay.transform import gradient -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import run_infer_type def test_prelude(): diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 914ac12ed6ff..aa81e3113b7f 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -20,7 +20,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.ir.loops import while_loop +from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type def int32(val): diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 0534c16dedaa..360b6bd20416 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.runtime import container from tvm.relay.backend.interpreter import RefValue, ConstructorValue -from tvm.relay.ir import ScopeBuilder +from tvm.relay.scope_builder import ScopeBuilder from tvm.relay import testing, create_executor diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py index bfc9accf5906..bab82472263a 100644 --- a/tests/python/relay/test_ir_module.py +++ b/tests/python/relay/test_ir_module.py @@ -18,7 +18,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions def constructor_list(p): diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index daf436c584cc..db953d5762e3 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -18,7 +18,7 @@ from tvm import te from tvm import relay from tvm.relay.analysis import well_formed -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude def test_let(): x = relay.Var("x") diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index e98c7ec1fe6f..3e7d916c96fa 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -19,10 +19,9 @@ import numpy as np import tvm -from tvm import te from tvm import relay from tvm.contrib import graph_runtime -from tvm.relay.ir import ExprMutator +from tvm.relay.expr_functor import ExprMutator from tvm.relay import transform diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 48923b5ce061..6f2a12589fb5 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -22,7 +22,7 @@ from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal from tvm.relay import create_executor, transform from tvm.relay.transform import gradient -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand import tvm.relay.op as op diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 2b865c57123d..aed026996a21 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -21,7 +21,8 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.ir import ExprFunctor, Function, Call +from tvm.relay import ExprFunctor +from tvm.relay import Function, Call from tvm.relay import analysis from tvm.relay import transform as _transform from tvm.relay.testing import ctx_list, run_infer_type diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index f2f6f85955e4..f54dd6bf69c5 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -20,11 +20,11 @@ from tvm import te from tvm import relay from tvm.relay.analysis import alpha_equal, assert_alpha_equal -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay import op, create_executor, transform -from tvm.relay.ir import Var, TypeVar, TupleGetItem, Let, const, RefRead, RefWrite, RefCreate -from tvm.relay.ir import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match -from tvm.relay.ir import GlobalVar, Call, Function +from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate +from tvm.relay import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match +from tvm.relay import GlobalVar, Call from tvm.relay.transform import gradient from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 36e371d68747..c4fbbc1458d9 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -26,7 +26,7 @@ from tvm.relay import transform from tvm.contrib import util from tvm.relay.op.annotation import compiler_begin, compiler_end -from tvm.relay.ir import ExprMutator +from tvm.relay.expr_functor import ExprMutator # Leverage the pass manager to write a simple white list based annotator @transform.function_pass(opt_level=0) diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index e42bdbd94986..33816344f562 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -19,7 +19,7 @@ from tvm import te from tvm import relay from tvm.relay import transform -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude def test_remove_all_prelude_functions(): diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 1f0c9f33cc73..2a6103ea1fbe 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -20,7 +20,7 @@ from tvm import relay from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay import op, create_executor, transform -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count from tvm.relay.analysis import Feature diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 76b906cf96ab..e2ac924e9661 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -16,12 +16,11 @@ # under the License. import numpy as np import tvm -from tvm import te from tvm import relay from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay.transform import to_cps, un_cps from tvm.relay.analysis import Feature -from tvm.relay.ir import Prelude, Function +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass from tvm.relay import create_executor from tvm.relay import transform diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index bac0970b098b..42344bccabaa 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -18,7 +18,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.analysis import unmatched_cases import pytest diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index 36ed034ba9eb..f6b1b2432d92 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -16,9 +16,10 @@ # under the License. import numpy as np import tvm +from tvm import te from tvm import relay from tvm.relay.testing import to_python, run_as_python -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.runtime.container import ADT from tvm.relay.backend.interpreter import RefValue, ConstructorValue diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py index 8f581b5b715c..9e023bc6b1e4 100644 --- a/tests/python/relay/test_type_functor.py +++ b/tests/python/relay/test_type_functor.py @@ -19,9 +19,9 @@ from tvm import relay from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor from tvm.relay.analysis import assert_graph_equal -from tvm.relay.ir import (TypeVar, IncompleteType, TensorType, FuncType, +from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType, TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall) -from tvm.relay.ir import TypeData +from tvm.relay.adt import TypeData def check_visit(typ): try: diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index b31f04c5edd9..f2b15ec26f32 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -20,10 +20,10 @@ import tvm from tvm import runtime from tvm import relay -from tvm.relay.ir import ScopeBuilder +from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.testing.config import ctx_list -from tvm.relay.ir import Prelude -from tvm.relay.ir.loops import while_loop +from tvm.relay.prelude import Prelude +from tvm.relay.loops import while_loop from tvm.relay import testing def check_result(args, expected_result, mod=None): diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index a2a786e6a8be..5d20651a8126 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -24,8 +24,8 @@ from tvm.relay import vm as rly_vm from tvm import relay -from tvm.relay.ir import ScopeBuilder -from tvm.relay.ir import Prelude +from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay.prelude import Prelude from tvm.contrib import util from tvm.relay import testing diff --git a/tests/python/unittest/test_autotvm_graph_tuner_utils.py b/tests/python/unittest/test_autotvm_graph_tuner_utils.py index b558010823e1..bd0ebe0cd3f5 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_utils.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_utils.py @@ -28,7 +28,7 @@ from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \ get_out_nodes, expr2graph, bind_inputs from tvm.autotvm.graph_tuner._base import OPT_OUT_OP -from tvm.relay.ir import Call, TupleGetItem, Tuple, Var +from tvm.relay.expr import Call, TupleGetItem, Tuple, Var def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):