diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst index dc9dac0f67bd..8dcab78d3231 100644 --- a/docs/api/python/relay/base.rst +++ b/docs/api/python/relay/base.rst @@ -19,10 +19,6 @@ tvm.relay.base -------------- .. automodule:: tvm.relay.base -.. autofunction:: tvm.relay.base.register_relay_node - -.. autofunction:: tvm.relay.base.register_relay_attr_node - .. autoclass:: tvm.relay.base.RelayNode :members: diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index f4a7c75864d5..b1aac3e606a2 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -19,35 +19,37 @@ import os from sys import setrecursionlimit -from . import call_graph from . import base from . import ty from . import expr from . import type_functor from . import expr_functor from . import adt -from . import analysis +from . import prelude +from . import loops +from . import scope_builder +from . import parser + from . import transform +from . import analysis +from .analysis import alpha_equal from .build_module import build, create_executor, optimize from .transform import build_config -from . import prelude -from . import parser from . import debug from . import param_dict -from . import feature from .backend import vm # Root operators from .op import Op +from .op import nn +from .op import image +from .op import annotation +from .op import vision +from .op import contrib from .op.reduce import * from .op.tensor import * from .op.transform import * from .op.algorithm import * -from . import nn -from . import annotation -from . import vision -from . import contrib -from . import image from . import frontend from . import backend from . import quantize @@ -55,15 +57,12 @@ # Dialects from . import qnn -from .scope_builder import ScopeBuilder -# Load Memory pass -from . import memory_alloc - # Required to traverse large programs setrecursionlimit(10000) # Span Span = base.Span +SourceName = base.SourceName # Type Type = ty.Type @@ -98,6 +97,7 @@ RefWrite = expr.RefWrite # ADT +Pattern = adt.Pattern PatternWildcard = adt.PatternWildcard PatternVar = adt.PatternVar PatternConstructor = adt.PatternConstructor @@ -111,9 +111,6 @@ var = expr.var const = expr.const bind = expr.bind -module_pass = transform.module_pass -function_pass = transform.function_pass -alpha_equal = analysis.alpha_equal # TypeFunctor TypeFunctor = type_functor.TypeFunctor @@ -125,6 +122,15 @@ ExprVisitor = expr_functor.ExprVisitor ExprMutator = expr_functor.ExprMutator +# Prelude +Prelude = prelude.Prelude + +# Scope builder +ScopeBuilder = scope_builder.ScopeBuilder + +module_pass = transform.module_pass +function_pass = transform.function_pass + # Parser fromtext = parser.fromtext @@ -139,9 +145,3 @@ ModulePass = transform.ModulePass FunctionPass = transform.FunctionPass Sequential = transform.Sequential - -# Feature -Feature = feature.Feature - -# CallGraph -CallGraph = call_graph.CallGraph diff --git a/python/tvm/relay/_analysis.py b/python/tvm/relay/_ffi_api.py similarity index 88% rename from python/tvm/relay/_analysis.py rename to python/tvm/relay/_ffi_api.py index 050fcce2fb17..8e9b46a14d35 100644 --- a/python/tvm/relay/_analysis.py +++ b/python/tvm/relay/_ffi_api.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI exposing the passes for Relay program analysis.""" +"""FFI APIs for Relay program IR.""" import tvm._ffi -tvm._ffi._init_api("relay._analysis", __name__) +tvm._ffi._init_api("relay.ir", __name__) diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index 9c5dac6362e2..df12aaece2da 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -17,9 +17,11 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import """Algebraic data types in Relay.""" from tvm.ir import Constructor, TypeData +from tvm.runtime import Object +import tvm._ffi -from .base import RelayNode, register_relay_node, Object -from . import _make +from .base import RelayNode +from . import _ffi_api from .ty import Type from .expr import ExprWithOp, RelayExpr, Call @@ -28,7 +30,7 @@ class Pattern(RelayNode): """Base type for pattern matching constructs.""" -@register_relay_node +@tvm._ffi.register_object("relay.PatternWildcard") class PatternWildcard(Pattern): """Wildcard pattern in Relay: Matches any ADT and binds nothing.""" @@ -44,10 +46,10 @@ def __init__(self): wildcard: PatternWildcard a wildcard pattern. """ - self.__init_handle_by_constructor__(_make.PatternWildcard) + self.__init_handle_by_constructor__(_ffi_api.PatternWildcard) -@register_relay_node +@tvm._ffi.register_object("relay.PatternVar") class PatternVar(Pattern): """Variable pattern in Relay: Matches anything and binds it to the variable.""" @@ -63,10 +65,10 @@ def __init__(self, var): pv: PatternVar A variable pattern. """ - self.__init_handle_by_constructor__(_make.PatternVar, var) + self.__init_handle_by_constructor__(_ffi_api.PatternVar, var) -@register_relay_node +@tvm._ffi.register_object("relay.PatternConstructor") class PatternConstructor(Pattern): """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" @@ -88,10 +90,10 @@ def __init__(self, constructor, patterns=None): """ if patterns is None: patterns = [] - self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns) + self.__init_handle_by_constructor__(_ffi_api.PatternConstructor, constructor, patterns) -@register_relay_node +@tvm._ffi.register_object("relay.PatternTuple") class PatternTuple(Pattern): """Constructor pattern in Relay: Matches a tuple, binds recursively.""" @@ -111,10 +113,10 @@ def __init__(self, patterns=None): """ if patterns is None: patterns = [] - self.__init_handle_by_constructor__(_make.PatternTuple, patterns) + self.__init_handle_by_constructor__(_ffi_api.PatternTuple, patterns) -@register_relay_node +@tvm._ffi.register_object("relay.Clause") class Clause(Object): """Clause for pattern matching in Relay.""" @@ -133,10 +135,10 @@ def __init__(self, lhs, rhs): clause: Clause The Clause. """ - self.__init_handle_by_constructor__(_make.Clause, lhs, rhs) + self.__init_handle_by_constructor__(_ffi_api.Clause, lhs, rhs) -@register_relay_node +@tvm._ffi.register_object("relay.Match") class Match(ExprWithOp): """Pattern matching expression in Relay.""" @@ -160,4 +162,4 @@ def __init__(self, data, clauses, complete=True): match: tvm.relay.Expr The match expression. """ - self.__init_handle_by_constructor__(_make.Match, data, clauses, complete) + self.__init_handle_by_constructor__(_ffi_api.Match, data, clauses, complete) diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py new file mode 100644 index 000000000000..957f5a3dcd94 --- /dev/null +++ b/python/tvm/relay/analysis/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay IR namespace containing the analysis passes.""" +# Analysis passes +from .analysis import * + +# Call graph +from . import call_graph + +# Feature +from . import feature + +CallGraph = call_graph.CallGraph diff --git a/python/tvm/relay/_base.py b/python/tvm/relay/analysis/_ffi_api.py similarity index 82% rename from python/tvm/relay/_base.py rename to python/tvm/relay/analysis/_ffi_api.py index f86aa70353dc..20b03c396e70 100644 --- a/python/tvm/relay/_base.py +++ b/python/tvm/relay/analysis/_ffi_api.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable -"""The interface of expr function exposed from C++.""" +"""FFI APIs for Relay program analysis.""" import tvm._ffi -tvm._ffi._init_api("relay._base", __name__) +tvm._ffi._init_api("relay.analysis", __name__) diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis/analysis.py similarity index 88% rename from python/tvm/relay/analysis.py rename to python/tvm/relay/analysis/analysis.py index 198e0a3bf9eb..beb3c6599e28 100644 --- a/python/tvm/relay/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -22,9 +22,9 @@ """ from tvm.ir import RelayExpr, IRModule -from . import _analysis -from .ty import Type +from . import _ffi_api from .feature import Feature +from ..ty import Type def post_order_visit(expr, fvisit): @@ -40,7 +40,7 @@ def post_order_visit(expr, fvisit): fvisit : function The visitor function to be applied. """ - return _analysis.post_order_visit(expr, fvisit) + return _ffi_api.post_order_visit(expr, fvisit) def well_formed(expr): @@ -56,7 +56,7 @@ def well_formed(expr): well_form : bool Whether the input expression is well formed """ - return _analysis.well_formed(expr) + return _ffi_api.well_formed(expr) def check_kind(t, mod=None): @@ -85,9 +85,9 @@ def check_kind(t, mod=None): assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type """ if mod is not None: - return _analysis.check_kind(t, mod) + return _ffi_api.check_kind(t, mod) else: - return _analysis.check_kind(t) + return _ffi_api.check_kind(t) def check_constant(expr): @@ -103,7 +103,7 @@ def check_constant(expr): result : bool Whether the expression is constant. """ - return _analysis.check_constant(expr) + return _ffi_api.check_constant(expr) def free_vars(expr): @@ -125,7 +125,7 @@ def free_vars(expr): neural networks: usually this means weights of previous are ordered first. """ - return _analysis.free_vars(expr) + return _ffi_api.free_vars(expr) def bound_vars(expr): @@ -141,7 +141,7 @@ def bound_vars(expr): free : List[tvm.relay.Var] The list of bound variables in post-DFS order. """ - return _analysis.bound_vars(expr) + return _ffi_api.bound_vars(expr) def all_vars(expr): @@ -157,7 +157,7 @@ def all_vars(expr): free : List[tvm.relay.Var] The list of all variables in post-DFS order. """ - return _analysis.all_vars(expr) + return _ffi_api.all_vars(expr) def free_type_vars(expr, mod=None): @@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None): The list of free type variables in post-DFS order """ use_mod = mod if mod is not None else IRModule() - return _analysis.free_type_vars(expr, use_mod) + return _ffi_api.free_type_vars(expr, use_mod) def bound_type_vars(expr, mod=None): @@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None): The list of bound type variables in post-DFS order """ use_mod = mod if mod is not None else IRModule() - return _analysis.bound_type_vars(expr, use_mod) + return _ffi_api.bound_type_vars(expr, use_mod) def all_type_vars(expr, mod=None): @@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None): The list of all type variables in post-DFS order """ use_mod = mod if mod is not None else IRModule() - return _analysis.all_type_vars(expr, use_mod) + return _ffi_api.all_type_vars(expr, use_mod) def alpha_equal(lhs, rhs): @@ -236,7 +236,7 @@ def alpha_equal(lhs, rhs): result : bool True iff lhs is alpha equal to rhs. """ - return bool(_analysis._alpha_equal(lhs, rhs)) + return bool(_ffi_api._alpha_equal(lhs, rhs)) def assert_alpha_equal(lhs, rhs): @@ -250,7 +250,7 @@ def assert_alpha_equal(lhs, rhs): rhs : tvm.relay.Expr One of the input Expression. """ - _analysis._assert_alpha_equal(lhs, rhs) + _ffi_api._assert_alpha_equal(lhs, rhs) def graph_equal(lhs, rhs): @@ -272,7 +272,7 @@ def graph_equal(lhs, rhs): result : bool True iff lhs is data-flow equivalent to rhs. """ - return bool(_analysis._graph_equal(lhs, rhs)) + return bool(_ffi_api._graph_equal(lhs, rhs)) def assert_graph_equal(lhs, rhs): @@ -289,7 +289,7 @@ def assert_graph_equal(lhs, rhs): rhs : tvm.relay.Expr One of the input Expression. """ - _analysis._assert_graph_equal(lhs, rhs) + _ffi_api._assert_graph_equal(lhs, rhs) def collect_device_info(expr): @@ -303,10 +303,10 @@ def collect_device_info(expr): Returns ------- - ret : Dict[tvm.relay.expr, int] + ret : Dict[tvm.relay.ir.expr, int] A dictionary mapping tvm.relay.Expr to device type. """ - return _analysis.CollectDeviceInfo(expr) + return _ffi_api.CollectDeviceInfo(expr) def collect_device_annotation_ops(expr): @@ -319,11 +319,11 @@ def collect_device_annotation_ops(expr): Returns ------- - ret : Dict[tvm.relay.expr, int] + ret : Dict[tvm.relay.Expr, int] A dictionary mapping tvm.relay.Expr to device type where the keys are annotation expressions. """ - return _analysis.CollectDeviceAnnotationOps(expr) + return _ffi_api.CollectDeviceAnnotationOps(expr) def get_total_mac_number(expr): @@ -340,7 +340,7 @@ def get_total_mac_number(expr): result : int64 The number of MACs (multiply-accumulate) of a model """ - return _analysis.GetTotalMacNumber(expr) + return _ffi_api.GetTotalMacNumber(expr) def unmatched_cases(match, mod=None): @@ -360,7 +360,7 @@ def unmatched_cases(match, mod=None): missing_patterns : [tvm.relay.Pattern] Patterns that the match expression does not catch. """ - return _analysis.unmatched_cases(match, mod) + return _ffi_api.unmatched_cases(match, mod) def detect_feature(a, b=None): @@ -383,7 +383,7 @@ def detect_feature(a, b=None): """ if isinstance(a, IRModule): a, b = b, a - return {Feature(int(x)) for x in _analysis.detect_feature(a, b)} + return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)} def structural_hash(value): @@ -400,9 +400,9 @@ def structural_hash(value): The hash value """ if isinstance(value, RelayExpr): - return int(_analysis._expr_hash(value)) + return int(_ffi_api._expr_hash(value)) elif isinstance(value, Type): - return int(_analysis._type_hash(value)) + return int(_ffi_api._type_hash(value)) else: msg = ("found value of type {0} expected" + "relay.Expr or relay.Type").format(type(value)) @@ -421,10 +421,10 @@ def extract_fused_functions(mod): Returns ------- - ret : Dict[int, tvm.relay.expr.Function] + ret : Dict[int, tvm.relay.ir.expr.Function] A module containing only fused primitive functions """ - ret_mod = _analysis.ExtractFusedFunctions()(mod) + ret_mod = _ffi_api.ExtractFusedFunctions()(mod) ret = {} for hash_, func in ret_mod.functions.items(): ret[hash_] = func diff --git a/python/tvm/relay/call_graph.py b/python/tvm/relay/analysis/call_graph.py similarity index 88% rename from python/tvm/relay/call_graph.py rename to python/tvm/relay/analysis/call_graph.py index 8206f5dccd4c..966659aac494 100644 --- a/python/tvm/relay/call_graph.py +++ b/python/tvm/relay/analysis/call_graph.py @@ -18,9 +18,9 @@ """Call graph used in Relay.""" from tvm.ir import IRModule -from .base import Object -from .expr import GlobalVar -from . import _analysis +from tvm.runtime import Object +from ..expr import GlobalVar +from . import _ffi_api class CallGraph(Object): @@ -39,7 +39,7 @@ def __init__(self, module): call_graph: CallGraph A constructed call graph. """ - self.__init_handle_by_constructor__(_analysis.CallGraph, module) + self.__init_handle_by_constructor__(_ffi_api.CallGraph, module) @property def module(self): @@ -54,7 +54,7 @@ def module(self): ret : tvm.ir.IRModule The contained IRModule """ - return _analysis.GetModule(self) + return _ffi_api.GetModule(self) def ref_count(self, var): """Return the number of references to the global var @@ -69,7 +69,7 @@ def ref_count(self, var): The number reference to the global var """ var = self._get_global_var(var) - return _analysis.GetRefCountGlobalVar(self, var) + return _ffi_api.GetRefCountGlobalVar(self, var) def global_call_count(self, var): """Return the number of global function calls from a given global var. @@ -84,7 +84,7 @@ def global_call_count(self, var): The number of global function calls from the given var. """ var = self._get_global_var(var) - return _analysis.GetGlobalVarCallCount(self, var) + return _ffi_api.GetGlobalVarCallCount(self, var) def is_recursive(self, var): """Return if the function corresponding to a var is a recursive @@ -100,7 +100,7 @@ def is_recursive(self, var): If the function corresponding to var is recurisve. """ var = self._get_global_var(var) - return _analysis.IsRecursive(self, var) + return _ffi_api.IsRecursive(self, var) def _get_global_var(self, var): """Return the global var using a given name or GlobalVar. @@ -137,8 +137,8 @@ def print_var(self, var): The call graph represented in string. """ var = self._get_global_var(var) - return _analysis.PrintCallGraphGlobalVar(self, var) + return _ffi_api.PrintCallGraphGlobalVar(self, var) def __str__(self): """Print the call graph in the topological order.""" - return _analysis.PrintCallGraph(self) + return _ffi_api.PrintCallGraph(self) diff --git a/python/tvm/relay/feature.py b/python/tvm/relay/analysis/feature.py similarity index 100% rename from python/tvm/relay/feature.py rename to python/tvm/relay/analysis/feature.py diff --git a/python/tvm/relay/annotation.py b/python/tvm/relay/annotation.py deleted file mode 100644 index 5a4065313dcc..000000000000 --- a/python/tvm/relay/annotation.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Annotation related operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.annotation import * diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index a51e4f7bad11..03d91d5beb0f 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -22,7 +22,7 @@ import numpy as np import tvm from tvm import te -from ..base import register_relay_node, Object +from tvm.runtime import Object from ... import target as _target from ... import autotvm from .. import expr as _expr @@ -33,7 +33,7 @@ logger = logging.getLogger('compile_engine') -@register_relay_node +@tvm._ffi.register_object("relay.LoweredOutput") class LoweredOutput(Object): """Lowered output""" def __init__(self, outputs, implement): @@ -41,7 +41,7 @@ def __init__(self, outputs, implement): _backend._make_LoweredOutput, outputs, implement) -@register_relay_node +@tvm._ffi.register_object("relay.CCacheKey") class CCacheKey(Object): """Key in the CompileEngine. @@ -58,7 +58,7 @@ def __init__(self, source_func, target): _backend._make_CCacheKey, source_func, target) -@register_relay_node +@tvm._ffi.register_object("relay.CCacheValue") class CCacheValue(Object): """Value in the CompileEngine, including usage statistics. """ @@ -261,7 +261,7 @@ def lower_call(call, inputs, target): return LoweredOutput(outputs, best_impl) -@register_relay_node +@tvm._ffi.register_object("relay.CompileEngine") class CompileEngine(Object): """CompileEngine to get lowered code. """ diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 18f848c212b2..ab39f7c56446 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -20,25 +20,25 @@ import numpy as np -from tvm.runtime import container +import tvm._ffi +from tvm.runtime import container, Object from tvm.ir import IRModule from . import _backend from .. import _make, analysis, transform from ... import nd -from ..base import Object, register_relay_node from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder -@register_relay_node +@tvm._ffi.register_object("relay.ConstructorValue") class ConstructorValue(Object): def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( _make.ConstructorValue, tag, fields, constructor) -@register_relay_node +@tvm._ffi.register_object("relay.RefValue") class RefValue(Object): def __init__(self, value): self.__init_handle_by_constructor__( diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 0d6f22f446cd..2c35681deb80 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -21,47 +21,17 @@ from tvm.runtime import Object from tvm.ir import SourceName, Span, Node as RelayNode -from . import _make -from . import _expr -from . import _base __STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") + @tvm._ffi.register_func("tvm.relay.std_path") def _std_path(): return __STD_PATH__ -def register_relay_node(type_key=None): - """Register a Relay node type. - - Parameters - ---------- - type_key : str or cls - The type key of the node. - """ - if not isinstance(type_key, str): - return tvm._ffi.register_object( - "relay." + type_key.__name__)(type_key) - return tvm._ffi.register_object(type_key) - - -def register_relay_attr_node(type_key=None): - """Register a Relay attribute node. - - Parameters - ---------- - type_key : str or cls - The type key of the node. - """ - if not isinstance(type_key, str): - return tvm._ffi.register_object( - "relay.attrs." + type_key.__name__)(type_key) - return tvm._ffi.register_object(type_key) - - -@register_relay_node +@tvm._ffi.register_object("relay.Id") class Id(Object): """Unique identifier(name) used in Var. Guaranteed to be stable across all passes. diff --git a/python/tvm/relay/contrib.py b/python/tvm/relay/contrib.py deleted file mode 100644 index d22c67614999..000000000000 --- a/python/tvm/relay/contrib.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Contrib operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.contrib import * diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 61a5fb7c63ba..380cdf7d90ef 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -20,13 +20,13 @@ from numbers import Number as _Number import numpy as _np +import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, convert, ndarray as _nd from tvm.ir import RelayExpr, GlobalVar, BaseFunc -from .base import RelayNode, register_relay_node -from . import _make -from . import _expr +from .base import RelayNode +from . import _ffi_api from . import ty as _ty # alias relay expr as Expr. @@ -54,7 +54,7 @@ def astype(self, dtype): result : tvm.relay.Expr The result expression. """ - return _make.cast(self, dtype) + return _ffi_api.cast(self, dtype) def __neg__(self): return _op_make.negative(self) @@ -160,7 +160,7 @@ def __call__(self, *args): """ return Call(self, args) -@register_relay_node +@tvm._ffi.register_object("relay.Constant") class Constant(ExprWithOp): """A constant expression in Relay. @@ -170,10 +170,10 @@ class Constant(ExprWithOp): The data content of the constant expression. """ def __init__(self, data): - self.__init_handle_by_constructor__(_make.Constant, data) + self.__init_handle_by_constructor__(_ffi_api.Constant, data) -@register_relay_node +@tvm._ffi.register_object("relay.Tuple") class Tuple(ExprWithOp): """Tuple expression that groups several fields together. @@ -183,7 +183,7 @@ class Tuple(ExprWithOp): The fields in the tuple. """ def __init__(self, fields): - self.__init_handle_by_constructor__(_make.Tuple, fields) + self.__init_handle_by_constructor__(_ffi_api.Tuple, fields) def __getitem__(self, index): if index >= len(self): @@ -197,7 +197,7 @@ def astype(self, _): raise TypeError("astype cannot be used on tuple") -@register_relay_node +@tvm._ffi.register_object("relay.Var") class Var(ExprWithOp): """A local variable in Relay. @@ -216,7 +216,7 @@ class Var(ExprWithOp): """ def __init__(self, name_hint, type_annotation=None): self.__init_handle_by_constructor__( - _make.Var, name_hint, type_annotation) + _ffi_api.Var, name_hint, type_annotation) @property def name_hint(self): @@ -225,7 +225,7 @@ def name_hint(self): return name -@register_relay_node +@tvm._ffi.register_object("relay.Function") class Function(BaseFunc): """A function declaration expression. @@ -254,7 +254,7 @@ def __init__(self, type_params = convert([]) self.__init_handle_by_constructor__( - _make.Function, params, body, ret_type, type_params, attrs) + _ffi_api.Function, params, body, ret_type, type_params, attrs) def __call__(self, *args): """Invoke the global function. @@ -282,12 +282,12 @@ def with_attr(self, attr_key, attr_value): func : Function A new copy of the function """ - return _expr.FunctionWithAttr( + return _ffi_api.FunctionWithAttr( self, attr_key, convert(attr_value)) -@register_relay_node +@tvm._ffi.register_object("relay.Call") class Call(ExprWithOp): """Function call node in Relay. @@ -313,10 +313,10 @@ def __init__(self, op, args, attrs=None, type_args=None): if not type_args: type_args = [] self.__init_handle_by_constructor__( - _make.Call, op, args, attrs, type_args) + _ffi_api.Call, op, args, attrs, type_args) -@register_relay_node +@tvm._ffi.register_object("relay.Let") class Let(ExprWithOp): """Let variable binding expression. @@ -333,10 +333,10 @@ class Let(ExprWithOp): """ def __init__(self, variable, value, body): self.__init_handle_by_constructor__( - _make.Let, variable, value, body) + _ffi_api.Let, variable, value, body) -@register_relay_node +@tvm._ffi.register_object("relay.If") class If(ExprWithOp): """A conditional expression in Relay. @@ -353,10 +353,10 @@ class If(ExprWithOp): """ def __init__(self, cond, true_branch, false_branch): self.__init_handle_by_constructor__( - _make.If, cond, true_branch, false_branch) + _ffi_api.If, cond, true_branch, false_branch) -@register_relay_node +@tvm._ffi.register_object("relay.TupleGetItem") class TupleGetItem(ExprWithOp): """Get index-th item from a tuple. @@ -370,10 +370,10 @@ class TupleGetItem(ExprWithOp): """ def __init__(self, tuple_value, index): self.__init_handle_by_constructor__( - _make.TupleGetItem, tuple_value, index) + _ffi_api.TupleGetItem, tuple_value, index) -@register_relay_node +@tvm._ffi.register_object("relay.RefCreate") class RefCreate(ExprWithOp): """Create a new reference from initial value. Parameters @@ -382,10 +382,10 @@ class RefCreate(ExprWithOp): The initial value. """ def __init__(self, value): - self.__init_handle_by_constructor__(_make.RefCreate, value) + self.__init_handle_by_constructor__(_ffi_api.RefCreate, value) -@register_relay_node +@tvm._ffi.register_object("relay.RefRead") class RefRead(ExprWithOp): """Get the value inside the reference. Parameters @@ -394,10 +394,10 @@ class RefRead(ExprWithOp): The reference. """ def __init__(self, ref): - self.__init_handle_by_constructor__(_make.RefRead, ref) + self.__init_handle_by_constructor__(_ffi_api.RefRead, ref) -@register_relay_node +@tvm._ffi.register_object("relay.RefWrite") class RefWrite(ExprWithOp): """ Update the value inside the reference. @@ -410,7 +410,7 @@ class RefWrite(ExprWithOp): The new value. """ def __init__(self, ref, value): - self.__init_handle_by_constructor__(_make.RefWrite, ref, value) + self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value) class TempExpr(ExprWithOp): @@ -427,7 +427,7 @@ def realize(self): ------- The corresponding normal expression. """ - return _expr.TempExprRealize(self) + return _ffi_api.TempExprRealize(self) class TupleWrapper(object): @@ -587,4 +587,4 @@ def bind(expr, binds): result : tvm.relay.Expr The expression or function after binding. """ - return _expr.Bind(expr, binds) + return _ffi_api.Bind(expr, binds) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b4891c3d5d96..d5def8a68e0d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -26,8 +26,8 @@ from .. import expr as _expr from .. import op as _op from .. import qnn as _qnn -from ..util import get_scalar_from_constant from ... import nd as _nd +from .util import get_scalar_from_constant from .common import ExprTable from .common import infer_shape as _infer_shape diff --git a/python/tvm/relay/util.py b/python/tvm/relay/frontend/util.py similarity index 98% rename from python/tvm/relay/util.py rename to python/tvm/relay/frontend/util.py index b207182e4113..a7f89a30b996 100644 --- a/python/tvm/relay/util.py +++ b/python/tvm/relay/frontend/util.py @@ -18,7 +18,7 @@ """ Utility functions that are used across many directories. """ from __future__ import absolute_import import numpy as np -from . import expr as _expr +from .. import expr as _expr def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ diff --git a/python/tvm/relay/image.py b/python/tvm/relay/image.py deleted file mode 100644 index 4d5cc5a47448..000000000000 --- a/python/tvm/relay/image.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Image network related operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.image import * diff --git a/python/tvm/relay/nn.py b/python/tvm/relay/nn.py deleted file mode 100644 index 2070eb9181e6..000000000000 --- a/python/tvm/relay/nn.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Neural network related operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.nn import * diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 1a1d0d3ff7ed..b3054d67885b 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -41,7 +41,6 @@ from . import _transform from . import _reduce from . import _algorithm -from ..base import register_relay_node def _register_op_make(): diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 4cd4b2a2a465..e6bd6bf230dd 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -19,13 +19,12 @@ import tvm._ffi from tvm.driver import lower, build -from ..base import register_relay_node from ..expr import RelayExpr from ...target import get_native_generic_func, GenericFunc from ...runtime import Object from . import _make -@register_relay_node +@tvm._ffi.register_object("relay.Op") class Op(RelayExpr): """A Relay operator definition.""" diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 12abf4a787db..9a5fb5592e90 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -15,315 +15,314 @@ # specific language governing permissions and limitations # under the License. """The attributes node used for Relay operators""" - from tvm.ir import Attrs -from ..base import register_relay_attr_node +import tvm._ffi -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DAttrs") class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DWinogradAttrs") class Conv2DWinogradAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_without_weight_transform""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DWinogradWeightTransformAttrs") class Conv2DWinogradWeightTransformAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_weight_transform""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_nnpack_weight_transform""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.GlobalPool2DAttrs") class GlobalPool2DAttrs(Attrs): """Attributes for nn.global_pool""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BiasAddAttrs") class BiasAddAttrs(Attrs): """Atttribute of nn.bias_add""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.DenseAttrs") class DenseAttrs(Attrs): """Attributes for nn.dense""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.FIFOBufferAttrs") class FIFOBufferAttrs(Attrs): """Attributes for nn.fifo_buffer""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.UpSamplingAttrs") class UpSamplingAttrs(Attrs): """Attributes for nn.upsampling""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.UpSampling3DAttrs") class UpSampling3DAttrs(Attrs): """Attributes for nn.upsampling3d""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.PadAttrs") class PadAttrs(Attrs): """Attributes for nn.pad""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MirrorPadAttrs") class MirrorPadAttrs(Attrs): """Attributes for nn.mirror_pad""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.LeakyReluAttrs") class LeakyReluAttrs(Attrs): """Attributes for nn.leaky_relu""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.PReluAttrs") class PReluAttrs(Attrs): """Attributes for nn.prelu""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.DropoutAttrs") class DropoutAttrs(Attrs): """Attributes for nn.dropout""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BatchNormAttrs") class BatchNormAttrs(Attrs): """Attributes for nn.batch_norm""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.LRNAttrs") class LRNAttrs(Attrs): """Attributes for nn.lrn""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.L2NormalizeAttrs") class L2NormalizeAttrs(Attrs): """Attributes for nn.l2_normalize""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.DeformableConv2DAttrs") class DeformableConv2DAttrs(Attrs): """Attributes for nn.deformable_conv2d""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ResizeAttrs") class ResizeAttrs(Attrs): """Attributes for image.resize""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.CropAndResizeAttrs") class CropAndResizeAttrs(Attrs): """Attributes for image.crop_and_resize""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ArgsortAttrs") class ArgsortAttrs(Attrs): """Attributes for algorithm.argsort""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.OnDeviceAttrs") class OnDeviceAttrs(Attrs): """Attributes for annotation.on_device""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.DebugAttrs") class DebugAttrs(Attrs): """Attributes for debug""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.OnDeviceAttrs") class DeviceCopyAttrs(Attrs): """Attributes for tensor.device_copy""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.CastAttrs") class CastAttrs(Attrs): """Attributes for transform.cast""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ConcatenateAttrs") class ConcatenateAttrs(Attrs): """Attributes for tensor.concatenate""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.TransposeAttrs") class TransposeAttrs(Attrs): """Attributes for transform.transpose""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ReshapeAttrs") class ReshapeAttrs(Attrs): """Attributes for transform.reshape""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes for transform.take""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.InitOpAttrs") class InitOpAttrs(Attrs): """Attributes for ops specifying a tensor""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ArangeAttrs") class ArangeAttrs(Attrs): """Attributes used in arange operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.StackAttrs") class StackAttrs(Attrs): """Attributes used in stack operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.RepeatAttrs") class RepeatAttrs(Attrs): """Attributes used in repeat operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.TileAttrs") class TileAttrs(Attrs): """Attributes used in tile operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ReverseAttrs") class ReverseAttrs(Attrs): """Attributes used in reverse operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.SqueezeAttrs") class SqueezeAttrs(Attrs): """Attributes used in squeeze operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.SplitAttrs") class SplitAttrs(Attrs): """Attributes for transform.split""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.StridedSliceAttrs") class StridedSliceAttrs(Attrs): """Attributes for transform.stranded_slice""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.SliceLikeAttrs") class SliceLikeAttrs(Attrs): """Attributes for transform.slice_like""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ClipAttrs") class ClipAttrs(Attrs): """Attributes for transform.clip""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes for transform.layout_transform""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ShapeOfAttrs") class ShapeOfAttrs(Attrs): """Attributes for tensor.shape_of""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MultiBoxPriorAttrs") class MultiBoxPriorAttrs(Attrs): """Attributes for vision.multibox_prior""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MultiBoxTransformLocAttrs") class MultiBoxTransformLocAttrs(Attrs): """Attributes for vision.multibox_transform_loc""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.GetValidCountsAttrs") class GetValidCountsAttrs(Attrs): """Attributes for vision.get_valid_counts""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.NonMaximumSuppressionAttrs") class NonMaximumSuppressionAttrs(Attrs): """Attributes for vision.non_maximum_suppression""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ROIAlignAttrs") class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ROIPoolAttrs") class ROIPoolAttrs(Attrs): """Attributes for vision.roi_pool""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.YoloReorgAttrs") class YoloReorgAttrs(Attrs): """Attributes for vision.yolo_reorg""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ProposalAttrs") class ProposalAttrs(Attrs): """Attributes used in proposal operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MaxPool2DAttrs") class MaxPool2DAttrs(Attrs): """Attributes used in max_pool2d operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.AvgPool2DAttrs") class AvgPool2DAttrs(Attrs): """Attributes used in avg_pool2d operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MaxPool1DAttrs") class MaxPool1DAttrs(Attrs): """Attributes used in max_pool1d operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.AvgPool1DAttrs") class AvgPool1DAttrs(Attrs): """Attributes used in avg_pool1d operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MaxPool3DAttrs") class MaxPool3DAttrs(Attrs): """Attributes used in max_pool3d operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.AvgPool3DAttrs") class AvgPool3DAttrs(Attrs): """Attributes used in avg_pool3d operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BitPackAttrs") class BitPackAttrs(Attrs): """Attributes used in bitpack operator""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BinaryConv2DAttrs") class BinaryConv2DAttrs(Attrs): """Attributes used in bitserial conv2d operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BinaryDenseAttrs") class BinaryDenseAttrs(Attrs): """Attributes used in bitserial dense operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DTransposeAttrs") class Conv2DTransposeAttrs(Attrs): """Attributes used in Transposed Conv2D operators""" -@register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.SubPixelAttrs") class SubPixelAttrs(Attrs): """Attributes used in depth to space and space to depth operators""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 0955978f81a0..6a30eb2b7ec9 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -38,7 +38,7 @@ def cast(data, dtype): result : relay.Expr The casted result. """ - from .. import _make as _relay_make + from .. import _ffi_api as _relay_make return _relay_make.cast(data, dtype) @@ -55,7 +55,7 @@ def cast_like(data, dtype_like): result : relay.Expr The casted result. """ - from .. import _make as _relay_make + from .. import _ffi_api as _relay_make return _relay_make.cast_like(data, dtype_like) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index f9874b78467e..b1c19092b4c7 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -21,7 +21,7 @@ import tvm from tvm import relay from .. import op as reg -from ...util import get_scalar_from_constant +from ...frontend.util import get_scalar_from_constant ################################################# # Register the functions for different operators. diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index b77516de6839..2658a0aa7dad 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -24,7 +24,6 @@ from .. import analysis as _analysis from .. import op as _op from ..op import op as _reg -from ..base import register_relay_node from . import _quantize from .quantize import QAnnotateKind, current_qconfig, quantize_context from .quantize import _forward_op @@ -58,7 +57,7 @@ def simulated_quantize_compute(attrs, inputs, out_type): _reg.register_injective_schedule("annotation.cast_hint") -@register_relay_node +@tvm._ffi.register_object("relay.QAnnotateExpr") class QAnnotateExpr(_expr.TempExpr): """A special kind of Expr for Annotating. diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index fbac767cea24..bb3db99eed79 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -19,7 +19,6 @@ import tvm from .. import expr as _expr from .. import analysis as _analysis -from ..base import register_relay_node from ..op import op as _reg from . import _quantize from .quantize import _forward_op @@ -30,7 +29,7 @@ def _register(func): return _register(frewrite) if frewrite is not None else _register -@register_relay_node +@tvm._ffi.register_object("relay.QPartitionExpr") class QPartitionExpr(_expr.TempExpr): def __init__(self, expr): self.__init_handle_by_constructor__( diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 56a4645058e5..2ad4e18771d7 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -17,12 +17,12 @@ #pylint: disable=unused-argument, not-context-manager """Automatic quantization toolkit.""" import tvm.ir +from tvm.runtime import Object from . import _quantize from ._calibrate import calibrate from .. import expr as _expr from .. import transform as _transform -from ..base import Object, register_relay_node class QAnnotateKind(object): @@ -52,7 +52,7 @@ def _forward_op(ref_call, args): ref_call.op, args, ref_call.attrs, ref_call.type_args) -@register_relay_node("relay.quantize.QConfig") +@tvm._ffi.register_object("relay.quantize.QConfig") class QConfig(Object): """Configure the quantization behavior by setting config variables. diff --git a/python/tvm/relay/_expr.py b/python/tvm/relay/transform/__init__.py similarity index 79% rename from python/tvm/relay/_expr.py rename to python/tvm/relay/transform/__init__.py index 70c13ce4eaa8..93d4341635a0 100644 --- a/python/tvm/relay/_expr.py +++ b/python/tvm/relay/transform/__init__.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable -"""The interface of expr function exposed from C++.""" -import tvm._ffi +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay IR namespace containing transformations.""" +# transformation passes +from .transform import * -tvm._ffi._init_api("relay._expr", __name__) +from . import memory_alloc diff --git a/python/tvm/relay/_transform.py b/python/tvm/relay/transform/_ffi_api.py similarity index 93% rename from python/tvm/relay/_transform.py rename to python/tvm/relay/transform/_ffi_api.py index a4168dfb5c0c..32c79cb6b2a3 100644 --- a/python/tvm/relay/_transform.py +++ b/python/tvm/relay/transform/_ffi_api.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI exposing the Relay type inference and checking.""" +"""FFI APIs for Relay transformation passes.""" import tvm._ffi tvm._ffi._init_api("relay._transform", __name__) diff --git a/python/tvm/relay/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py similarity index 98% rename from python/tvm/relay/memory_alloc.py rename to python/tvm/relay/transform/memory_alloc.py index f8e981121031..c238730807d3 100644 --- a/python/tvm/relay/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -19,12 +19,13 @@ A pass for manifesting explicit memory allocations. """ import numpy as np -from .expr_functor import ExprMutator -from .scope_builder import ScopeBuilder +from ..expr_functor import ExprMutator +from ..scope_builder import ScopeBuilder from . import transform -from . import op, ty, expr -from .. import DataType, register_func -from .backend import compile_engine +from .. import op +from ... import DataType, register_func +from .. import ty, expr +from ..backend import compile_engine def is_primitive(call): diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform/transform.py similarity index 92% rename from python/tvm/relay/transform.py rename to python/tvm/relay/transform/transform.py index b2565f3f97eb..43a116e64e5b 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -28,8 +28,7 @@ from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass from tvm import relay -from . import _transform -from .base import register_relay_node +from . import _ffi_api def build_config(opt_level=2, @@ -83,7 +82,7 @@ def build_config(opt_level=2, disabled_pass, trace) -@register_relay_node +@tvm._ffi.register_object("relay.FunctionPass") class FunctionPass(Pass): """A pass that works on each tvm.relay.Function in a module. A function pass class should be created through `function_pass`. @@ -98,7 +97,7 @@ def InferType(): ret : tvm.relay.Pass The registered type inference pass. """ - return _transform.InferType() + return _ffi_api.InferType() def FoldScaleAxis(): @@ -116,7 +115,7 @@ def FoldScaleAxis(): forward_fold_scale_axis as backward folding targets the common conv->bn pattern. """ - return _transform.FoldScaleAxis() + return _ffi_api.FoldScaleAxis() def BackwardFoldScaleAxis(): @@ -133,7 +132,7 @@ def BackwardFoldScaleAxis(): before using forward_fold_scale_axis as backward folding targets the common conv->bn pattern. """ - return _transform.BackwardFoldScaleAxis() + return _ffi_api.BackwardFoldScaleAxis() def RemoveUnusedFunctions(entry_functions=None): """Remove unused global relay functions in a relay module. @@ -150,7 +149,7 @@ def RemoveUnusedFunctions(entry_functions=None): """ if entry_functions is None: entry_functions = ['main'] - return _transform.RemoveUnusedFunctions(entry_functions) + return _ffi_api.RemoveUnusedFunctions(entry_functions) def ForwardFoldScaleAxis(): """Fold the scaling of axis into weights of conv2d/dense. @@ -166,7 +165,7 @@ def ForwardFoldScaleAxis(): before using forward_fold_scale_axis, as backward folding targets the common conv->bn pattern. """ - return _transform.ForwardFoldScaleAxis() + return _ffi_api.ForwardFoldScaleAxis() def SimplifyInference(): @@ -178,7 +177,7 @@ def SimplifyInference(): ret: tvm.relay.Pass The registered pass to perform operator simplification. """ - return _transform.SimplifyInference() + return _ffi_api.SimplifyInference() def FastMath(): @@ -189,7 +188,7 @@ def FastMath(): ret: tvm.relay.Pass The registered pass to perform fast math operations. """ - return _transform.FastMath() + return _ffi_api.FastMath() def CanonicalizeOps(): @@ -202,7 +201,7 @@ def CanonicalizeOps(): ret: tvm.relay.Pass The registered pass performing the canonicalization. """ - return _transform.CanonicalizeOps() + return _ffi_api.CanonicalizeOps() def DeadCodeElimination(inline_once=False): @@ -218,7 +217,7 @@ def DeadCodeElimination(inline_once=False): ret: tvm.relay.Pass The registered pass that eliminates the dead code in a Relay program. """ - return _transform.DeadCodeElimination(inline_once) + return _ffi_api.DeadCodeElimination(inline_once) def FoldConstant(): @@ -229,7 +228,7 @@ def FoldConstant(): ret : tvm.relay.Pass The registered pass for constant folding. """ - return _transform.FoldConstant() + return _ffi_api.FoldConstant() def FuseOps(fuse_opt_level=-1): @@ -246,7 +245,7 @@ def FuseOps(fuse_opt_level=-1): ret : tvm.relay.Pass The registered pass for operator fusion. """ - return _transform.FuseOps(fuse_opt_level) + return _ffi_api.FuseOps(fuse_opt_level) def CombineParallelConv2D(min_num_branches=3): @@ -263,7 +262,7 @@ def CombineParallelConv2D(min_num_branches=3): ret: tvm.relay.Pass The registered pass that combines parallel conv2d operators. """ - return _transform.CombineParallelConv2D(min_num_branches) + return _ffi_api.CombineParallelConv2D(min_num_branches) def CombineParallelDense(min_num_branches=3): @@ -295,7 +294,7 @@ def CombineParallelDense(min_num_branches=3): ret: tvm.relay.Pass The registered pass that combines parallel dense operators. """ - return _transform.CombineParallelDense(min_num_branches) + return _ffi_api.CombineParallelDense(min_num_branches) def AlterOpLayout(): @@ -309,7 +308,7 @@ def AlterOpLayout(): ret : tvm.relay.Pass The registered pass that alters the layout of operators. """ - return _transform.AlterOpLayout() + return _ffi_api.AlterOpLayout() def ConvertLayout(desired_layout): @@ -337,7 +336,7 @@ def ConvertLayout(desired_layout): pass: FunctionPass The pass. """ - return _transform.ConvertLayout(desired_layout) + return _ffi_api.ConvertLayout(desired_layout) def Legalize(legalize_map_attr_name="FTVMLegalize"): @@ -357,7 +356,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): ret : tvm.relay.Pass The registered pass that rewrites an expr. """ - return _transform.Legalize(legalize_map_attr_name) + return _ffi_api.Legalize(legalize_map_attr_name) def MergeComposite(pattern_table): @@ -382,7 +381,7 @@ def MergeComposite(pattern_table): pattern_names.append(pattern_name) patterns.append(pattern) - return _transform.MergeComposite(pattern_names, patterns) + return _ffi_api.MergeComposite(pattern_names, patterns) def RewriteAnnotatedOps(fallback_device): @@ -403,7 +402,7 @@ def RewriteAnnotatedOps(fallback_device): The registered pass that rewrites an expression with annotated `on_device` operators. """ - return _transform.RewriteDeviceAnnotation(fallback_device) + return _ffi_api.RewriteDeviceAnnotation(fallback_device) def ToANormalForm(): @@ -417,7 +416,7 @@ def ToANormalForm(): ret: Union[tvm.relay.Pass, tvm.relay.Expr] The registered pass that transforms an expression into A Normal Form. """ - return _transform.ToANormalForm() + return _ffi_api.ToANormalForm() def ToCPS(expr, mod=None): @@ -431,7 +430,7 @@ def ToCPS(expr, mod=None): result: tvm.relay.Pass The registered pass that transforms an expression into CPS. """ - return _transform.to_cps(expr, mod) + return _ffi_api.to_cps(expr, mod) def EtaExpand(expand_constructor=False, expand_global_var=False): @@ -450,7 +449,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False): ret: tvm.relay.Pass The registered pass that eta expands an expression. """ - return _transform.EtaExpand(expand_constructor, expand_global_var) + return _ffi_api.EtaExpand(expand_constructor, expand_global_var) def ToGraphNormalForm(): @@ -461,7 +460,7 @@ def ToGraphNormalForm(): ret : tvm.relay.Pass The registered pass that transforms an expression into Graph Normal Form. """ - return _transform.ToGraphNormalForm() + return _ffi_api.ToGraphNormalForm() def EliminateCommonSubexpr(fskip=None): @@ -478,7 +477,7 @@ def EliminateCommonSubexpr(fskip=None): ret : tvm.relay.Pass The registered pass that eliminates common subexpressions. """ - return _transform.EliminateCommonSubexpr(fskip) + return _ffi_api.EliminateCommonSubexpr(fskip) def PartialEvaluate(): @@ -496,7 +495,7 @@ def PartialEvaluate(): ret: tvm.relay.Pass The registered pass that performs partial evaluation on an expression. """ - return _transform.PartialEvaluate() + return _ffi_api.PartialEvaluate() def CanonicalizeCast(): @@ -508,7 +507,7 @@ def CanonicalizeCast(): ret : tvm.relay.Pass The registered pass that canonicalizes cast expression. """ - return _transform.CanonicalizeCast() + return _ffi_api.CanonicalizeCast() def LambdaLift(): @@ -520,7 +519,7 @@ def LambdaLift(): ret : tvm.relay.Pass The registered pass that lifts the lambda function. """ - return _transform.LambdaLift() + return _ffi_api.LambdaLift() def PrintIR(show_meta_data=True): @@ -537,7 +536,7 @@ def PrintIR(show_meta_data=True): ret : tvm.relay.Pass The registered pass that prints the module IR. """ - return _transform.PrintIR(show_meta_data) + return _ffi_api.PrintIR(show_meta_data) def PartitionGraph(): @@ -549,7 +548,7 @@ def PartitionGraph(): ret: tvm.relay.Pass The registered pass that partitions the Relay program. """ - return _transform.PartitionGraph() + return _ffi_api.PartitionGraph() @@ -568,7 +567,7 @@ def AnnotateTarget(target): The annotated pass that wrapps ops with subgraph_start and subgraph_end. """ - return _transform.AnnotateTarget(target) + return _ffi_api.AnnotateTarget(target) def Inline(): @@ -581,7 +580,7 @@ def Inline(): ret: tvm.relay.Pass The registered pass that performs inlining for a Relay IR module. """ - return _transform.Inline() + return _ffi_api.Inline() def gradient(expr, mod=None, mode='higher_order'): @@ -609,9 +608,9 @@ def gradient(expr, mod=None, mode='higher_order'): The transformed expression. """ if mode == 'first_order': - return _transform.first_order_gradient(expr, mod) + return _ffi_api.first_order_gradient(expr, mod) if mode == 'higher_order': - return _transform.gradient(expr, mod) + return _ffi_api.gradient(expr, mod) raise Exception('unknown mode') @@ -634,7 +633,7 @@ def to_cps(func, mod=None): result: tvm.relay.Function The output function. """ - return _transform.to_cps(func, mod) + return _ffi_api.to_cps(func, mod) def un_cps(func): @@ -654,7 +653,7 @@ def un_cps(func): result: tvm.relay.Function The output function """ - return _transform.un_cps(func) + return _ffi_api.un_cps(func) def _wrap_class_function_pass(pass_cls, pass_info): @@ -670,7 +669,7 @@ def __init__(self, *args, **kwargs): def _pass_func(func, mod, ctx): return inst.transform_function(func, mod, ctx) self.__init_handle_by_constructor__( - _transform.MakeFunctionPass, _pass_func, pass_info) + _ffi_api.MakeFunctionPass, _pass_func, pass_info) self._inst = inst def __getattr__(self, name): @@ -778,7 +777,7 @@ def create_function_pass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): raise TypeError("pass_func must be a callable for Module pass") - return _transform.MakeFunctionPass(pass_arg, info) + return _ffi_api.MakeFunctionPass(pass_arg, info) if pass_func: return create_function_pass(pass_func) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 13d7f9197e79..19cc10aba41e 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -20,10 +20,10 @@ from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType -from .base import RelayNode, register_relay_node -from . import _make +from .base import RelayNode +from . import _ffi_api -Any = _make.Any +Any = _ffi_api.Any def type_has_any(tensor_type): """Check whether type has any as a shape. @@ -36,7 +36,7 @@ def type_has_any(tensor_type): has_any : bool The check result. """ - return _make.IsDynamic(tensor_type) + return _ffi_api.IsDynamic(tensor_type) def ShapeVar(name): diff --git a/python/tvm/relay/vision.py b/python/tvm/relay/vision.py deleted file mode 100644 index f428295f03f4..000000000000 --- a/python/tvm/relay/vision.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Vision network related operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.vision import * diff --git a/src/relay/analysis/alpha_equal.cc b/src/relay/analysis/alpha_equal.cc index 726ccbbb3411..540284848d7c 100644 --- a/src/relay/analysis/alpha_equal.cc +++ b/src/relay/analysis/alpha_equal.cc @@ -581,7 +581,7 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs); } -TVM_REGISTER_GLOBAL("relay._analysis._alpha_equal") +TVM_REGISTER_GLOBAL("relay.analysis._alpha_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(false, false).Equal(a, b); }); @@ -591,18 +591,18 @@ TVM_REGISTER_GLOBAL("ir.type_alpha_equal") return AlphaEqual(a, b); }); -TVM_REGISTER_GLOBAL("relay._analysis._assert_alpha_equal") +TVM_REGISTER_GLOBAL("relay.analysis._assert_alpha_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; }); -TVM_REGISTER_GLOBAL("relay._analysis._graph_equal") +TVM_REGISTER_GLOBAL("relay.analysis._graph_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(true, false).Equal(a, b); }); -TVM_REGISTER_GLOBAL("relay._analysis._assert_graph_equal") +TVM_REGISTER_GLOBAL("relay.analysis._assert_graph_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; diff --git a/src/relay/analysis/call_graph.cc b/src/relay/analysis/call_graph.cc index b9ee8e148894..a12d23d88a30 100644 --- a/src/relay/analysis/call_graph.cc +++ b/src/relay/analysis/call_graph.cc @@ -299,24 +299,24 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "CallGraph: \n" << GetRef(node); }); -TVM_REGISTER_GLOBAL("relay._analysis.CallGraph") +TVM_REGISTER_GLOBAL("relay.analysis.CallGraph") .set_body_typed([](IRModule module) { return CallGraph(module); }); -TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraph") +TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph") .set_body_typed([](CallGraph call_graph) { std::stringstream ss; ss << call_graph; return ss.str(); }); -TVM_REGISTER_GLOBAL("relay._analysis.GetModule") +TVM_REGISTER_GLOBAL("relay.analysis.GetModule") .set_body_typed([](CallGraph call_graph) { return call_graph->module; }); -TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar") +TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; std::stringstream ss; @@ -324,19 +324,19 @@ TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar") return ss.str(); }); -TVM_REGISTER_GLOBAL("relay._analysis.GetRefCountGlobalVar") +TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return static_cast(entry_node->GetRefCount()); }); -TVM_REGISTER_GLOBAL("relay._analysis.GetGlobalVarCallCount") +TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return static_cast(entry_node->size()); }); -TVM_REGISTER_GLOBAL("relay._analysis.IsRecursive") +TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return entry_node->IsRecursive(); diff --git a/src/relay/analysis/extract_fused_functions.cc b/src/relay/analysis/extract_fused_functions.cc index 3667d8a47826..8cb517f7e33d 100644 --- a/src/relay/analysis/extract_fused_functions.cc +++ b/src/relay/analysis/extract_fused_functions.cc @@ -74,7 +74,7 @@ Pass ExtractFusedFunctions() { "ExtractFusedFunctions"); } -TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions); +TVM_REGISTER_GLOBAL("relay.analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions); } // namespace transform diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index 4f0e829cdbde..95c2f731ff72 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -104,7 +104,7 @@ Array PyDetectFeature(const Expr& expr, const IRModule& mod) { return static_cast>(fs); } -TVM_REGISTER_GLOBAL("relay._analysis.detect_feature") +TVM_REGISTER_GLOBAL("relay.analysis.detect_feature") .set_body_typed(PyDetectFeature); } // namespace relay diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc index d43059cf9f6a..b4835ccb7a3c 100644 --- a/src/relay/analysis/kind_check.cc +++ b/src/relay/analysis/kind_check.cc @@ -186,7 +186,7 @@ Kind KindCheck(const Type& t, const IRModule& mod) { return kc.Check(t); } -TVM_REGISTER_GLOBAL("relay._analysis.check_kind") +TVM_REGISTER_GLOBAL("relay.analysis.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = KindCheck(args[0], IRModule({}, {})); diff --git a/src/relay/analysis/mac_count.cc b/src/relay/analysis/mac_count.cc index 49fe2a3900b3..fecde3c75669 100644 --- a/src/relay/analysis/mac_count.cc +++ b/src/relay/analysis/mac_count.cc @@ -206,7 +206,7 @@ int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } -TVM_REGISTER_GLOBAL("relay._analysis.GetTotalMacNumber") +TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber") .set_body_typed(GetTotalMacNumber); } // namespace mac_count diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index 14be6b751354..919065469a4d 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -310,7 +310,7 @@ Array UnmatchedCases(const Match& match, const IRModule& mod) { } // expose for testing only -TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases") +TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases") .set_body_typed( [](const Match& match, const IRModule& mod_ref) { IRModule call_mod = mod_ref; diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 8aa1ac9c5a8b..a6ac9ce9a7ec 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -659,7 +659,7 @@ bool TypeSolver::Solve() { } // Expose type solver only for debugging purposes. -TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") +TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 88c89dfee6b6..6a151d7d21f1 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -274,10 +274,10 @@ tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -TVM_REGISTER_GLOBAL("relay._analysis.free_vars") +TVM_REGISTER_GLOBAL("relay.analysis.free_vars") .set_body_typed(FreeVars); -TVM_REGISTER_GLOBAL("relay._analysis.bound_vars") +TVM_REGISTER_GLOBAL("relay.analysis.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; if (x.as()) { @@ -287,10 +287,10 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_vars") } }); -TVM_REGISTER_GLOBAL("relay._analysis.all_vars") +TVM_REGISTER_GLOBAL("relay.analysis.all_vars") .set_body_typed(AllVars); -TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") +TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; IRModule mod = args[1]; @@ -301,7 +301,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") } }); -TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") +TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; IRModule mod = args[1]; @@ -312,7 +312,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") } }); -TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars") +TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; IRModule mod = args[1]; diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index 72a6fcc8a6bb..f3a2cadb363f 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -125,7 +125,7 @@ bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); } -TVM_REGISTER_GLOBAL("relay._analysis.well_formed") +TVM_REGISTER_GLOBAL("relay.analysis.well_formed") .set_body_typed(WellFormed); } // namespace relay diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index d8a05f401bb5..ccbe4dfc858a 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -86,7 +86,7 @@ bool IsDynamic(const Type& ty) { } // TODO(@jroesch): MOVE ME -TVM_REGISTER_GLOBAL("relay._make.IsDynamic") +TVM_REGISTER_GLOBAL("relay.ir.IsDynamic") .set_body_typed(IsDynamic); Array GetShape(const Array& shape) { diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 29fc9abd0ea7..ff24825f14e8 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -34,7 +34,7 @@ PatternWildcard PatternWildcardNode::make() { TVM_REGISTER_NODE_TYPE(PatternWildcardNode); -TVM_REGISTER_GLOBAL("relay._make.PatternWildcard") +TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard") .set_body_typed(PatternWildcardNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -50,7 +50,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) { TVM_REGISTER_NODE_TYPE(PatternVarNode); -TVM_REGISTER_GLOBAL("relay._make.PatternVar") +TVM_REGISTER_GLOBAL("relay.ir.PatternVar") .set_body_typed(PatternVarNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -69,7 +69,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor, TVM_REGISTER_NODE_TYPE(PatternConstructorNode); -TVM_REGISTER_GLOBAL("relay._make.PatternConstructor") +TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor") .set_body_typed(PatternConstructorNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -87,7 +87,7 @@ PatternTuple PatternTupleNode::make(tvm::Array patterns) { TVM_REGISTER_NODE_TYPE(PatternTupleNode); -TVM_REGISTER_GLOBAL("relay._make.PatternTuple") +TVM_REGISTER_GLOBAL("relay.ir.PatternTuple") .set_body_typed(PatternTupleNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -105,7 +105,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) { TVM_REGISTER_NODE_TYPE(ClauseNode); -TVM_REGISTER_GLOBAL("relay._make.Clause") +TVM_REGISTER_GLOBAL("relay.ir.Clause") .set_body_typed(ClauseNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -125,7 +125,7 @@ Match MatchNode::make(Expr data, tvm::Array clauses, bool complete) { TVM_REGISTER_NODE_TYPE(MatchNode); -TVM_REGISTER_GLOBAL("relay._make.Match") +TVM_REGISTER_GLOBAL("relay.ir.Match") .set_body_typed(MatchNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index d9eb9108af50..5da5be3c43a7 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -38,7 +38,7 @@ Constant ConstantNode::make(runtime::NDArray data) { TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relay._make.Constant") +TVM_REGISTER_GLOBAL("relay.ir.Constant") .set_body_typed(ConstantNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -71,7 +71,7 @@ Tuple TupleNode::make(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_GLOBAL("relay._make.Tuple") +TVM_REGISTER_GLOBAL("relay.ir.Tuple") .set_body_typed(TupleNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -96,7 +96,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_GLOBAL("relay._make.Var") +TVM_REGISTER_GLOBAL("relay.ir.Var") .set_body_typed(static_cast(VarNode::make)); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -123,7 +123,7 @@ Call CallNode::make(Expr op, Array args, Attrs attrs, TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_GLOBAL("relay._make.Call") +TVM_REGISTER_GLOBAL("relay.ir.Call") .set_body_typed(CallNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -143,7 +143,7 @@ Let LetNode::make(Var var, Expr value, Expr body) { TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_GLOBAL("relay._make.Let") +TVM_REGISTER_GLOBAL("relay.ir.Let") .set_body_typed(LetNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -163,7 +163,7 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { TVM_REGISTER_NODE_TYPE(IfNode); -TVM_REGISTER_GLOBAL("relay._make.If") +TVM_REGISTER_GLOBAL("relay.ir.If") .set_body_typed(IfNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -182,7 +182,7 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay._make.TupleGetItem") +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem") .set_body_typed(TupleGetItemNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -199,7 +199,7 @@ RefCreate RefCreateNode::make(Expr value) { TVM_REGISTER_NODE_TYPE(RefCreateNode); -TVM_REGISTER_GLOBAL("relay._make.RefCreate") +TVM_REGISTER_GLOBAL("relay.ir.RefCreate") .set_body_typed(RefCreateNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -216,7 +216,7 @@ RefRead RefReadNode::make(Expr ref) { TVM_REGISTER_NODE_TYPE(RefReadNode); -TVM_REGISTER_GLOBAL("relay._make.RefRead") +TVM_REGISTER_GLOBAL("relay.ir.RefRead") .set_body_typed(RefReadNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -234,7 +234,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { TVM_REGISTER_NODE_TYPE(RefWriteNode); -TVM_REGISTER_GLOBAL("relay._make.RefWrite") +TVM_REGISTER_GLOBAL("relay.ir.RefWrite") .set_body_typed(RefWriteNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -243,12 +243,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; }); -TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize") +TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize") .set_body_typed([](TempExpr temp) { return temp->Realize(); }); -TVM_REGISTER_GLOBAL("relay._make.Any") +TVM_REGISTER_GLOBAL("relay.ir.Any") .set_body_typed([]() { return Any::make(); }); } // namespace relay diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 16e5fe1457d6..4b0239b8da21 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -347,7 +347,7 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit") +TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit") .set_body_typed([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); @@ -443,7 +443,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } -TVM_REGISTER_GLOBAL("relay._expr.Bind") +TVM_REGISTER_GLOBAL("relay.ir.Bind") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef input = args[0]; if (input->IsInstance()) { diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 63ad4ddb26d5..d371edb31fca 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -62,7 +62,7 @@ bool FunctionNode::UseDefaultCompiler() const { TVM_REGISTER_NODE_TYPE(FunctionNode); -TVM_REGISTER_GLOBAL("relay._make.Function") +TVM_REGISTER_GLOBAL("relay.ir.Function") .set_body_typed([](tvm::Array params, Expr body, Type ret_type, @@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->attrs << ")"; }); -TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr") +TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr") .set_body_typed( [](Function func, std::string name, ObjectRef ref) { return WithAttr(std::move(func), name, ref); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index b1bc76b18164..ce15e2a3fe70 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -423,12 +423,12 @@ size_t StructuralHash::operator()(const Expr& expr) const { return RelayHashHandler().ExprHash(expr); } -TVM_REGISTER_GLOBAL("relay._analysis._expr_hash") +TVM_REGISTER_GLOBAL("relay.analysis._expr_hash") .set_body_typed([](ObjectRef ref) { return static_cast(RelayHashHandler().Hash(ref)); }); -TVM_REGISTER_GLOBAL("relay._analysis._type_hash") +TVM_REGISTER_GLOBAL("relay.analysis._type_hash") .set_body_typed([](Type type) { return static_cast(RelayHashHandler().TypeHash(type)); }); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 2bdcbcf00a99..17d9788e8193 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -82,7 +82,7 @@ Expr MakeCast(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay._make.cast") +TVM_REGISTER_GLOBAL("relay.ir.cast") .set_body_typed(MakeCast); RELAY_REGISTER_OP("cast") @@ -138,7 +138,7 @@ Expr MakeCastLike(Expr data, } -TVM_REGISTER_GLOBAL("relay._make.cast_like") +TVM_REGISTER_GLOBAL("relay.ir.cast_like") .set_body_typed(MakeCastLike); RELAY_REGISTER_OP("cast_like") diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 03eb51d1c2c5..75afc9e7b63e 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -560,10 +560,10 @@ Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); } -TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceInfo") +TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo") .set_body_typed(CollectDeviceInfo); -TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps") +TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps") .set_body_typed(CollectDeviceAnnotationOps); namespace transform { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 2b4cc32bd790..8fcef2f15c49 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -73,7 +73,7 @@ bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); } -TVM_REGISTER_GLOBAL("relay._analysis.check_constant") +TVM_REGISTER_GLOBAL("relay.analysis.check_constant") .set_body_typed(ConstantCheck); // TODO(tvm-team) consider combine dead-code with constant folder. diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_analysis_feature.py similarity index 96% rename from tests/python/relay/test_feature.py rename to tests/python/relay/test_analysis_feature.py index 3ef53d3b88b1..ec5deb3c4e60 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_analysis_feature.py @@ -18,9 +18,8 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.analysis import detect_feature +from tvm.relay.analysis import detect_feature, Feature from tvm.relay.transform import gradient -from tvm.relay.feature import Feature from tvm.relay.prelude import Prelude from tvm.relay.testing import run_infer_type diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 71428a6dbefd..b0399a53a732 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -17,10 +17,8 @@ import numpy as np import tvm -from tvm import te from tvm import relay from tvm.contrib import graph_runtime -from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.op import add from tvm.relay.testing.config import ctx_list diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py index 92bb37367ef3..849f01546788 100644 --- a/tests/python/relay/test_call_graph.py +++ b/tests/python/relay/test_call_graph.py @@ -25,7 +25,7 @@ def test_callgraph_construct(): x = relay.var("x", shape=(2, 3)) y = relay.var("y", shape=(2, 3)) mod["g1"] = relay.Function([x, y], x + y) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert "g1" in str(call_graph) assert relay.alpha_equal(mod, call_graph.module) @@ -38,7 +38,7 @@ def test_print_element(): x1 = relay.var("x1", shape=(2, 3)) y1 = relay.var("y1", shape=(2, 3)) mod["g1"] = relay.Function([x1, y1], x1 - y1) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert "#refs = 0" in str(call_graph.print_var("g0")) assert "#refs = 0" in str(call_graph.print_var("g1")) @@ -54,13 +54,13 @@ def test_global_call_count(): y1 = relay.var("y1", shape=(2, 3)) g1 = relay.GlobalVar("g1") mod[g1] = relay.Function([x1, y1], g0(x1, y1)) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) p0 = relay.var("p0", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) mod["main"] = func - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert call_graph.global_call_count(g0) == 0 assert call_graph.global_call_count(g1) == 1 @@ -77,13 +77,13 @@ def test_ref_count(): y1 = relay.var("y1", shape=(2, 3)) g1 = relay.GlobalVar("g1") mod[g1] = relay.Function([x1, y1], x1 - y1) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) p0 = relay.var("p0", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) mod["main"] = func - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert call_graph.ref_count(g0) == 1 assert call_graph.ref_count(g1) == 1 @@ -100,13 +100,13 @@ def test_nested_ref(): y1 = relay.var("y1", shape=(2, 3)) g1 = relay.GlobalVar("g1") mod[g1] = relay.Function([x1, y1], g0(x1, y1)) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) p0 = relay.var("p0", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) mod["main"] = func - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert call_graph.ref_count(g0) == 2 assert call_graph.ref_count(g1) == 1 @@ -138,7 +138,7 @@ def test_recursive_func(): mod[sum_up] = func iarg = relay.var('i', shape=[], dtype='int32') mod["main"] = relay.Function([iarg], sum_up(iarg)) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert call_graph.is_recursive(sum_up) assert call_graph.ref_count(sum_up) == 2 diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 49e9883d8ee8..3e7d916c96fa 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -19,7 +19,6 @@ import numpy as np import tvm -from tvm import te from tvm import relay from tvm.contrib import graph_runtime from tvm.relay.expr_functor import ExprMutator diff --git a/tests/python/relay/test_memory_alloc.py b/tests/python/relay/test_pass_memory_alloc.py similarity index 98% rename from tests/python/relay/test_memory_alloc.py rename to tests/python/relay/test_pass_memory_alloc.py index 08fc39df9ad0..c3c53121d934 100644 --- a/tests/python/relay/test_memory_alloc.py +++ b/tests/python/relay/test_pass_memory_alloc.py @@ -18,7 +18,7 @@ from tvm import te import numpy as np from tvm import relay -from tvm.relay import memory_alloc +from tvm.relay.transform import memory_alloc def check_vm_alloc(func, check_fn): mod = tvm.IRModule() diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 11f7e971b283..c4fbbc1458d9 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -25,7 +25,7 @@ from tvm import runtime from tvm.relay import transform from tvm.contrib import util -from tvm.relay.annotation import compiler_begin, compiler_end +from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.expr_functor import ExprMutator # Leverage the pass manager to write a simple white list based annotator diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index f68f64874c78..2a6103ea1fbe 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -22,7 +22,7 @@ from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count -from tvm.relay.feature import Feature +from tvm.relay.analysis import Feature def run_opt_pass(expr, passes): diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index fe4959ed8ce3..e2ac924e9661 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -16,15 +16,14 @@ # under the License. import numpy as np import tvm -from tvm import te from tvm import relay from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay.transform import to_cps, un_cps -from tvm.relay.feature import Feature +from tvm.relay.analysis import Feature from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass from tvm.relay import create_executor -from tvm.relay import Function, transform +from tvm.relay import transform def test_id(): diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index dc47ad350fe5..94886220874d 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -16,9 +16,9 @@ # under the License. import numpy as np import tvm -from tvm import te from tvm import relay -from tvm.relay import op, create_executor, transform, Feature +from tvm.relay import op, create_executor, transform +from tvm.relay.analysis import Feature from tvm.relay.analysis import detect_feature diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index d90fd29a7eb5..6d72ad3af9af 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te from tvm import relay import pytest @@ -27,7 +26,7 @@ def make_rel(name, args, num_inputs=None, attrs=None): return relay.ty.TypeRelation(func, args, num_inputs, attrs) def make_solver(): - solver = relay._analysis._test_type_solver() + solver = relay.analysis._ffi_api._test_type_solver() solver.Solve = solver("Solve") solver.Unify = solver("Unify") solver.Resolve = solver("Resolve") diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index a8ac27a11c0f..f2b15ec26f32 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -18,7 +18,6 @@ import pytest import tvm -from tvm import te from tvm import runtime from tvm import relay from tvm.relay.scope_builder import ScopeBuilder