From 49e1a9114c48304f216c517e31b461e9b4c4d799 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 16 Mar 2020 05:42:19 +0000 Subject: [PATCH 1/4] refactor relay python --- docs/api/python/relay/base.rst | 12 +- docs/api/python/relay/expr.rst | 38 ++--- docs/api/python/relay/scope_builder.rst | 6 +- .../graph_tuner/utils/traverse_graph.py | 4 +- python/tvm/relay/__init__.py | 138 +++++++++--------- python/tvm/relay/analysis/__init__.py | 26 ++++ .../{_analysis.py => analysis/_ffi_api.py} | 4 +- python/tvm/relay/{ => analysis}/analysis.py | 108 +++++++------- python/tvm/relay/{ => analysis}/call_graph.py | 20 +-- python/tvm/relay/{ => analysis}/feature.py | 0 python/tvm/relay/annotation.py | 21 --- python/tvm/relay/backend/compile_engine.py | 2 +- python/tvm/relay/backend/interpreter.py | 6 +- python/tvm/relay/contrib.py | 20 --- python/tvm/relay/frontend/pytorch.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 4 +- python/tvm/relay/frontend/tflite.py | 2 +- python/tvm/relay/{ => frontend}/util.py | 2 +- python/tvm/relay/image.py | 20 --- python/tvm/relay/ir/__init__.py | 95 ++++++++++++ python/tvm/relay/{_base.py => ir/_ffi_api.py} | 5 +- python/tvm/relay/{ => ir}/_parser.py | 10 +- python/tvm/relay/{ => ir}/adt.py | 14 +- python/tvm/relay/{ => ir}/base.py | 6 +- python/tvm/relay/{ => ir}/expr.py | 33 ++--- python/tvm/relay/{ => ir}/expr_functor.py | 2 +- python/tvm/relay/{ => ir}/loops.py | 0 python/tvm/relay/{ => ir}/parser.py | 4 +- python/tvm/relay/{ => ir}/prelude.py | 4 +- python/tvm/relay/{ => ir}/scope_builder.py | 2 +- python/tvm/relay/{ => ir}/ty.py | 6 +- python/tvm/relay/{ => ir}/type_functor.py | 0 python/tvm/relay/nn.py | 20 --- python/tvm/relay/op/__init__.py | 4 +- python/tvm/relay/op/_tensor_grad.py | 2 +- python/tvm/relay/op/algorithm.py | 2 +- python/tvm/relay/op/nn/nn.py | 2 +- python/tvm/relay/op/op.py | 4 +- python/tvm/relay/op/op_attrs.py | 2 +- python/tvm/relay/op/reduce.py | 2 +- python/tvm/relay/op/tensor.py | 2 +- python/tvm/relay/op/transform.py | 4 +- python/tvm/relay/op/vision/multibox.py | 2 +- python/tvm/relay/op/vision/nms.py | 2 +- python/tvm/relay/qnn/op/legalizations.py | 2 +- python/tvm/relay/qnn/op/qnn.py | 2 +- python/tvm/relay/quantize/_annotate.py | 2 +- python/tvm/relay/quantize/_partition.py | 2 +- python/tvm/relay/quantize/quantize.py | 2 +- python/tvm/relay/testing/__init__.py | 6 +- python/tvm/relay/testing/nat.py | 6 +- python/tvm/relay/testing/py_converter.py | 6 +- .../relay/{_expr.py => transform/__init__.py} | 9 +- .../{_transform.py => transform/_ffi_api.py} | 2 +- .../tvm/relay/{ => transform}/memory_alloc.py | 11 +- python/tvm/relay/{ => transform}/transform.py | 76 +++++----- python/tvm/relay/vision.py | 20 --- src/relay/analysis/alpha_equal.cc | 8 +- src/relay/analysis/call_graph.cc | 14 +- src/relay/analysis/extract_fused_functions.cc | 2 +- src/relay/analysis/feature.cc | 2 +- src/relay/analysis/kind_check.cc | 2 +- src/relay/analysis/mac_count.cc | 2 +- src/relay/analysis/match_exhaustion.cc | 2 +- src/relay/analysis/type_solver.cc | 2 +- src/relay/analysis/util.cc | 12 +- src/relay/analysis/well_formed.cc | 2 +- src/relay/backend/compile_engine.cc | 2 +- src/relay/ir/adt.cc | 12 +- src/relay/ir/expr.cc | 24 +-- src/relay/ir/expr_functor.cc | 4 +- src/relay/ir/function.cc | 4 +- src/relay/ir/hash.cc | 4 +- src/relay/op/tensor/transform.cc | 4 +- src/relay/transforms/device_annotation.cc | 4 +- src/relay/transforms/fold_constant.cc | 2 +- tests/python/relay/test_adt.py | 2 +- tests/python/relay/test_any.py | 2 +- .../relay/test_backend_graph_runtime.py | 2 - .../python/relay/test_backend_interpreter.py | 2 +- tests/python/relay/test_feature.py | 5 +- tests/python/relay/test_ir_module.py | 2 +- tests/python/relay/test_ir_well_formed.py | 2 +- tests/python/relay/test_pass_annotation.py | 2 +- tests/python/relay/test_pass_gradient.py | 2 +- tests/python/relay/test_pass_manager.py | 3 +- tests/python/relay/test_pass_partial_eval.py | 8 +- .../python/relay/test_pass_partition_graph.py | 4 +- .../test_pass_remove_unused_functions.py | 2 +- .../relay/test_pass_to_a_normal_form.py | 4 +- tests/python/relay/test_pass_to_cps.py | 6 +- .../python/relay/test_pass_unmatched_cases.py | 2 +- tests/python/relay/test_py_converter.py | 3 +- tests/python/relay/test_type_functor.py | 4 +- tests/python/relay/test_type_solver.py | 3 +- tests/python/relay/test_vm.py | 7 +- tests/python/relay/test_vm_serialization.py | 4 +- .../test_autotvm_graph_tuner_utils.py | 2 +- 98 files changed, 503 insertions(+), 488 deletions(-) create mode 100644 python/tvm/relay/analysis/__init__.py rename python/tvm/relay/{_analysis.py => analysis/_ffi_api.py} (88%) rename python/tvm/relay/{ => analysis}/analysis.py (79%) rename python/tvm/relay/{ => analysis}/call_graph.py (88%) rename python/tvm/relay/{ => analysis}/feature.py (100%) delete mode 100644 python/tvm/relay/annotation.py delete mode 100644 python/tvm/relay/contrib.py rename python/tvm/relay/{ => frontend}/util.py (98%) delete mode 100644 python/tvm/relay/image.py create mode 100644 python/tvm/relay/ir/__init__.py rename python/tvm/relay/{_base.py => ir/_ffi_api.py} (82%) rename python/tvm/relay/{ => ir}/_parser.py (99%) rename python/tvm/relay/{ => ir}/adt.py (89%) rename python/tvm/relay/{ => ir}/base.py (93%) rename python/tvm/relay/{ => ir}/expr.py (94%) rename python/tvm/relay/{ => ir}/expr_functor.py (99%) rename python/tvm/relay/{ => ir}/loops.py (100%) rename python/tvm/relay/{ => ir}/parser.py (94%) rename python/tvm/relay/{ => ir}/prelude.py (99%) rename python/tvm/relay/{ => ir}/scope_builder.py (99%) rename python/tvm/relay/{ => ir}/ty.py (95%) rename python/tvm/relay/{ => ir}/type_functor.py (100%) delete mode 100644 python/tvm/relay/nn.py rename python/tvm/relay/{_expr.py => transform/__init__.py} (79%) rename python/tvm/relay/{_transform.py => transform/_ffi_api.py} (93%) rename python/tvm/relay/{ => transform}/memory_alloc.py (98%) rename python/tvm/relay/{ => transform}/transform.py (92%) delete mode 100644 python/tvm/relay/vision.py diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst index dc9dac0f67bd..61b18ac9c815 100644 --- a/docs/api/python/relay/base.rst +++ b/docs/api/python/relay/base.rst @@ -15,16 +15,16 @@ specific language governing permissions and limitations under the License. -tvm.relay.base +tvm.relay.ir.base -------------- -.. automodule:: tvm.relay.base +.. automodule:: tvm.relay.ir.base -.. autofunction:: tvm.relay.base.register_relay_node +.. autofunction:: tvm.relay.ir.base.register_relay_node -.. autofunction:: tvm.relay.base.register_relay_attr_node +.. autofunction:: tvm.relay.ir.base.register_relay_attr_node -.. autoclass:: tvm.relay.base.RelayNode +.. autoclass:: tvm.relay.ir.base.RelayNode :members: -.. autoclass:: tvm.relay.base.Id +.. autoclass:: tvm.relay.ir.base.Id :members: diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst index 57a4a2511b72..e7944730a4b8 100644 --- a/docs/api/python/relay/expr.rst +++ b/docs/api/python/relay/expr.rst @@ -15,56 +15,56 @@ specific language governing permissions and limitations under the License. -tvm.relay.expr +tvm.relay.ir.expr -------------- -.. automodule:: tvm.relay.expr +.. automodule:: tvm.relay.ir.expr -.. autofunction:: tvm.relay.expr.var +.. autofunction:: tvm.relay.ir.expr.var -.. autofunction:: tvm.relay.expr.const +.. autofunction:: tvm.relay.ir.expr.const -.. autofunction:: tvm.relay.expr.bind +.. autofunction:: tvm.relay.ir.expr.bind -.. autoclass:: tvm.relay.expr.Expr +.. autoclass:: tvm.relay.ir.expr.Expr :members: -.. autoclass:: tvm.relay.expr.Constant +.. autoclass:: tvm.relay.ir.expr.Constant :members: -.. autoclass:: tvm.relay.expr.Tuple +.. autoclass:: tvm.relay.ir.expr.Tuple :members: -.. autoclass:: tvm.relay.expr.Function +.. autoclass:: tvm.relay.ir.expr.Function :members: -.. autoclass:: tvm.relay.expr.Call +.. autoclass:: tvm.relay.ir.expr.Call :members: -.. autoclass:: tvm.relay.expr.Let +.. autoclass:: tvm.relay.ir.expr.Let :members: -.. autoclass:: tvm.relay.expr.If +.. autoclass:: tvm.relay.ir.expr.If :members: -.. autoclass:: tvm.relay.expr.TupleGetItem +.. autoclass:: tvm.relay.ir.expr.TupleGetItem :members: -.. autoclass:: tvm.relay.expr.RefCreate +.. autoclass:: tvm.relay.ir.expr.RefCreate :members: -.. autoclass:: tvm.relay.expr.RefRead +.. autoclass:: tvm.relay.ir.expr.RefRead :members: -.. autoclass:: tvm.relay.expr.RefWrite +.. autoclass:: tvm.relay.ir.expr.RefWrite :members: -.. autoclass:: tvm.relay.expr.TupleGetItem +.. autoclass:: tvm.relay.ir.expr.TupleGetItem :members: -.. autoclass:: tvm.relay.expr.TempExpr +.. autoclass:: tvm.relay.ir.expr.TempExpr :members: -.. autoclass:: tvm.relay.expr.TupleWrapper +.. autoclass:: tvm.relay.ir.expr.TupleWrapper :members: diff --git a/docs/api/python/relay/scope_builder.rst b/docs/api/python/relay/scope_builder.rst index 6d8e01428e31..730751f7a581 100644 --- a/docs/api/python/relay/scope_builder.rst +++ b/docs/api/python/relay/scope_builder.rst @@ -15,10 +15,10 @@ specific language governing permissions and limitations under the License. -tvm.relay.scope_builder +tvm.relay.ir.scope_builder ----------------------- -.. automodule:: tvm.relay.scope_builder +.. automodule:: tvm.relay.ir.scope_builder -.. autoclass:: tvm.relay.scope_builder.ScopeBuilder +.. autoclass:: tvm.relay.ir.scope_builder.ScopeBuilder :members: diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index f1dd40440532..e463d20242d7 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -21,8 +21,8 @@ import tvm from tvm import relay, autotvm from tvm.relay import transform -from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple -from tvm.relay.ty import TupleType, TensorType +from tvm.relay.ir import Call, Function, TupleGetItem, Var, Constant, Tuple +from tvm.relay.ir import TupleType, TensorType from tvm.autotvm.task import TaskExtractEnv from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index f4a7c75864d5..8184b586016c 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -19,35 +19,33 @@ import os from sys import setrecursionlimit -from . import call_graph -from . import base -from . import ty -from . import expr -from . import type_functor -from . import expr_functor -from . import adt -from . import analysis +from . import ir +from .ir import adt, expr, ty, base, scope_builder +from .ir import prelude, loops, parser + from . import transform +from . import analysis +from .analysis import call_graph, feature, alpha_equal from .build_module import build, create_executor, optimize from .transform import build_config -from . import prelude -from . import parser from . import debug from . import param_dict -from . import feature from .backend import vm # Root operators from .op import Op +from .op import nn +from .op import image +from .op import vision +from .op import annotation from .op.reduce import * from .op.tensor import * from .op.transform import * from .op.algorithm import * -from . import nn -from . import annotation -from . import vision -from . import contrib -from . import image +from .op.nn import * +from .op.vision import * +from .op.contrib import * +from .op.image import * from . import frontend from . import backend from . import quantize @@ -55,75 +53,81 @@ # Dialects from . import qnn -from .scope_builder import ScopeBuilder # Load Memory pass -from . import memory_alloc +from .transform import memory_alloc # Required to traverse large programs setrecursionlimit(10000) # Span -Span = base.Span +Span = ir.Span # Type -Type = ty.Type -TupleType = ty.TupleType -TensorType = ty.TensorType -TypeKind = ty.TypeKind -TypeVar = ty.TypeVar -ShapeVar = ty.ShapeVar -TypeConstraint = ty.TypeConstraint -FuncType = ty.FuncType -TypeRelation = ty.TypeRelation -IncompleteType = ty.IncompleteType -scalar_type = ty.scalar_type -RefType = ty.RefType -GlobalTypeVar = ty.GlobalTypeVar -TypeCall = ty.TypeCall -Any = ty.Any +Type = ir.Type +TupleType = ir.TupleType +TensorType = ir.TensorType +TypeKind = ir.TypeKind +TypeVar = ir.TypeVar +ShapeVar = ir.ShapeVar +TypeConstraint = ir.TypeConstraint +FuncType = ir.FuncType +TypeRelation = ir.TypeRelation +IncompleteType = ir.IncompleteType +scalar_type = ir.scalar_type +RefType = ir.RefType +GlobalTypeVar = ir.GlobalTypeVar +TypeCall = ir.TypeCall +Any = ir.Any # Expr -Expr = expr.RelayExpr -Constant = expr.Constant -Tuple = expr.Tuple -Var = expr.Var -GlobalVar = expr.GlobalVar -Function = expr.Function -Call = expr.Call -Let = expr.Let -If = expr.If -TupleGetItem = expr.TupleGetItem -RefCreate = expr.RefCreate -RefRead = expr.RefRead -RefWrite = expr.RefWrite +Expr = ir.Expr +Constant = ir.Constant +Tuple = ir.Tuple +Var = ir.Var +GlobalVar = ir.GlobalVar +Function = ir.Function +Call = ir.Call +Let = ir.Let +If = ir.If +TupleGetItem = ir.TupleGetItem +RefCreate = ir.RefCreate +RefRead = ir.RefRead +RefWrite = ir.RefWrite # ADT -PatternWildcard = adt.PatternWildcard -PatternVar = adt.PatternVar -PatternConstructor = adt.PatternConstructor -PatternTuple = adt.PatternTuple -Constructor = adt.Constructor -TypeData = adt.TypeData -Clause = adt.Clause -Match = adt.Match +Pattern = ir.Pattern +PatternWildcard = ir.PatternWildcard +PatternVar = ir.PatternVar +PatternConstructor = ir.PatternConstructor +PatternTuple = ir.PatternTuple +Constructor = ir.Constructor +TypeData = ir.TypeData +Clause = ir.Clause +Match = ir.Match # helper functions -var = expr.var -const = expr.const -bind = expr.bind -module_pass = transform.module_pass -function_pass = transform.function_pass -alpha_equal = analysis.alpha_equal +var = ir.var +const = ir.const +bind = ir.bind # TypeFunctor -TypeFunctor = type_functor.TypeFunctor -TypeVisitor = type_functor.TypeVisitor -TypeMutator = type_functor.TypeMutator +TypeFunctor = ir.TypeFunctor +TypeVisitor = ir.TypeVisitor +TypeMutator = ir.TypeMutator # ExprFunctor -ExprFunctor = expr_functor.ExprFunctor -ExprVisitor = expr_functor.ExprVisitor -ExprMutator = expr_functor.ExprMutator +ExprFunctor = ir.ExprFunctor +ExprVisitor = ir.ExprVisitor +ExprMutator = ir.ExprMutator + +# Prelude +Prelude = prelude.Prelude + +# Scope builder +ScopeBuilder = scope_builder.ScopeBuilder + +module_pass = transform.module_pass +function_pass = transform.function_pass # Parser fromtext = parser.fromtext diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py new file mode 100644 index 000000000000..eb8a91e084f3 --- /dev/null +++ b/python/tvm/relay/analysis/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay IR namespace containing the analysis passes.""" +# Analysis passes +from .analysis import * + +# Call graph +from . import call_graph + +# Feature +from . import feature diff --git a/python/tvm/relay/_analysis.py b/python/tvm/relay/analysis/_ffi_api.py similarity index 88% rename from python/tvm/relay/_analysis.py rename to python/tvm/relay/analysis/_ffi_api.py index 050fcce2fb17..20b03c396e70 100644 --- a/python/tvm/relay/_analysis.py +++ b/python/tvm/relay/analysis/_ffi_api.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI exposing the passes for Relay program analysis.""" +"""FFI APIs for Relay program analysis.""" import tvm._ffi -tvm._ffi._init_api("relay._analysis", __name__) +tvm._ffi._init_api("relay.analysis", __name__) diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis/analysis.py similarity index 79% rename from python/tvm/relay/analysis.py rename to python/tvm/relay/analysis/analysis.py index 198e0a3bf9eb..2e4465bac08c 100644 --- a/python/tvm/relay/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -22,9 +22,9 @@ """ from tvm.ir import RelayExpr, IRModule -from . import _analysis -from .ty import Type +from . import _ffi_api from .feature import Feature +from ..ir import Type def post_order_visit(expr, fvisit): @@ -34,13 +34,13 @@ def post_order_visit(expr, fvisit): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression. fvisit : function The visitor function to be applied. """ - return _analysis.post_order_visit(expr, fvisit) + return _ffi_api.post_order_visit(expr, fvisit) def well_formed(expr): @@ -48,7 +48,7 @@ def well_formed(expr): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression Returns @@ -56,7 +56,7 @@ def well_formed(expr): well_form : bool Whether the input expression is well formed """ - return _analysis.well_formed(expr) + return _ffi_api.well_formed(expr) def check_kind(t, mod=None): @@ -85,9 +85,9 @@ def check_kind(t, mod=None): assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type """ if mod is not None: - return _analysis.check_kind(t, mod) + return _ffi_api.check_kind(t, mod) else: - return _analysis.check_kind(t) + return _ffi_api.check_kind(t) def check_constant(expr): @@ -95,7 +95,7 @@ def check_constant(expr): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression Returns @@ -103,7 +103,7 @@ def check_constant(expr): result : bool Whether the expression is constant. """ - return _analysis.check_constant(expr) + return _ffi_api.check_constant(expr) def free_vars(expr): @@ -111,7 +111,7 @@ def free_vars(expr): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression Returns @@ -125,7 +125,7 @@ def free_vars(expr): neural networks: usually this means weights of previous are ordered first. """ - return _analysis.free_vars(expr) + return _ffi_api.free_vars(expr) def bound_vars(expr): @@ -133,7 +133,7 @@ def bound_vars(expr): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression Returns @@ -141,7 +141,7 @@ def bound_vars(expr): free : List[tvm.relay.Var] The list of bound variables in post-DFS order. """ - return _analysis.bound_vars(expr) + return _ffi_api.bound_vars(expr) def all_vars(expr): @@ -149,7 +149,7 @@ def all_vars(expr): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression Returns @@ -157,7 +157,7 @@ def all_vars(expr): free : List[tvm.relay.Var] The list of all variables in post-DFS order. """ - return _analysis.all_vars(expr) + return _ffi_api.all_vars(expr) def free_type_vars(expr, mod=None): @@ -165,7 +165,7 @@ def free_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None): The list of free type variables in post-DFS order """ use_mod = mod if mod is not None else IRModule() - return _analysis.free_type_vars(expr, use_mod) + return _ffi_api.free_type_vars(expr, use_mod) def bound_type_vars(expr, mod=None): @@ -185,7 +185,7 @@ def bound_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None): The list of bound type variables in post-DFS order """ use_mod = mod if mod is not None else IRModule() - return _analysis.bound_type_vars(expr, use_mod) + return _ffi_api.bound_type_vars(expr, use_mod) def all_type_vars(expr, mod=None): @@ -205,7 +205,7 @@ def all_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None): The list of all type variables in post-DFS order """ use_mod = mod if mod is not None else IRModule() - return _analysis.all_type_vars(expr, use_mod) + return _ffi_api.all_type_vars(expr, use_mod) def alpha_equal(lhs, rhs): @@ -225,10 +225,10 @@ def alpha_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.Expr + lhs : tvm.relay.ir.Expr One of the input Expression. - rhs : tvm.relay.Expr + rhs : tvm.relay.ir.Expr One of the input Expression. Returns @@ -236,7 +236,7 @@ def alpha_equal(lhs, rhs): result : bool True iff lhs is alpha equal to rhs. """ - return bool(_analysis._alpha_equal(lhs, rhs)) + return bool(_ffi_api._alpha_equal(lhs, rhs)) def assert_alpha_equal(lhs, rhs): @@ -244,13 +244,13 @@ def assert_alpha_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.Expr + lhs : tvm.relay.ir.Expr One of the input Expression. - rhs : tvm.relay.Expr + rhs : tvm.relay.ir.Expr One of the input Expression. """ - _analysis._assert_alpha_equal(lhs, rhs) + _ffi_api._assert_alpha_equal(lhs, rhs) def graph_equal(lhs, rhs): @@ -261,10 +261,10 @@ def graph_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.Expr + lhs : tvm.relay.ir.Expr One of the input Expression. - rhs : tvm.relay.Expr + rhs : tvm.relay.ir.Expr One of the input Expression. Returns @@ -272,7 +272,7 @@ def graph_equal(lhs, rhs): result : bool True iff lhs is data-flow equivalent to rhs. """ - return bool(_analysis._graph_equal(lhs, rhs)) + return bool(_ffi_api._graph_equal(lhs, rhs)) def assert_graph_equal(lhs, rhs): @@ -283,13 +283,13 @@ def assert_graph_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.Expr + lhs : tvm.relay.ir.Expr One of the input Expression. - rhs : tvm.relay.Expr + rhs : tvm.relay.ir.Expr One of the input Expression. """ - _analysis._assert_graph_equal(lhs, rhs) + _ffi_api._assert_graph_equal(lhs, rhs) def collect_device_info(expr): @@ -298,15 +298,15 @@ def collect_device_info(expr): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression. Returns ------- - ret : Dict[tvm.relay.expr, int] - A dictionary mapping tvm.relay.Expr to device type. + ret : Dict[tvm.relay.ir.expr, int] + A dictionary mapping tvm.relay.ir.Expr to device type. """ - return _analysis.CollectDeviceInfo(expr) + return _ffi_api.CollectDeviceInfo(expr) def collect_device_annotation_ops(expr): @@ -314,16 +314,16 @@ def collect_device_annotation_ops(expr): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression. Returns ------- - ret : Dict[tvm.relay.expr, int] - A dictionary mapping tvm.relay.Expr to device type where the keys are + ret : Dict[tvm.relay.ir.Expr, int] + A dictionary mapping tvm.relay.ir.Expr to device type where the keys are annotation expressions. """ - return _analysis.CollectDeviceAnnotationOps(expr) + return _ffi_api.CollectDeviceAnnotationOps(expr) def get_total_mac_number(expr): @@ -332,7 +332,7 @@ def get_total_mac_number(expr): Parameters ---------- - expr : tvm.relay.Expr + expr : tvm.relay.ir.Expr The input expression. Returns @@ -340,7 +340,7 @@ def get_total_mac_number(expr): result : int64 The number of MACs (multiply-accumulate) of a model """ - return _analysis.GetTotalMacNumber(expr) + return _ffi_api.GetTotalMacNumber(expr) def unmatched_cases(match, mod=None): @@ -360,7 +360,7 @@ def unmatched_cases(match, mod=None): missing_patterns : [tvm.relay.Pattern] Patterns that the match expression does not catch. """ - return _analysis.unmatched_cases(match, mod) + return _ffi_api.unmatched_cases(match, mod) def detect_feature(a, b=None): @@ -369,10 +369,10 @@ def detect_feature(a, b=None): Parameters ---------- - a : Union[tvm.relay.Expr, tvm.IRModule] + a : Union[tvm.relay.ir.Expr, tvm.IRModule] The input expression or module. - b : Optional[Union[tvm.relay.Expr, tvm.IRModule]] + b : Optional[Union[tvm.relay.ir.Expr, tvm.IRModule]] The input expression or module. The two arguments cannot both be expression or module. @@ -383,7 +383,7 @@ def detect_feature(a, b=None): """ if isinstance(a, IRModule): a, b = b, a - return {Feature(int(x)) for x in _analysis.detect_feature(a, b)} + return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)} def structural_hash(value): @@ -391,7 +391,7 @@ def structural_hash(value): Parameters ---------- - expr : Union[tvm.relay.Expr, tvm.relay.Type] + expr : Union[tvm.relay.ir.Expr, tvm.relay.Type] The expression to hash. Returns @@ -400,12 +400,12 @@ def structural_hash(value): The hash value """ if isinstance(value, RelayExpr): - return int(_analysis._expr_hash(value)) + return int(_ffi_api._expr_hash(value)) elif isinstance(value, Type): - return int(_analysis._type_hash(value)) + return int(_ffi_api._type_hash(value)) else: msg = ("found value of type {0} expected" + - "relay.Expr or relay.Type").format(type(value)) + "relay.ir.Expr or relay.Type").format(type(value)) raise TypeError(msg) @@ -421,10 +421,10 @@ def extract_fused_functions(mod): Returns ------- - ret : Dict[int, tvm.relay.expr.Function] + ret : Dict[int, tvm.relay.ir.expr.Function] A module containing only fused primitive functions """ - ret_mod = _analysis.ExtractFusedFunctions()(mod) + ret_mod = _ffi_api.ExtractFusedFunctions()(mod) ret = {} for hash_, func in ret_mod.functions.items(): ret[hash_] = func diff --git a/python/tvm/relay/call_graph.py b/python/tvm/relay/analysis/call_graph.py similarity index 88% rename from python/tvm/relay/call_graph.py rename to python/tvm/relay/analysis/call_graph.py index 8206f5dccd4c..0d1053612c8d 100644 --- a/python/tvm/relay/call_graph.py +++ b/python/tvm/relay/analysis/call_graph.py @@ -18,9 +18,9 @@ """Call graph used in Relay.""" from tvm.ir import IRModule -from .base import Object -from .expr import GlobalVar -from . import _analysis +from tvm.runtime import Object +from ..ir import GlobalVar +from . import _ffi_api class CallGraph(Object): @@ -39,7 +39,7 @@ def __init__(self, module): call_graph: CallGraph A constructed call graph. """ - self.__init_handle_by_constructor__(_analysis.CallGraph, module) + self.__init_handle_by_constructor__(_ffi_api.CallGraph, module) @property def module(self): @@ -54,7 +54,7 @@ def module(self): ret : tvm.ir.IRModule The contained IRModule """ - return _analysis.GetModule(self) + return _ffi_api.GetModule(self) def ref_count(self, var): """Return the number of references to the global var @@ -69,7 +69,7 @@ def ref_count(self, var): The number reference to the global var """ var = self._get_global_var(var) - return _analysis.GetRefCountGlobalVar(self, var) + return _ffi_api.GetRefCountGlobalVar(self, var) def global_call_count(self, var): """Return the number of global function calls from a given global var. @@ -84,7 +84,7 @@ def global_call_count(self, var): The number of global function calls from the given var. """ var = self._get_global_var(var) - return _analysis.GetGlobalVarCallCount(self, var) + return _ffi_api.GetGlobalVarCallCount(self, var) def is_recursive(self, var): """Return if the function corresponding to a var is a recursive @@ -100,7 +100,7 @@ def is_recursive(self, var): If the function corresponding to var is recurisve. """ var = self._get_global_var(var) - return _analysis.IsRecursive(self, var) + return _ffi_api.IsRecursive(self, var) def _get_global_var(self, var): """Return the global var using a given name or GlobalVar. @@ -137,8 +137,8 @@ def print_var(self, var): The call graph represented in string. """ var = self._get_global_var(var) - return _analysis.PrintCallGraphGlobalVar(self, var) + return _ffi_api.PrintCallGraphGlobalVar(self, var) def __str__(self): """Print the call graph in the topological order.""" - return _analysis.PrintCallGraph(self) + return _ffi_api.PrintCallGraph(self) diff --git a/python/tvm/relay/feature.py b/python/tvm/relay/analysis/feature.py similarity index 100% rename from python/tvm/relay/feature.py rename to python/tvm/relay/analysis/feature.py diff --git a/python/tvm/relay/annotation.py b/python/tvm/relay/annotation.py deleted file mode 100644 index 5a4065313dcc..000000000000 --- a/python/tvm/relay/annotation.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Annotation related operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.annotation import * diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index a51e4f7bad11..38efafeee66e 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -22,7 +22,7 @@ import numpy as np import tvm from tvm import te -from ..base import register_relay_node, Object +from ..ir.base import register_relay_node, Object from ... import target as _target from ... import autotvm from .. import expr as _expr diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 18f848c212b2..a7245987ab01 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -26,9 +26,9 @@ from . import _backend from .. import _make, analysis, transform from ... import nd -from ..base import Object, register_relay_node -from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const -from ..scope_builder import ScopeBuilder +from ..ir.base import Object, register_relay_node +from ..ir import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const +from ..ir.scope_builder import ScopeBuilder @register_relay_node diff --git a/python/tvm/relay/contrib.py b/python/tvm/relay/contrib.py deleted file mode 100644 index d22c67614999..000000000000 --- a/python/tvm/relay/contrib.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Contrib operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.contrib import * diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6da91c17fd94..7b16bc9849aa 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -30,7 +30,7 @@ from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op -from ..loops import while_loop +from ..ir.loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 29d9d1bcb93b..d1abb4f1a5d3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,12 +27,12 @@ import tvm from tvm.ir import IRModule -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from .. import analysis from .. import expr as _expr from .. import op as _op -from ..expr_functor import ExprMutator +from ..ir import ExprMutator from .common import AttrCvt, get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b4891c3d5d96..d5def8a68e0d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -26,8 +26,8 @@ from .. import expr as _expr from .. import op as _op from .. import qnn as _qnn -from ..util import get_scalar_from_constant from ... import nd as _nd +from .util import get_scalar_from_constant from .common import ExprTable from .common import infer_shape as _infer_shape diff --git a/python/tvm/relay/util.py b/python/tvm/relay/frontend/util.py similarity index 98% rename from python/tvm/relay/util.py rename to python/tvm/relay/frontend/util.py index b207182e4113..a7f89a30b996 100644 --- a/python/tvm/relay/util.py +++ b/python/tvm/relay/frontend/util.py @@ -18,7 +18,7 @@ """ Utility functions that are used across many directories. """ from __future__ import absolute_import import numpy as np -from . import expr as _expr +from .. import expr as _expr def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ diff --git a/python/tvm/relay/image.py b/python/tvm/relay/image.py deleted file mode 100644 index 4d5cc5a47448..000000000000 --- a/python/tvm/relay/image.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Image network related operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.image import * diff --git a/python/tvm/relay/ir/__init__.py b/python/tvm/relay/ir/__init__.py new file mode 100644 index 000000000000..2a141aa271ed --- /dev/null +++ b/python/tvm/relay/ir/__init__.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay IR namespace.""" +from . import base +from . import ty +from . import expr +from . import type_functor +from . import expr_functor +from . import adt +from . import prelude +from . import loops +from . import scope_builder + +# Span +Span = base.Span +SourceName = base.SourceName + +# Type +Type = ty.Type +TupleType = ty.TupleType +TensorType = ty.TensorType +TypeKind = ty.TypeKind +TypeVar = ty.TypeVar +ShapeVar = ty.ShapeVar +TypeConstraint = ty.TypeConstraint +FuncType = ty.FuncType +TypeRelation = ty.TypeRelation +IncompleteType = ty.IncompleteType +scalar_type = ty.scalar_type +RefType = ty.RefType +GlobalTypeVar = ty.GlobalTypeVar +TypeCall = ty.TypeCall +Any = ty.Any + +# Expr +Expr = expr.RelayExpr +Constant = expr.Constant +Tuple = expr.Tuple +Var = expr.Var +GlobalVar = expr.GlobalVar +Function = expr.Function +Call = expr.Call +Let = expr.Let +If = expr.If +TupleGetItem = expr.TupleGetItem +RefCreate = expr.RefCreate +RefRead = expr.RefRead +RefWrite = expr.RefWrite + +# ADT +Pattern = adt.Pattern +PatternWildcard = adt.PatternWildcard +PatternVar = adt.PatternVar +PatternConstructor = adt.PatternConstructor +PatternTuple = adt.PatternTuple +Constructor = adt.Constructor +TypeData = adt.TypeData +Clause = adt.Clause +Match = adt.Match + +# helper functions +var = expr.var +const = expr.const +bind = expr.bind + +# TypeFunctor +TypeFunctor = type_functor.TypeFunctor +TypeVisitor = type_functor.TypeVisitor +TypeMutator = type_functor.TypeMutator + +# ExprFunctor +ExprFunctor = expr_functor.ExprFunctor +ExprVisitor = expr_functor.ExprVisitor +ExprMutator = expr_functor.ExprMutator + +# Prelude +Prelude = prelude.Prelude + +# Scope builder +ScopeBuilder = scope_builder.ScopeBuilder diff --git a/python/tvm/relay/_base.py b/python/tvm/relay/ir/_ffi_api.py similarity index 82% rename from python/tvm/relay/_base.py rename to python/tvm/relay/ir/_ffi_api.py index f86aa70353dc..8e9b46a14d35 100644 --- a/python/tvm/relay/_base.py +++ b/python/tvm/relay/ir/_ffi_api.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable -"""The interface of expr function exposed from C++.""" +"""FFI APIs for Relay program IR.""" import tvm._ffi -tvm._ffi._init_api("relay._base", __name__) +tvm._ffi._init_api("relay.ir", __name__) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/ir/_parser.py similarity index 99% rename from python/tvm/relay/_parser.py rename to python/tvm/relay/ir/_parser.py index 49bdbb393c2e..354014a78862 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/ir/_parser.py @@ -40,11 +40,11 @@ def __new__(cls, *args, **kwds): import tvm.ir._ffi_api from tvm.ir import IRModule -from .base import Span, SourceName +from . import Span, SourceName from . import adt from . import expr from . import ty -from . import op +from .. import op PYTHON_VERSION = sys.version_info.major try: @@ -56,9 +56,9 @@ def __new__(cls, *args, **kwds): .format(version=PYTHON_VERSION)) try: - from .grammar.py3.RelayVisitor import RelayVisitor - from .grammar.py3.RelayParser import RelayParser - from .grammar.py3.RelayLexer import RelayLexer + from ..grammar.py3.RelayVisitor import RelayVisitor + from ..grammar.py3.RelayParser import RelayParser + from ..grammar.py3.RelayLexer import RelayLexer except ImportError: raise Exception("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/ir/adt.py similarity index 89% rename from python/tvm/relay/adt.py rename to python/tvm/relay/ir/adt.py index 9c5dac6362e2..8b4127286948 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/ir/adt.py @@ -19,7 +19,7 @@ from tvm.ir import Constructor, TypeData from .base import RelayNode, register_relay_node, Object -from . import _make +from . import _ffi_api from .ty import Type from .expr import ExprWithOp, RelayExpr, Call @@ -44,7 +44,7 @@ def __init__(self): wildcard: PatternWildcard a wildcard pattern. """ - self.__init_handle_by_constructor__(_make.PatternWildcard) + self.__init_handle_by_constructor__(_ffi_api.PatternWildcard) @register_relay_node @@ -63,7 +63,7 @@ def __init__(self, var): pv: PatternVar A variable pattern. """ - self.__init_handle_by_constructor__(_make.PatternVar, var) + self.__init_handle_by_constructor__(_ffi_api.PatternVar, var) @register_relay_node @@ -88,7 +88,7 @@ def __init__(self, constructor, patterns=None): """ if patterns is None: patterns = [] - self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns) + self.__init_handle_by_constructor__(_ffi_api.PatternConstructor, constructor, patterns) @register_relay_node @@ -111,7 +111,7 @@ def __init__(self, patterns=None): """ if patterns is None: patterns = [] - self.__init_handle_by_constructor__(_make.PatternTuple, patterns) + self.__init_handle_by_constructor__(_ffi_api.PatternTuple, patterns) @register_relay_node @@ -133,7 +133,7 @@ def __init__(self, lhs, rhs): clause: Clause The Clause. """ - self.__init_handle_by_constructor__(_make.Clause, lhs, rhs) + self.__init_handle_by_constructor__(_ffi_api.Clause, lhs, rhs) @register_relay_node @@ -160,4 +160,4 @@ def __init__(self, data, clauses, complete=True): match: tvm.relay.Expr The match expression. """ - self.__init_handle_by_constructor__(_make.Match, data, clauses, complete) + self.__init_handle_by_constructor__(_ffi_api.Match, data, clauses, complete) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/ir/base.py similarity index 93% rename from python/tvm/relay/base.py rename to python/tvm/relay/ir/base.py index 0d6f22f446cd..ad801ae23062 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/ir/base.py @@ -21,12 +21,10 @@ from tvm.runtime import Object from tvm.ir import SourceName, Span, Node as RelayNode -from . import _make -from . import _expr -from . import _base -__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") +__STD_PATH__ = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), \ + os.pardir), "std") @tvm._ffi.register_func("tvm.relay.std_path") def _std_path(): diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/ir/expr.py similarity index 94% rename from python/tvm/relay/expr.py rename to python/tvm/relay/ir/expr.py index 61a5fb7c63ba..1ee187b27e7b 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/ir/expr.py @@ -25,8 +25,7 @@ from tvm.ir import RelayExpr, GlobalVar, BaseFunc from .base import RelayNode, register_relay_node -from . import _make -from . import _expr +from . import _ffi_api from . import ty as _ty # alias relay expr as Expr. @@ -54,7 +53,7 @@ def astype(self, dtype): result : tvm.relay.Expr The result expression. """ - return _make.cast(self, dtype) + return _ffi_api.cast(self, dtype) def __neg__(self): return _op_make.negative(self) @@ -170,7 +169,7 @@ class Constant(ExprWithOp): The data content of the constant expression. """ def __init__(self, data): - self.__init_handle_by_constructor__(_make.Constant, data) + self.__init_handle_by_constructor__(_ffi_api.Constant, data) @register_relay_node @@ -183,7 +182,7 @@ class Tuple(ExprWithOp): The fields in the tuple. """ def __init__(self, fields): - self.__init_handle_by_constructor__(_make.Tuple, fields) + self.__init_handle_by_constructor__(_ffi_api.Tuple, fields) def __getitem__(self, index): if index >= len(self): @@ -216,7 +215,7 @@ class Var(ExprWithOp): """ def __init__(self, name_hint, type_annotation=None): self.__init_handle_by_constructor__( - _make.Var, name_hint, type_annotation) + _ffi_api.Var, name_hint, type_annotation) @property def name_hint(self): @@ -254,7 +253,7 @@ def __init__(self, type_params = convert([]) self.__init_handle_by_constructor__( - _make.Function, params, body, ret_type, type_params, attrs) + _ffi_api.Function, params, body, ret_type, type_params, attrs) def __call__(self, *args): """Invoke the global function. @@ -282,7 +281,7 @@ def with_attr(self, attr_key, attr_value): func : Function A new copy of the function """ - return _expr.FunctionWithAttr( + return _ffi_api.FunctionWithAttr( self, attr_key, convert(attr_value)) @@ -313,7 +312,7 @@ def __init__(self, op, args, attrs=None, type_args=None): if not type_args: type_args = [] self.__init_handle_by_constructor__( - _make.Call, op, args, attrs, type_args) + _ffi_api.Call, op, args, attrs, type_args) @register_relay_node @@ -333,7 +332,7 @@ class Let(ExprWithOp): """ def __init__(self, variable, value, body): self.__init_handle_by_constructor__( - _make.Let, variable, value, body) + _ffi_api.Let, variable, value, body) @register_relay_node @@ -353,7 +352,7 @@ class If(ExprWithOp): """ def __init__(self, cond, true_branch, false_branch): self.__init_handle_by_constructor__( - _make.If, cond, true_branch, false_branch) + _ffi_api.If, cond, true_branch, false_branch) @register_relay_node @@ -370,7 +369,7 @@ class TupleGetItem(ExprWithOp): """ def __init__(self, tuple_value, index): self.__init_handle_by_constructor__( - _make.TupleGetItem, tuple_value, index) + _ffi_api.TupleGetItem, tuple_value, index) @register_relay_node @@ -382,7 +381,7 @@ class RefCreate(ExprWithOp): The initial value. """ def __init__(self, value): - self.__init_handle_by_constructor__(_make.RefCreate, value) + self.__init_handle_by_constructor__(_ffi_api.RefCreate, value) @register_relay_node @@ -394,7 +393,7 @@ class RefRead(ExprWithOp): The reference. """ def __init__(self, ref): - self.__init_handle_by_constructor__(_make.RefRead, ref) + self.__init_handle_by_constructor__(_ffi_api.RefRead, ref) @register_relay_node @@ -410,7 +409,7 @@ class RefWrite(ExprWithOp): The new value. """ def __init__(self, ref, value): - self.__init_handle_by_constructor__(_make.RefWrite, ref, value) + self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value) class TempExpr(ExprWithOp): @@ -427,7 +426,7 @@ def realize(self): ------- The corresponding normal expression. """ - return _expr.TempExprRealize(self) + return _ffi_api.TempExprRealize(self) class TupleWrapper(object): @@ -587,4 +586,4 @@ def bind(expr, binds): result : tvm.relay.Expr The expression or function after binding. """ - return _expr.Bind(expr, binds) + return _ffi_api.Bind(expr, binds) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/ir/expr_functor.py similarity index 99% rename from python/tvm/relay/expr_functor.py rename to python/tvm/relay/ir/expr_functor.py index 8d6923920979..d3beb200e407 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/ir/expr_functor.py @@ -21,7 +21,7 @@ from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite from .adt import Constructor, Match, Clause -from .op import Op +from ..op import Op class ExprFunctor: """ diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/ir/loops.py similarity index 100% rename from python/tvm/relay/loops.py rename to python/tvm/relay/ir/loops.py diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/ir/parser.py similarity index 94% rename from python/tvm/relay/parser.py rename to python/tvm/relay/ir/parser.py index 6c4e3131e3c2..053d299da022 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/ir/parser.py @@ -16,14 +16,14 @@ # under the License. """A parser for Relay's text format.""" from __future__ import absolute_import -from .. import register_func +from ... import register_func @register_func("relay.fromtext") def fromtext(data, source_name=None): """Parse a Relay program.""" # pylint: disable=import-outside-toplevel - from tvm.relay import _parser + from . import _parser x = _parser.fromtext(data + "\n", source_name) if x is None: raise Exception("cannot parse: ", data) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/ir/prelude.py similarity index 99% rename from python/tvm/relay/prelude.py rename to python/tvm/relay/ir/prelude.py index 5288a2e08011..fa68d3ae177a 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/ir/prelude.py @@ -20,10 +20,10 @@ from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, Function, GlobalVar, If, const -from .op.tensor import add, subtract, equal +from ..op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard -from . import op +from .. import op class TensorArrayOps(object): diff --git a/python/tvm/relay/scope_builder.py b/python/tvm/relay/ir/scope_builder.py similarity index 99% rename from python/tvm/relay/scope_builder.py rename to python/tvm/relay/ir/scope_builder.py index cd8dc8dcd309..35357707e535 100644 --- a/python/tvm/relay/scope_builder.py +++ b/python/tvm/relay/ir/scope_builder.py @@ -20,7 +20,7 @@ from . import ty as _ty from . import expr as _expr -from .._ffi import base as _base +from ..._ffi import base as _base class WithScope(object): """A wrapper for builder methods which introduce scoping. diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ir/ty.py similarity index 95% rename from python/tvm/relay/ty.py rename to python/tvm/relay/ir/ty.py index 13d7f9197e79..b9643803f2f6 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ir/ty.py @@ -21,9 +21,9 @@ from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType from .base import RelayNode, register_relay_node -from . import _make +from . import _ffi_api -Any = _make.Any +Any = _ffi_api.Any def type_has_any(tensor_type): """Check whether type has any as a shape. @@ -36,7 +36,7 @@ def type_has_any(tensor_type): has_any : bool The check result. """ - return _make.IsDynamic(tensor_type) + return _ffi_api.IsDynamic(tensor_type) def ShapeVar(name): diff --git a/python/tvm/relay/type_functor.py b/python/tvm/relay/ir/type_functor.py similarity index 100% rename from python/tvm/relay/type_functor.py rename to python/tvm/relay/ir/type_functor.py diff --git a/python/tvm/relay/nn.py b/python/tvm/relay/nn.py deleted file mode 100644 index 2070eb9181e6..000000000000 --- a/python/tvm/relay/nn.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Neural network related operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.nn import * diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 1a1d0d3ff7ed..acc975cbe854 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -41,13 +41,13 @@ from . import _transform from . import _reduce from . import _algorithm -from ..base import register_relay_node +from ..ir.base import register_relay_node def _register_op_make(): # pylint: disable=import-outside-toplevel from . import _make - from .. import expr + from ..ir import expr expr._op_make = _make _register_op_make() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 33a193799288..039b2f738883 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -21,7 +21,7 @@ from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple -from ..expr import Tuple, TupleGetItem, const +from ..ir.expr import Tuple, TupleGetItem, const from . import nn as _nn from .op import register_gradient from .reduce import sum as _sum diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 17fab80118af..414b458cbd7c 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,7 +17,7 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs from . import _make -from ..expr import TupleWrapper +from ..ir.expr import TupleWrapper def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 30918a4183b1..36f2fa565c9f 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -17,7 +17,7 @@ #pylint: disable=invalid-name, too-many-lines """Neural network operations.""" from __future__ import absolute_import as _abs -from ...expr import TupleWrapper +from ...ir.expr import TupleWrapper from . import _make from .util import get_pad_tuple2d diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 4cd4b2a2a465..26a6b6ead8cd 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -19,8 +19,8 @@ import tvm._ffi from tvm.driver import lower, build -from ..base import register_relay_node -from ..expr import RelayExpr +from ..ir.base import register_relay_node +from ..ir.expr import RelayExpr from ...target import get_native_generic_func, GenericFunc from ...runtime import Object from . import _make diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 12abf4a787db..9224f570f39d 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -17,7 +17,7 @@ """The attributes node used for Relay operators""" from tvm.ir import Attrs -from ..base import register_relay_attr_node +from ..ir.base import register_relay_attr_node @register_relay_attr_node diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index d3226012e887..05f8932e396c 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -20,7 +20,7 @@ from . import _make from .tensor import sqrt from .transform import squeeze -from ..expr import Tuple, TupleWrapper +from ..ir.expr import Tuple, TupleWrapper def argmax(data, axis=None, keepdims=False, exclude=False): """Returns the indices of the maximum values along an axis. diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 77969185c0a7..90e4604e80f3 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -20,7 +20,7 @@ from tvm.runtime import TVMContext as _TVMContext from . import _make -from ..expr import Tuple +from ..ir.expr import Tuple # We create a wrapper function for each operator in the diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 0955978f81a0..3e18447c2c6d 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -19,7 +19,7 @@ """Transform operators.""" from . import _make -from ..expr import TupleWrapper, const +from ..ir.expr import TupleWrapper, const def cast(data, dtype): @@ -38,7 +38,7 @@ def cast(data, dtype): result : relay.Expr The casted result. """ - from .. import _make as _relay_make + from ..ir import _ffi_api as _relay_make return _relay_make.cast(data, dtype) diff --git a/python/tvm/relay/op/vision/multibox.py b/python/tvm/relay/op/vision/multibox.py index 55fb01c5eaef..73f4893412fc 100644 --- a/python/tvm/relay/op/vision/multibox.py +++ b/python/tvm/relay/op/vision/multibox.py @@ -16,7 +16,7 @@ # under the License. """Multibox operations.""" from . import _make -from ...expr import TupleWrapper +from ...ir.expr import TupleWrapper def multibox_prior(data, sizes=(1.0,), diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index cba08bfba824..859c4999545d 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -16,7 +16,7 @@ # under the License. """Non-maximum suppression operations.""" from . import _make -from ...expr import TupleWrapper +from ...ir.expr import TupleWrapper def get_valid_counts(data, score_threshold, diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index f9874b78467e..b1c19092b4c7 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -21,7 +21,7 @@ import tvm from tvm import relay from .. import op as reg -from ...util import get_scalar_from_constant +from ...frontend.util import get_scalar_from_constant ################################################# # Register the functions for different operators. diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index c94a4194daee..b53992637250 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,7 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm.relay.expr import Tuple +from tvm.relay.ir import Tuple from tvm.relay.op.nn.util import get_pad_tuple2d from . import _make diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index b77516de6839..26d8a18f7b98 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -24,7 +24,7 @@ from .. import analysis as _analysis from .. import op as _op from ..op import op as _reg -from ..base import register_relay_node +from ..ir.base import register_relay_node from . import _quantize from .quantize import QAnnotateKind, current_qconfig, quantize_context from .quantize import _forward_op diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index fbac767cea24..90274e879992 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -19,7 +19,7 @@ import tvm from .. import expr as _expr from .. import analysis as _analysis -from ..base import register_relay_node +from ..ir.base import register_relay_node from ..op import op as _reg from . import _quantize from .quantize import _forward_op diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 56a4645058e5..bd1ab6fd7dfc 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,7 @@ from ._calibrate import calibrate from .. import expr as _expr from .. import transform as _transform -from ..base import Object, register_relay_node +from ..ir.base import Object, register_relay_node class QAnnotateKind(object): diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 54c909179e4f..068c02c6d53b 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -23,9 +23,9 @@ from tvm import te import tvm.relay as relay import tvm.relay.op as op -from tvm.relay import transform -from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor -from tvm.relay import TensorType, TupleType +from tvm.relay import transform, create_executor +from tvm.relay.ir import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem +from tvm.relay.ir import TensorType, TupleType from . import mlp from . import resnet diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index eb71120610d3..d1110fdb19d5 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -19,10 +19,10 @@ Nats are useful for testing purposes, as they make it easy to write test cases for recursion and pattern matching.""" -from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar +from tvm.relay.ir import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar from tvm.relay.backend.interpreter import ConstructorValue -from tvm.relay.expr import Var, Function, GlobalVar -from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType +from tvm.relay.ir import Var, Function, GlobalVar +from tvm.relay.ir import GlobalTypeVar, TypeVar, FuncType def define_nat_adt(prelude): """Defines a Peano (unary) natural number ADT. diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index eacfe379137f..e40436c4df22 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -21,10 +21,10 @@ import tvm from tvm import relay -from tvm.relay.adt import Pattern +from tvm.relay.ir import Pattern from tvm.relay.backend import compile_engine -from tvm.relay.expr import Expr, Function, GlobalVar, Var -from tvm.relay.expr_functor import ExprFunctor +from tvm.relay.ir import Expr, Function, GlobalVar, Var +from tvm.relay.ir import ExprFunctor OUTPUT_VAR_NAME = '_py_out' diff --git a/python/tvm/relay/_expr.py b/python/tvm/relay/transform/__init__.py similarity index 79% rename from python/tvm/relay/_expr.py rename to python/tvm/relay/transform/__init__.py index 70c13ce4eaa8..93d4341635a0 100644 --- a/python/tvm/relay/_expr.py +++ b/python/tvm/relay/transform/__init__.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable -"""The interface of expr function exposed from C++.""" -import tvm._ffi +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay IR namespace containing transformations.""" +# transformation passes +from .transform import * -tvm._ffi._init_api("relay._expr", __name__) +from . import memory_alloc diff --git a/python/tvm/relay/_transform.py b/python/tvm/relay/transform/_ffi_api.py similarity index 93% rename from python/tvm/relay/_transform.py rename to python/tvm/relay/transform/_ffi_api.py index a4168dfb5c0c..32c79cb6b2a3 100644 --- a/python/tvm/relay/_transform.py +++ b/python/tvm/relay/transform/_ffi_api.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI exposing the Relay type inference and checking.""" +"""FFI APIs for Relay transformation passes.""" import tvm._ffi tvm._ffi._init_api("relay._transform", __name__) diff --git a/python/tvm/relay/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py similarity index 98% rename from python/tvm/relay/memory_alloc.py rename to python/tvm/relay/transform/memory_alloc.py index f8e981121031..7b5a0b75ac6c 100644 --- a/python/tvm/relay/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -19,12 +19,13 @@ A pass for manifesting explicit memory allocations. """ import numpy as np -from .expr_functor import ExprMutator -from .scope_builder import ScopeBuilder +from ..ir.expr_functor import ExprMutator +from ..ir.scope_builder import ScopeBuilder from . import transform -from . import op, ty, expr -from .. import DataType, register_func -from .backend import compile_engine +from .. import op +from ... import DataType, register_func +from ..ir import ty, expr +from ..backend import compile_engine def is_primitive(call): diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform/transform.py similarity index 92% rename from python/tvm/relay/transform.py rename to python/tvm/relay/transform/transform.py index b2565f3f97eb..e32de3e3d112 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -28,8 +28,8 @@ from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass from tvm import relay -from . import _transform -from .base import register_relay_node +from . import _ffi_api +from ..ir.base import register_relay_node def build_config(opt_level=2, @@ -98,7 +98,7 @@ def InferType(): ret : tvm.relay.Pass The registered type inference pass. """ - return _transform.InferType() + return _ffi_api.InferType() def FoldScaleAxis(): @@ -116,7 +116,7 @@ def FoldScaleAxis(): forward_fold_scale_axis as backward folding targets the common conv->bn pattern. """ - return _transform.FoldScaleAxis() + return _ffi_api.FoldScaleAxis() def BackwardFoldScaleAxis(): @@ -133,7 +133,7 @@ def BackwardFoldScaleAxis(): before using forward_fold_scale_axis as backward folding targets the common conv->bn pattern. """ - return _transform.BackwardFoldScaleAxis() + return _ffi_api.BackwardFoldScaleAxis() def RemoveUnusedFunctions(entry_functions=None): """Remove unused global relay functions in a relay module. @@ -150,7 +150,7 @@ def RemoveUnusedFunctions(entry_functions=None): """ if entry_functions is None: entry_functions = ['main'] - return _transform.RemoveUnusedFunctions(entry_functions) + return _ffi_api.RemoveUnusedFunctions(entry_functions) def ForwardFoldScaleAxis(): """Fold the scaling of axis into weights of conv2d/dense. @@ -166,7 +166,7 @@ def ForwardFoldScaleAxis(): before using forward_fold_scale_axis, as backward folding targets the common conv->bn pattern. """ - return _transform.ForwardFoldScaleAxis() + return _ffi_api.ForwardFoldScaleAxis() def SimplifyInference(): @@ -178,7 +178,7 @@ def SimplifyInference(): ret: tvm.relay.Pass The registered pass to perform operator simplification. """ - return _transform.SimplifyInference() + return _ffi_api.SimplifyInference() def FastMath(): @@ -189,7 +189,7 @@ def FastMath(): ret: tvm.relay.Pass The registered pass to perform fast math operations. """ - return _transform.FastMath() + return _ffi_api.FastMath() def CanonicalizeOps(): @@ -202,7 +202,7 @@ def CanonicalizeOps(): ret: tvm.relay.Pass The registered pass performing the canonicalization. """ - return _transform.CanonicalizeOps() + return _ffi_api.CanonicalizeOps() def DeadCodeElimination(inline_once=False): @@ -218,7 +218,7 @@ def DeadCodeElimination(inline_once=False): ret: tvm.relay.Pass The registered pass that eliminates the dead code in a Relay program. """ - return _transform.DeadCodeElimination(inline_once) + return _ffi_api.DeadCodeElimination(inline_once) def FoldConstant(): @@ -229,7 +229,7 @@ def FoldConstant(): ret : tvm.relay.Pass The registered pass for constant folding. """ - return _transform.FoldConstant() + return _ffi_api.FoldConstant() def FuseOps(fuse_opt_level=-1): @@ -246,7 +246,7 @@ def FuseOps(fuse_opt_level=-1): ret : tvm.relay.Pass The registered pass for operator fusion. """ - return _transform.FuseOps(fuse_opt_level) + return _ffi_api.FuseOps(fuse_opt_level) def CombineParallelConv2D(min_num_branches=3): @@ -263,7 +263,7 @@ def CombineParallelConv2D(min_num_branches=3): ret: tvm.relay.Pass The registered pass that combines parallel conv2d operators. """ - return _transform.CombineParallelConv2D(min_num_branches) + return _ffi_api.CombineParallelConv2D(min_num_branches) def CombineParallelDense(min_num_branches=3): @@ -295,7 +295,7 @@ def CombineParallelDense(min_num_branches=3): ret: tvm.relay.Pass The registered pass that combines parallel dense operators. """ - return _transform.CombineParallelDense(min_num_branches) + return _ffi_api.CombineParallelDense(min_num_branches) def AlterOpLayout(): @@ -309,7 +309,7 @@ def AlterOpLayout(): ret : tvm.relay.Pass The registered pass that alters the layout of operators. """ - return _transform.AlterOpLayout() + return _ffi_api.AlterOpLayout() def ConvertLayout(desired_layout): @@ -337,7 +337,7 @@ def ConvertLayout(desired_layout): pass: FunctionPass The pass. """ - return _transform.ConvertLayout(desired_layout) + return _ffi_api.ConvertLayout(desired_layout) def Legalize(legalize_map_attr_name="FTVMLegalize"): @@ -357,7 +357,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): ret : tvm.relay.Pass The registered pass that rewrites an expr. """ - return _transform.Legalize(legalize_map_attr_name) + return _ffi_api.Legalize(legalize_map_attr_name) def MergeComposite(pattern_table): @@ -382,7 +382,7 @@ def MergeComposite(pattern_table): pattern_names.append(pattern_name) patterns.append(pattern) - return _transform.MergeComposite(pattern_names, patterns) + return _ffi_api.MergeComposite(pattern_names, patterns) def RewriteAnnotatedOps(fallback_device): @@ -403,7 +403,7 @@ def RewriteAnnotatedOps(fallback_device): The registered pass that rewrites an expression with annotated `on_device` operators. """ - return _transform.RewriteDeviceAnnotation(fallback_device) + return _ffi_api.RewriteDeviceAnnotation(fallback_device) def ToANormalForm(): @@ -417,7 +417,7 @@ def ToANormalForm(): ret: Union[tvm.relay.Pass, tvm.relay.Expr] The registered pass that transforms an expression into A Normal Form. """ - return _transform.ToANormalForm() + return _ffi_api.ToANormalForm() def ToCPS(expr, mod=None): @@ -431,7 +431,7 @@ def ToCPS(expr, mod=None): result: tvm.relay.Pass The registered pass that transforms an expression into CPS. """ - return _transform.to_cps(expr, mod) + return _ffi_api.to_cps(expr, mod) def EtaExpand(expand_constructor=False, expand_global_var=False): @@ -450,7 +450,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False): ret: tvm.relay.Pass The registered pass that eta expands an expression. """ - return _transform.EtaExpand(expand_constructor, expand_global_var) + return _ffi_api.EtaExpand(expand_constructor, expand_global_var) def ToGraphNormalForm(): @@ -461,7 +461,7 @@ def ToGraphNormalForm(): ret : tvm.relay.Pass The registered pass that transforms an expression into Graph Normal Form. """ - return _transform.ToGraphNormalForm() + return _ffi_api.ToGraphNormalForm() def EliminateCommonSubexpr(fskip=None): @@ -478,7 +478,7 @@ def EliminateCommonSubexpr(fskip=None): ret : tvm.relay.Pass The registered pass that eliminates common subexpressions. """ - return _transform.EliminateCommonSubexpr(fskip) + return _ffi_api.EliminateCommonSubexpr(fskip) def PartialEvaluate(): @@ -496,7 +496,7 @@ def PartialEvaluate(): ret: tvm.relay.Pass The registered pass that performs partial evaluation on an expression. """ - return _transform.PartialEvaluate() + return _ffi_api.PartialEvaluate() def CanonicalizeCast(): @@ -508,7 +508,7 @@ def CanonicalizeCast(): ret : tvm.relay.Pass The registered pass that canonicalizes cast expression. """ - return _transform.CanonicalizeCast() + return _ffi_api.CanonicalizeCast() def LambdaLift(): @@ -520,7 +520,7 @@ def LambdaLift(): ret : tvm.relay.Pass The registered pass that lifts the lambda function. """ - return _transform.LambdaLift() + return _ffi_api.LambdaLift() def PrintIR(show_meta_data=True): @@ -537,7 +537,7 @@ def PrintIR(show_meta_data=True): ret : tvm.relay.Pass The registered pass that prints the module IR. """ - return _transform.PrintIR(show_meta_data) + return _ffi_api.PrintIR(show_meta_data) def PartitionGraph(): @@ -549,7 +549,7 @@ def PartitionGraph(): ret: tvm.relay.Pass The registered pass that partitions the Relay program. """ - return _transform.PartitionGraph() + return _ffi_api.PartitionGraph() @@ -568,7 +568,7 @@ def AnnotateTarget(target): The annotated pass that wrapps ops with subgraph_start and subgraph_end. """ - return _transform.AnnotateTarget(target) + return _ffi_api.AnnotateTarget(target) def Inline(): @@ -581,7 +581,7 @@ def Inline(): ret: tvm.relay.Pass The registered pass that performs inlining for a Relay IR module. """ - return _transform.Inline() + return _ffi_api.Inline() def gradient(expr, mod=None, mode='higher_order'): @@ -609,9 +609,9 @@ def gradient(expr, mod=None, mode='higher_order'): The transformed expression. """ if mode == 'first_order': - return _transform.first_order_gradient(expr, mod) + return _ffi_api.first_order_gradient(expr, mod) if mode == 'higher_order': - return _transform.gradient(expr, mod) + return _ffi_api.gradient(expr, mod) raise Exception('unknown mode') @@ -634,7 +634,7 @@ def to_cps(func, mod=None): result: tvm.relay.Function The output function. """ - return _transform.to_cps(func, mod) + return _ffi_api.to_cps(func, mod) def un_cps(func): @@ -654,7 +654,7 @@ def un_cps(func): result: tvm.relay.Function The output function """ - return _transform.un_cps(func) + return _ffi_api.un_cps(func) def _wrap_class_function_pass(pass_cls, pass_info): @@ -670,7 +670,7 @@ def __init__(self, *args, **kwargs): def _pass_func(func, mod, ctx): return inst.transform_function(func, mod, ctx) self.__init_handle_by_constructor__( - _transform.MakeFunctionPass, _pass_func, pass_info) + _ffi_api.MakeFunctionPass, _pass_func, pass_info) self._inst = inst def __getattr__(self, name): @@ -778,7 +778,7 @@ def create_function_pass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): raise TypeError("pass_func must be a callable for Module pass") - return _transform.MakeFunctionPass(pass_arg, info) + return _ffi_api.MakeFunctionPass(pass_arg, info) if pass_func: return create_function_pass(pass_func) diff --git a/python/tvm/relay/vision.py b/python/tvm/relay/vision.py deleted file mode 100644 index f428295f03f4..000000000000 --- a/python/tvm/relay/vision.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, unused-import, unused-wildcard-import -"""Vision network related operators.""" -# Re-export in a specific file name so that autodoc can pick it up -from .op.vision import * diff --git a/src/relay/analysis/alpha_equal.cc b/src/relay/analysis/alpha_equal.cc index 726ccbbb3411..540284848d7c 100644 --- a/src/relay/analysis/alpha_equal.cc +++ b/src/relay/analysis/alpha_equal.cc @@ -581,7 +581,7 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs); } -TVM_REGISTER_GLOBAL("relay._analysis._alpha_equal") +TVM_REGISTER_GLOBAL("relay.analysis._alpha_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(false, false).Equal(a, b); }); @@ -591,18 +591,18 @@ TVM_REGISTER_GLOBAL("ir.type_alpha_equal") return AlphaEqual(a, b); }); -TVM_REGISTER_GLOBAL("relay._analysis._assert_alpha_equal") +TVM_REGISTER_GLOBAL("relay.analysis._assert_alpha_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; }); -TVM_REGISTER_GLOBAL("relay._analysis._graph_equal") +TVM_REGISTER_GLOBAL("relay.analysis._graph_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(true, false).Equal(a, b); }); -TVM_REGISTER_GLOBAL("relay._analysis._assert_graph_equal") +TVM_REGISTER_GLOBAL("relay.analysis._assert_graph_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; diff --git a/src/relay/analysis/call_graph.cc b/src/relay/analysis/call_graph.cc index b9ee8e148894..a12d23d88a30 100644 --- a/src/relay/analysis/call_graph.cc +++ b/src/relay/analysis/call_graph.cc @@ -299,24 +299,24 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "CallGraph: \n" << GetRef(node); }); -TVM_REGISTER_GLOBAL("relay._analysis.CallGraph") +TVM_REGISTER_GLOBAL("relay.analysis.CallGraph") .set_body_typed([](IRModule module) { return CallGraph(module); }); -TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraph") +TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph") .set_body_typed([](CallGraph call_graph) { std::stringstream ss; ss << call_graph; return ss.str(); }); -TVM_REGISTER_GLOBAL("relay._analysis.GetModule") +TVM_REGISTER_GLOBAL("relay.analysis.GetModule") .set_body_typed([](CallGraph call_graph) { return call_graph->module; }); -TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar") +TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; std::stringstream ss; @@ -324,19 +324,19 @@ TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar") return ss.str(); }); -TVM_REGISTER_GLOBAL("relay._analysis.GetRefCountGlobalVar") +TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return static_cast(entry_node->GetRefCount()); }); -TVM_REGISTER_GLOBAL("relay._analysis.GetGlobalVarCallCount") +TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return static_cast(entry_node->size()); }); -TVM_REGISTER_GLOBAL("relay._analysis.IsRecursive") +TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; return entry_node->IsRecursive(); diff --git a/src/relay/analysis/extract_fused_functions.cc b/src/relay/analysis/extract_fused_functions.cc index 3667d8a47826..8cb517f7e33d 100644 --- a/src/relay/analysis/extract_fused_functions.cc +++ b/src/relay/analysis/extract_fused_functions.cc @@ -74,7 +74,7 @@ Pass ExtractFusedFunctions() { "ExtractFusedFunctions"); } -TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions); +TVM_REGISTER_GLOBAL("relay.analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions); } // namespace transform diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index 4f0e829cdbde..95c2f731ff72 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -104,7 +104,7 @@ Array PyDetectFeature(const Expr& expr, const IRModule& mod) { return static_cast>(fs); } -TVM_REGISTER_GLOBAL("relay._analysis.detect_feature") +TVM_REGISTER_GLOBAL("relay.analysis.detect_feature") .set_body_typed(PyDetectFeature); } // namespace relay diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc index d43059cf9f6a..b4835ccb7a3c 100644 --- a/src/relay/analysis/kind_check.cc +++ b/src/relay/analysis/kind_check.cc @@ -186,7 +186,7 @@ Kind KindCheck(const Type& t, const IRModule& mod) { return kc.Check(t); } -TVM_REGISTER_GLOBAL("relay._analysis.check_kind") +TVM_REGISTER_GLOBAL("relay.analysis.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = KindCheck(args[0], IRModule({}, {})); diff --git a/src/relay/analysis/mac_count.cc b/src/relay/analysis/mac_count.cc index 49fe2a3900b3..fecde3c75669 100644 --- a/src/relay/analysis/mac_count.cc +++ b/src/relay/analysis/mac_count.cc @@ -206,7 +206,7 @@ int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } -TVM_REGISTER_GLOBAL("relay._analysis.GetTotalMacNumber") +TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber") .set_body_typed(GetTotalMacNumber); } // namespace mac_count diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index 14be6b751354..919065469a4d 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -310,7 +310,7 @@ Array UnmatchedCases(const Match& match, const IRModule& mod) { } // expose for testing only -TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases") +TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases") .set_body_typed( [](const Match& match, const IRModule& mod_ref) { IRModule call_mod = mod_ref; diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 8aa1ac9c5a8b..a6ac9ce9a7ec 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -659,7 +659,7 @@ bool TypeSolver::Solve() { } // Expose type solver only for debugging purposes. -TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") +TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 88c89dfee6b6..6a151d7d21f1 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -274,10 +274,10 @@ tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -TVM_REGISTER_GLOBAL("relay._analysis.free_vars") +TVM_REGISTER_GLOBAL("relay.analysis.free_vars") .set_body_typed(FreeVars); -TVM_REGISTER_GLOBAL("relay._analysis.bound_vars") +TVM_REGISTER_GLOBAL("relay.analysis.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; if (x.as()) { @@ -287,10 +287,10 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_vars") } }); -TVM_REGISTER_GLOBAL("relay._analysis.all_vars") +TVM_REGISTER_GLOBAL("relay.analysis.all_vars") .set_body_typed(AllVars); -TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") +TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; IRModule mod = args[1]; @@ -301,7 +301,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") } }); -TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") +TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; IRModule mod = args[1]; @@ -312,7 +312,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") } }); -TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars") +TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; IRModule mod = args[1]; diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index 72a6fcc8a6bb..f3a2cadb363f 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -125,7 +125,7 @@ bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); } -TVM_REGISTER_GLOBAL("relay._analysis.well_formed") +TVM_REGISTER_GLOBAL("relay.analysis.well_formed") .set_body_typed(WellFormed); } // namespace relay diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index d8a05f401bb5..ccbe4dfc858a 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -86,7 +86,7 @@ bool IsDynamic(const Type& ty) { } // TODO(@jroesch): MOVE ME -TVM_REGISTER_GLOBAL("relay._make.IsDynamic") +TVM_REGISTER_GLOBAL("relay.ir.IsDynamic") .set_body_typed(IsDynamic); Array GetShape(const Array& shape) { diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 29fc9abd0ea7..ff24825f14e8 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -34,7 +34,7 @@ PatternWildcard PatternWildcardNode::make() { TVM_REGISTER_NODE_TYPE(PatternWildcardNode); -TVM_REGISTER_GLOBAL("relay._make.PatternWildcard") +TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard") .set_body_typed(PatternWildcardNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -50,7 +50,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) { TVM_REGISTER_NODE_TYPE(PatternVarNode); -TVM_REGISTER_GLOBAL("relay._make.PatternVar") +TVM_REGISTER_GLOBAL("relay.ir.PatternVar") .set_body_typed(PatternVarNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -69,7 +69,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor, TVM_REGISTER_NODE_TYPE(PatternConstructorNode); -TVM_REGISTER_GLOBAL("relay._make.PatternConstructor") +TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor") .set_body_typed(PatternConstructorNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -87,7 +87,7 @@ PatternTuple PatternTupleNode::make(tvm::Array patterns) { TVM_REGISTER_NODE_TYPE(PatternTupleNode); -TVM_REGISTER_GLOBAL("relay._make.PatternTuple") +TVM_REGISTER_GLOBAL("relay.ir.PatternTuple") .set_body_typed(PatternTupleNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -105,7 +105,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) { TVM_REGISTER_NODE_TYPE(ClauseNode); -TVM_REGISTER_GLOBAL("relay._make.Clause") +TVM_REGISTER_GLOBAL("relay.ir.Clause") .set_body_typed(ClauseNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -125,7 +125,7 @@ Match MatchNode::make(Expr data, tvm::Array clauses, bool complete) { TVM_REGISTER_NODE_TYPE(MatchNode); -TVM_REGISTER_GLOBAL("relay._make.Match") +TVM_REGISTER_GLOBAL("relay.ir.Match") .set_body_typed(MatchNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index d9eb9108af50..5da5be3c43a7 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -38,7 +38,7 @@ Constant ConstantNode::make(runtime::NDArray data) { TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relay._make.Constant") +TVM_REGISTER_GLOBAL("relay.ir.Constant") .set_body_typed(ConstantNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -71,7 +71,7 @@ Tuple TupleNode::make(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_GLOBAL("relay._make.Tuple") +TVM_REGISTER_GLOBAL("relay.ir.Tuple") .set_body_typed(TupleNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -96,7 +96,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_GLOBAL("relay._make.Var") +TVM_REGISTER_GLOBAL("relay.ir.Var") .set_body_typed(static_cast(VarNode::make)); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -123,7 +123,7 @@ Call CallNode::make(Expr op, Array args, Attrs attrs, TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_GLOBAL("relay._make.Call") +TVM_REGISTER_GLOBAL("relay.ir.Call") .set_body_typed(CallNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -143,7 +143,7 @@ Let LetNode::make(Var var, Expr value, Expr body) { TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_GLOBAL("relay._make.Let") +TVM_REGISTER_GLOBAL("relay.ir.Let") .set_body_typed(LetNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -163,7 +163,7 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { TVM_REGISTER_NODE_TYPE(IfNode); -TVM_REGISTER_GLOBAL("relay._make.If") +TVM_REGISTER_GLOBAL("relay.ir.If") .set_body_typed(IfNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -182,7 +182,7 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay._make.TupleGetItem") +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem") .set_body_typed(TupleGetItemNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -199,7 +199,7 @@ RefCreate RefCreateNode::make(Expr value) { TVM_REGISTER_NODE_TYPE(RefCreateNode); -TVM_REGISTER_GLOBAL("relay._make.RefCreate") +TVM_REGISTER_GLOBAL("relay.ir.RefCreate") .set_body_typed(RefCreateNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -216,7 +216,7 @@ RefRead RefReadNode::make(Expr ref) { TVM_REGISTER_NODE_TYPE(RefReadNode); -TVM_REGISTER_GLOBAL("relay._make.RefRead") +TVM_REGISTER_GLOBAL("relay.ir.RefRead") .set_body_typed(RefReadNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -234,7 +234,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { TVM_REGISTER_NODE_TYPE(RefWriteNode); -TVM_REGISTER_GLOBAL("relay._make.RefWrite") +TVM_REGISTER_GLOBAL("relay.ir.RefWrite") .set_body_typed(RefWriteNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -243,12 +243,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; }); -TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize") +TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize") .set_body_typed([](TempExpr temp) { return temp->Realize(); }); -TVM_REGISTER_GLOBAL("relay._make.Any") +TVM_REGISTER_GLOBAL("relay.ir.Any") .set_body_typed([]() { return Any::make(); }); } // namespace relay diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 16e5fe1457d6..4b0239b8da21 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -347,7 +347,7 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit") +TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit") .set_body_typed([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); @@ -443,7 +443,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } -TVM_REGISTER_GLOBAL("relay._expr.Bind") +TVM_REGISTER_GLOBAL("relay.ir.Bind") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef input = args[0]; if (input->IsInstance()) { diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 63ad4ddb26d5..d371edb31fca 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -62,7 +62,7 @@ bool FunctionNode::UseDefaultCompiler() const { TVM_REGISTER_NODE_TYPE(FunctionNode); -TVM_REGISTER_GLOBAL("relay._make.Function") +TVM_REGISTER_GLOBAL("relay.ir.Function") .set_body_typed([](tvm::Array params, Expr body, Type ret_type, @@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->attrs << ")"; }); -TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr") +TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr") .set_body_typed( [](Function func, std::string name, ObjectRef ref) { return WithAttr(std::move(func), name, ref); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index b1bc76b18164..ce15e2a3fe70 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -423,12 +423,12 @@ size_t StructuralHash::operator()(const Expr& expr) const { return RelayHashHandler().ExprHash(expr); } -TVM_REGISTER_GLOBAL("relay._analysis._expr_hash") +TVM_REGISTER_GLOBAL("relay.analysis._expr_hash") .set_body_typed([](ObjectRef ref) { return static_cast(RelayHashHandler().Hash(ref)); }); -TVM_REGISTER_GLOBAL("relay._analysis._type_hash") +TVM_REGISTER_GLOBAL("relay.analysis._type_hash") .set_body_typed([](Type type) { return static_cast(RelayHashHandler().TypeHash(type)); }); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 2bdcbcf00a99..17d9788e8193 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -82,7 +82,7 @@ Expr MakeCast(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay._make.cast") +TVM_REGISTER_GLOBAL("relay.ir.cast") .set_body_typed(MakeCast); RELAY_REGISTER_OP("cast") @@ -138,7 +138,7 @@ Expr MakeCastLike(Expr data, } -TVM_REGISTER_GLOBAL("relay._make.cast_like") +TVM_REGISTER_GLOBAL("relay.ir.cast_like") .set_body_typed(MakeCastLike); RELAY_REGISTER_OP("cast_like") diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 03eb51d1c2c5..75afc9e7b63e 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -560,10 +560,10 @@ Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); } -TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceInfo") +TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo") .set_body_typed(CollectDeviceInfo); -TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps") +TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps") .set_body_typed(CollectDeviceAnnotationOps); namespace transform { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 2b4cc32bd790..8fcef2f15c49 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -73,7 +73,7 @@ bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); } -TVM_REGISTER_GLOBAL("relay._analysis.check_constant") +TVM_REGISTER_GLOBAL("relay.analysis.check_constant") .set_body_typed(ConstantCheck); // TODO(tvm-team) consider combine dead-code with constant folder. diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index deeb7330f9da..491f18df3de0 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -19,7 +19,7 @@ from tvm import relay from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr import numpy as np diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index aa81e3113b7f..914ac12ed6ff 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -20,7 +20,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.loops import while_loop +from tvm.relay.ir.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type def int32(val): diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 71428a6dbefd..b0399a53a732 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -17,10 +17,8 @@ import numpy as np import tvm -from tvm import te from tvm import relay from tvm.contrib import graph_runtime -from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.op import add from tvm.relay.testing.config import ctx_list diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 360b6bd20416..0534c16dedaa 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.runtime import container from tvm.relay.backend.interpreter import RefValue, ConstructorValue -from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay.ir import ScopeBuilder from tvm.relay import testing, create_executor diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_feature.py index 3ef53d3b88b1..f54fa713e957 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_feature.py @@ -18,10 +18,9 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.analysis import detect_feature +from tvm.relay.analysis import detect_feature, Feature from tvm.relay.transform import gradient -from tvm.relay.feature import Feature -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from tvm.relay.testing import run_infer_type def test_prelude(): diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py index bab82472263a..bfc9accf5906 100644 --- a/tests/python/relay/test_ir_module.py +++ b/tests/python/relay/test_ir_module.py @@ -18,7 +18,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from tvm.relay.testing import add_nat_definitions def constructor_list(p): diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index db953d5762e3..daf436c584cc 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -18,7 +18,7 @@ from tvm import te from tvm import relay from tvm.relay.analysis import well_formed -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude def test_let(): x = relay.Var("x") diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 49e9883d8ee8..e98c7ec1fe6f 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -22,7 +22,7 @@ from tvm import te from tvm import relay from tvm.contrib import graph_runtime -from tvm.relay.expr_functor import ExprMutator +from tvm.relay.ir import ExprMutator from tvm.relay import transform diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 6f2a12589fb5..48923b5ce061 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -22,7 +22,7 @@ from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal from tvm.relay import create_executor, transform from tvm.relay.transform import gradient -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand import tvm.relay.op as op diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index aed026996a21..2b865c57123d 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -21,8 +21,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay import ExprFunctor -from tvm.relay import Function, Call +from tvm.relay.ir import ExprFunctor, Function, Call from tvm.relay import analysis from tvm.relay import transform as _transform from tvm.relay.testing import ctx_list, run_infer_type diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index f54dd6bf69c5..f2f6f85955e4 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -20,11 +20,11 @@ from tvm import te from tvm import relay from tvm.relay.analysis import alpha_equal, assert_alpha_equal -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from tvm.relay import op, create_executor, transform -from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate -from tvm.relay import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match -from tvm.relay import GlobalVar, Call +from tvm.relay.ir import Var, TypeVar, TupleGetItem, Let, const, RefRead, RefWrite, RefCreate +from tvm.relay.ir import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match +from tvm.relay.ir import GlobalVar, Call, Function from tvm.relay.transform import gradient from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 11f7e971b283..36e371d68747 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -25,8 +25,8 @@ from tvm import runtime from tvm.relay import transform from tvm.contrib import util -from tvm.relay.annotation import compiler_begin, compiler_end -from tvm.relay.expr_functor import ExprMutator +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay.ir import ExprMutator # Leverage the pass manager to write a simple white list based annotator @transform.function_pass(opt_level=0) diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index 33816344f562..e42bdbd94986 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -19,7 +19,7 @@ from tvm import te from tvm import relay from tvm.relay import transform -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude def test_remove_all_prelude_functions(): diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index f68f64874c78..1f0c9f33cc73 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -20,9 +20,9 @@ from tvm import relay from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay import op, create_executor, transform -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from tvm.relay.testing import add_nat_definitions, count -from tvm.relay.feature import Feature +from tvm.relay.analysis import Feature def run_opt_pass(expr, passes): diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index fe4959ed8ce3..76b906cf96ab 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -20,11 +20,11 @@ from tvm import relay from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay.transform import to_cps, un_cps -from tvm.relay.feature import Feature -from tvm.relay.prelude import Prelude +from tvm.relay.analysis import Feature +from tvm.relay.ir import Prelude, Function from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass from tvm.relay import create_executor -from tvm.relay import Function, transform +from tvm.relay import transform def test_id(): diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index 42344bccabaa..bac0970b098b 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -18,7 +18,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from tvm.relay.analysis import unmatched_cases import pytest diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index f6b1b2432d92..36ed034ba9eb 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -16,10 +16,9 @@ # under the License. import numpy as np import tvm -from tvm import te from tvm import relay from tvm.relay.testing import to_python, run_as_python -from tvm.relay.prelude import Prelude +from tvm.relay.ir import Prelude from tvm.runtime.container import ADT from tvm.relay.backend.interpreter import RefValue, ConstructorValue diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py index 9e023bc6b1e4..8f581b5b715c 100644 --- a/tests/python/relay/test_type_functor.py +++ b/tests/python/relay/test_type_functor.py @@ -19,9 +19,9 @@ from tvm import relay from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor from tvm.relay.analysis import assert_graph_equal -from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType, +from tvm.relay.ir import (TypeVar, IncompleteType, TensorType, FuncType, TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall) -from tvm.relay.adt import TypeData +from tvm.relay.ir import TypeData def check_visit(typ): try: diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index d90fd29a7eb5..6d72ad3af9af 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te from tvm import relay import pytest @@ -27,7 +26,7 @@ def make_rel(name, args, num_inputs=None, attrs=None): return relay.ty.TypeRelation(func, args, num_inputs, attrs) def make_solver(): - solver = relay._analysis._test_type_solver() + solver = relay.analysis._ffi_api._test_type_solver() solver.Solve = solver("Solve") solver.Unify = solver("Unify") solver.Resolve = solver("Resolve") diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index a8ac27a11c0f..b31f04c5edd9 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -18,13 +18,12 @@ import pytest import tvm -from tvm import te from tvm import runtime from tvm import relay -from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay.ir import ScopeBuilder from tvm.relay.testing.config import ctx_list -from tvm.relay.prelude import Prelude -from tvm.relay.loops import while_loop +from tvm.relay.ir import Prelude +from tvm.relay.ir.loops import while_loop from tvm.relay import testing def check_result(args, expected_result, mod=None): diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 5d20651a8126..a2a786e6a8be 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -24,8 +24,8 @@ from tvm.relay import vm as rly_vm from tvm import relay -from tvm.relay.scope_builder import ScopeBuilder -from tvm.relay.prelude import Prelude +from tvm.relay.ir import ScopeBuilder +from tvm.relay.ir import Prelude from tvm.contrib import util from tvm.relay import testing diff --git a/tests/python/unittest/test_autotvm_graph_tuner_utils.py b/tests/python/unittest/test_autotvm_graph_tuner_utils.py index bd0ebe0cd3f5..b558010823e1 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_utils.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_utils.py @@ -28,7 +28,7 @@ from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \ get_out_nodes, expr2graph, bind_inputs from tvm.autotvm.graph_tuner._base import OPT_OUT_OP -from tvm.relay.expr import Call, TupleGetItem, Tuple, Var +from tvm.relay.ir import Call, TupleGetItem, Tuple, Var def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result): From bc6530500c319e80904d65d4b54219b40ec75a85 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 16 Mar 2020 18:18:02 +0000 Subject: [PATCH 2/4] revert relay/ir/*.py to relay --- docs/api/python/relay/base.rst | 12 +- docs/api/python/relay/expr.rst | 38 ++--- docs/api/python/relay/scope_builder.rst | 6 +- .../graph_tuner/utils/traverse_graph.py | 4 +- python/tvm/relay/__init__.py | 108 +++++++------- python/tvm/relay/{ir => }/_ffi_api.py | 0 python/tvm/relay/{ir => }/_parser.py | 10 +- python/tvm/relay/{ir => }/adt.py | 16 +- python/tvm/relay/analysis/analysis.py | 56 +++---- python/tvm/relay/analysis/call_graph.py | 2 +- python/tvm/relay/backend/compile_engine.py | 10 +- python/tvm/relay/backend/interpreter.py | 12 +- python/tvm/relay/{ir => }/base.py | 34 +---- python/tvm/relay/{ir => }/expr.py | 25 ++-- python/tvm/relay/{ir => }/expr_functor.py | 2 +- python/tvm/relay/frontend/pytorch.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 4 +- python/tvm/relay/ir/__init__.py | 95 ------------ python/tvm/relay/{ir => }/loops.py | 0 python/tvm/relay/op/__init__.py | 3 +- python/tvm/relay/op/_tensor_grad.py | 2 +- python/tvm/relay/op/algorithm.py | 2 +- python/tvm/relay/op/nn/nn.py | 2 +- python/tvm/relay/op/op.py | 5 +- python/tvm/relay/op/op_attrs.py | 140 ++++++++++-------- python/tvm/relay/op/reduce.py | 2 +- python/tvm/relay/op/tensor.py | 2 +- python/tvm/relay/op/transform.py | 6 +- python/tvm/relay/op/vision/multibox.py | 2 +- python/tvm/relay/op/vision/nms.py | 2 +- python/tvm/relay/{ir => }/parser.py | 4 +- python/tvm/relay/{ir => }/prelude.py | 4 +- python/tvm/relay/qnn/op/qnn.py | 2 +- python/tvm/relay/quantize/_annotate.py | 3 +- python/tvm/relay/quantize/_partition.py | 3 +- python/tvm/relay/quantize/quantize.py | 4 +- python/tvm/relay/{ir => }/scope_builder.py | 2 +- python/tvm/relay/testing/__init__.py | 6 +- python/tvm/relay/testing/nat.py | 6 +- python/tvm/relay/testing/py_converter.py | 6 +- python/tvm/relay/transform/memory_alloc.py | 6 +- python/tvm/relay/transform/transform.py | 3 +- python/tvm/relay/{ir => }/ty.py | 2 +- python/tvm/relay/{ir => }/type_functor.py | 0 tests/python/relay/test_adt.py | 2 +- ...st_feature.py => test_analysis_feature.py} | 2 +- tests/python/relay/test_any.py | 2 +- .../python/relay/test_backend_interpreter.py | 2 +- tests/python/relay/test_ir_module.py | 2 +- tests/python/relay/test_ir_well_formed.py | 2 +- tests/python/relay/test_pass_annotation.py | 3 +- tests/python/relay/test_pass_gradient.py | 2 +- tests/python/relay/test_pass_manager.py | 3 +- tests/python/relay/test_pass_partial_eval.py | 8 +- .../python/relay/test_pass_partition_graph.py | 2 +- .../test_pass_remove_unused_functions.py | 2 +- .../relay/test_pass_to_a_normal_form.py | 2 +- tests/python/relay/test_pass_to_cps.py | 3 +- .../python/relay/test_pass_unmatched_cases.py | 2 +- tests/python/relay/test_py_converter.py | 3 +- tests/python/relay/test_type_functor.py | 4 +- tests/python/relay/test_vm.py | 6 +- tests/python/relay/test_vm_serialization.py | 4 +- .../test_autotvm_graph_tuner_utils.py | 2 +- 64 files changed, 304 insertions(+), 409 deletions(-) rename python/tvm/relay/{ir => }/_ffi_api.py (100%) rename python/tvm/relay/{ir => }/_parser.py (99%) rename python/tvm/relay/{ir => }/adt.py (92%) rename python/tvm/relay/{ir => }/base.py (59%) rename python/tvm/relay/{ir => }/expr.py (97%) rename python/tvm/relay/{ir => }/expr_functor.py (99%) delete mode 100644 python/tvm/relay/ir/__init__.py rename python/tvm/relay/{ir => }/loops.py (100%) rename python/tvm/relay/{ir => }/parser.py (94%) rename python/tvm/relay/{ir => }/prelude.py (99%) rename python/tvm/relay/{ir => }/scope_builder.py (99%) rename python/tvm/relay/{ir => }/ty.py (97%) rename python/tvm/relay/{ir => }/type_functor.py (100%) rename tests/python/relay/{test_feature.py => test_analysis_feature.py} (98%) diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst index 61b18ac9c815..dc9dac0f67bd 100644 --- a/docs/api/python/relay/base.rst +++ b/docs/api/python/relay/base.rst @@ -15,16 +15,16 @@ specific language governing permissions and limitations under the License. -tvm.relay.ir.base +tvm.relay.base -------------- -.. automodule:: tvm.relay.ir.base +.. automodule:: tvm.relay.base -.. autofunction:: tvm.relay.ir.base.register_relay_node +.. autofunction:: tvm.relay.base.register_relay_node -.. autofunction:: tvm.relay.ir.base.register_relay_attr_node +.. autofunction:: tvm.relay.base.register_relay_attr_node -.. autoclass:: tvm.relay.ir.base.RelayNode +.. autoclass:: tvm.relay.base.RelayNode :members: -.. autoclass:: tvm.relay.ir.base.Id +.. autoclass:: tvm.relay.base.Id :members: diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst index e7944730a4b8..57a4a2511b72 100644 --- a/docs/api/python/relay/expr.rst +++ b/docs/api/python/relay/expr.rst @@ -15,56 +15,56 @@ specific language governing permissions and limitations under the License. -tvm.relay.ir.expr +tvm.relay.expr -------------- -.. automodule:: tvm.relay.ir.expr +.. automodule:: tvm.relay.expr -.. autofunction:: tvm.relay.ir.expr.var +.. autofunction:: tvm.relay.expr.var -.. autofunction:: tvm.relay.ir.expr.const +.. autofunction:: tvm.relay.expr.const -.. autofunction:: tvm.relay.ir.expr.bind +.. autofunction:: tvm.relay.expr.bind -.. autoclass:: tvm.relay.ir.expr.Expr +.. autoclass:: tvm.relay.expr.Expr :members: -.. autoclass:: tvm.relay.ir.expr.Constant +.. autoclass:: tvm.relay.expr.Constant :members: -.. autoclass:: tvm.relay.ir.expr.Tuple +.. autoclass:: tvm.relay.expr.Tuple :members: -.. autoclass:: tvm.relay.ir.expr.Function +.. autoclass:: tvm.relay.expr.Function :members: -.. autoclass:: tvm.relay.ir.expr.Call +.. autoclass:: tvm.relay.expr.Call :members: -.. autoclass:: tvm.relay.ir.expr.Let +.. autoclass:: tvm.relay.expr.Let :members: -.. autoclass:: tvm.relay.ir.expr.If +.. autoclass:: tvm.relay.expr.If :members: -.. autoclass:: tvm.relay.ir.expr.TupleGetItem +.. autoclass:: tvm.relay.expr.TupleGetItem :members: -.. autoclass:: tvm.relay.ir.expr.RefCreate +.. autoclass:: tvm.relay.expr.RefCreate :members: -.. autoclass:: tvm.relay.ir.expr.RefRead +.. autoclass:: tvm.relay.expr.RefRead :members: -.. autoclass:: tvm.relay.ir.expr.RefWrite +.. autoclass:: tvm.relay.expr.RefWrite :members: -.. autoclass:: tvm.relay.ir.expr.TupleGetItem +.. autoclass:: tvm.relay.expr.TupleGetItem :members: -.. autoclass:: tvm.relay.ir.expr.TempExpr +.. autoclass:: tvm.relay.expr.TempExpr :members: -.. autoclass:: tvm.relay.ir.expr.TupleWrapper +.. autoclass:: tvm.relay.expr.TupleWrapper :members: diff --git a/docs/api/python/relay/scope_builder.rst b/docs/api/python/relay/scope_builder.rst index 730751f7a581..6d8e01428e31 100644 --- a/docs/api/python/relay/scope_builder.rst +++ b/docs/api/python/relay/scope_builder.rst @@ -15,10 +15,10 @@ specific language governing permissions and limitations under the License. -tvm.relay.ir.scope_builder +tvm.relay.scope_builder ----------------------- -.. automodule:: tvm.relay.ir.scope_builder +.. automodule:: tvm.relay.scope_builder -.. autoclass:: tvm.relay.ir.scope_builder.ScopeBuilder +.. autoclass:: tvm.relay.scope_builder.ScopeBuilder :members: diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index e463d20242d7..f1dd40440532 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -21,8 +21,8 @@ import tvm from tvm import relay, autotvm from tvm.relay import transform -from tvm.relay.ir import Call, Function, TupleGetItem, Var, Constant, Tuple -from tvm.relay.ir import TupleType, TensorType +from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple +from tvm.relay.ty import TupleType, TensorType from tvm.autotvm.task import TaskExtractEnv from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 8184b586016c..f24fe44f9041 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -19,9 +19,16 @@ import os from sys import setrecursionlimit -from . import ir -from .ir import adt, expr, ty, base, scope_builder -from .ir import prelude, loops, parser +from . import base +from . import ty +from . import expr +from . import type_functor +from . import expr_functor +from . import adt +from . import prelude +from . import loops +from . import scope_builder +from . import parser from . import transform from . import analysis @@ -60,65 +67,66 @@ setrecursionlimit(10000) # Span -Span = ir.Span +Span = base.Span +SourceName = base.SourceName # Type -Type = ir.Type -TupleType = ir.TupleType -TensorType = ir.TensorType -TypeKind = ir.TypeKind -TypeVar = ir.TypeVar -ShapeVar = ir.ShapeVar -TypeConstraint = ir.TypeConstraint -FuncType = ir.FuncType -TypeRelation = ir.TypeRelation -IncompleteType = ir.IncompleteType -scalar_type = ir.scalar_type -RefType = ir.RefType -GlobalTypeVar = ir.GlobalTypeVar -TypeCall = ir.TypeCall -Any = ir.Any +Type = ty.Type +TupleType = ty.TupleType +TensorType = ty.TensorType +TypeKind = ty.TypeKind +TypeVar = ty.TypeVar +ShapeVar = ty.ShapeVar +TypeConstraint = ty.TypeConstraint +FuncType = ty.FuncType +TypeRelation = ty.TypeRelation +IncompleteType = ty.IncompleteType +scalar_type = ty.scalar_type +RefType = ty.RefType +GlobalTypeVar = ty.GlobalTypeVar +TypeCall = ty.TypeCall +Any = ty.Any # Expr -Expr = ir.Expr -Constant = ir.Constant -Tuple = ir.Tuple -Var = ir.Var -GlobalVar = ir.GlobalVar -Function = ir.Function -Call = ir.Call -Let = ir.Let -If = ir.If -TupleGetItem = ir.TupleGetItem -RefCreate = ir.RefCreate -RefRead = ir.RefRead -RefWrite = ir.RefWrite +Expr = expr.RelayExpr +Constant = expr.Constant +Tuple = expr.Tuple +Var = expr.Var +GlobalVar = expr.GlobalVar +Function = expr.Function +Call = expr.Call +Let = expr.Let +If = expr.If +TupleGetItem = expr.TupleGetItem +RefCreate = expr.RefCreate +RefRead = expr.RefRead +RefWrite = expr.RefWrite # ADT -Pattern = ir.Pattern -PatternWildcard = ir.PatternWildcard -PatternVar = ir.PatternVar -PatternConstructor = ir.PatternConstructor -PatternTuple = ir.PatternTuple -Constructor = ir.Constructor -TypeData = ir.TypeData -Clause = ir.Clause -Match = ir.Match +Pattern = adt.Pattern +PatternWildcard = adt.PatternWildcard +PatternVar = adt.PatternVar +PatternConstructor = adt.PatternConstructor +PatternTuple = adt.PatternTuple +Constructor = adt.Constructor +TypeData = adt.TypeData +Clause = adt.Clause +Match = adt.Match # helper functions -var = ir.var -const = ir.const -bind = ir.bind +var = expr.var +const = expr.const +bind = expr.bind # TypeFunctor -TypeFunctor = ir.TypeFunctor -TypeVisitor = ir.TypeVisitor -TypeMutator = ir.TypeMutator +TypeFunctor = type_functor.TypeFunctor +TypeVisitor = type_functor.TypeVisitor +TypeMutator = type_functor.TypeMutator # ExprFunctor -ExprFunctor = ir.ExprFunctor -ExprVisitor = ir.ExprVisitor -ExprMutator = ir.ExprMutator +ExprFunctor = expr_functor.ExprFunctor +ExprVisitor = expr_functor.ExprVisitor +ExprMutator = expr_functor.ExprMutator # Prelude Prelude = prelude.Prelude diff --git a/python/tvm/relay/ir/_ffi_api.py b/python/tvm/relay/_ffi_api.py similarity index 100% rename from python/tvm/relay/ir/_ffi_api.py rename to python/tvm/relay/_ffi_api.py diff --git a/python/tvm/relay/ir/_parser.py b/python/tvm/relay/_parser.py similarity index 99% rename from python/tvm/relay/ir/_parser.py rename to python/tvm/relay/_parser.py index 354014a78862..49bdbb393c2e 100644 --- a/python/tvm/relay/ir/_parser.py +++ b/python/tvm/relay/_parser.py @@ -40,11 +40,11 @@ def __new__(cls, *args, **kwds): import tvm.ir._ffi_api from tvm.ir import IRModule -from . import Span, SourceName +from .base import Span, SourceName from . import adt from . import expr from . import ty -from .. import op +from . import op PYTHON_VERSION = sys.version_info.major try: @@ -56,9 +56,9 @@ def __new__(cls, *args, **kwds): .format(version=PYTHON_VERSION)) try: - from ..grammar.py3.RelayVisitor import RelayVisitor - from ..grammar.py3.RelayParser import RelayParser - from ..grammar.py3.RelayLexer import RelayLexer + from .grammar.py3.RelayVisitor import RelayVisitor + from .grammar.py3.RelayParser import RelayParser + from .grammar.py3.RelayLexer import RelayLexer except ImportError: raise Exception("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.") diff --git a/python/tvm/relay/ir/adt.py b/python/tvm/relay/adt.py similarity index 92% rename from python/tvm/relay/ir/adt.py rename to python/tvm/relay/adt.py index 8b4127286948..df12aaece2da 100644 --- a/python/tvm/relay/ir/adt.py +++ b/python/tvm/relay/adt.py @@ -17,8 +17,10 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import """Algebraic data types in Relay.""" from tvm.ir import Constructor, TypeData +from tvm.runtime import Object +import tvm._ffi -from .base import RelayNode, register_relay_node, Object +from .base import RelayNode from . import _ffi_api from .ty import Type from .expr import ExprWithOp, RelayExpr, Call @@ -28,7 +30,7 @@ class Pattern(RelayNode): """Base type for pattern matching constructs.""" -@register_relay_node +@tvm._ffi.register_object("relay.PatternWildcard") class PatternWildcard(Pattern): """Wildcard pattern in Relay: Matches any ADT and binds nothing.""" @@ -47,7 +49,7 @@ def __init__(self): self.__init_handle_by_constructor__(_ffi_api.PatternWildcard) -@register_relay_node +@tvm._ffi.register_object("relay.PatternVar") class PatternVar(Pattern): """Variable pattern in Relay: Matches anything and binds it to the variable.""" @@ -66,7 +68,7 @@ def __init__(self, var): self.__init_handle_by_constructor__(_ffi_api.PatternVar, var) -@register_relay_node +@tvm._ffi.register_object("relay.PatternConstructor") class PatternConstructor(Pattern): """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" @@ -91,7 +93,7 @@ def __init__(self, constructor, patterns=None): self.__init_handle_by_constructor__(_ffi_api.PatternConstructor, constructor, patterns) -@register_relay_node +@tvm._ffi.register_object("relay.PatternTuple") class PatternTuple(Pattern): """Constructor pattern in Relay: Matches a tuple, binds recursively.""" @@ -114,7 +116,7 @@ def __init__(self, patterns=None): self.__init_handle_by_constructor__(_ffi_api.PatternTuple, patterns) -@register_relay_node +@tvm._ffi.register_object("relay.Clause") class Clause(Object): """Clause for pattern matching in Relay.""" @@ -136,7 +138,7 @@ def __init__(self, lhs, rhs): self.__init_handle_by_constructor__(_ffi_api.Clause, lhs, rhs) -@register_relay_node +@tvm._ffi.register_object("relay.Match") class Match(ExprWithOp): """Pattern matching expression in Relay.""" diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 2e4465bac08c..beb3c6599e28 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -24,7 +24,7 @@ from . import _ffi_api from .feature import Feature -from ..ir import Type +from ..ty import Type def post_order_visit(expr, fvisit): @@ -34,7 +34,7 @@ def post_order_visit(expr, fvisit): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression. fvisit : function @@ -48,7 +48,7 @@ def well_formed(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -95,7 +95,7 @@ def check_constant(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -111,7 +111,7 @@ def free_vars(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -133,7 +133,7 @@ def bound_vars(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -149,7 +149,7 @@ def all_vars(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression Returns @@ -165,7 +165,7 @@ def free_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -185,7 +185,7 @@ def bound_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -205,7 +205,7 @@ def all_type_vars(expr, mod=None): Parameters ---------- - expr : Union[tvm.relay.ir.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type mod : Optional[tvm.IRModule] @@ -225,10 +225,10 @@ def alpha_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.ir.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs : tvm.relay.ir.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns @@ -244,10 +244,10 @@ def assert_alpha_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.ir.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs : tvm.relay.ir.Expr + rhs : tvm.relay.Expr One of the input Expression. """ _ffi_api._assert_alpha_equal(lhs, rhs) @@ -261,10 +261,10 @@ def graph_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.ir.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs : tvm.relay.ir.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns @@ -283,10 +283,10 @@ def assert_graph_equal(lhs, rhs): Parameters ---------- - lhs : tvm.relay.ir.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs : tvm.relay.ir.Expr + rhs : tvm.relay.Expr One of the input Expression. """ _ffi_api._assert_graph_equal(lhs, rhs) @@ -298,13 +298,13 @@ def collect_device_info(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression. Returns ------- ret : Dict[tvm.relay.ir.expr, int] - A dictionary mapping tvm.relay.ir.Expr to device type. + A dictionary mapping tvm.relay.Expr to device type. """ return _ffi_api.CollectDeviceInfo(expr) @@ -314,13 +314,13 @@ def collect_device_annotation_ops(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression. Returns ------- - ret : Dict[tvm.relay.ir.Expr, int] - A dictionary mapping tvm.relay.ir.Expr to device type where the keys are + ret : Dict[tvm.relay.Expr, int] + A dictionary mapping tvm.relay.Expr to device type where the keys are annotation expressions. """ return _ffi_api.CollectDeviceAnnotationOps(expr) @@ -332,7 +332,7 @@ def get_total_mac_number(expr): Parameters ---------- - expr : tvm.relay.ir.Expr + expr : tvm.relay.Expr The input expression. Returns @@ -369,10 +369,10 @@ def detect_feature(a, b=None): Parameters ---------- - a : Union[tvm.relay.ir.Expr, tvm.IRModule] + a : Union[tvm.relay.Expr, tvm.IRModule] The input expression or module. - b : Optional[Union[tvm.relay.ir.Expr, tvm.IRModule]] + b : Optional[Union[tvm.relay.Expr, tvm.IRModule]] The input expression or module. The two arguments cannot both be expression or module. @@ -391,7 +391,7 @@ def structural_hash(value): Parameters ---------- - expr : Union[tvm.relay.ir.Expr, tvm.relay.Type] + expr : Union[tvm.relay.Expr, tvm.relay.Type] The expression to hash. Returns @@ -405,7 +405,7 @@ def structural_hash(value): return int(_ffi_api._type_hash(value)) else: msg = ("found value of type {0} expected" + - "relay.ir.Expr or relay.Type").format(type(value)) + "relay.Expr or relay.Type").format(type(value)) raise TypeError(msg) diff --git a/python/tvm/relay/analysis/call_graph.py b/python/tvm/relay/analysis/call_graph.py index 0d1053612c8d..966659aac494 100644 --- a/python/tvm/relay/analysis/call_graph.py +++ b/python/tvm/relay/analysis/call_graph.py @@ -19,7 +19,7 @@ from tvm.ir import IRModule from tvm.runtime import Object -from ..ir import GlobalVar +from ..expr import GlobalVar from . import _ffi_api diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 38efafeee66e..03d91d5beb0f 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -22,7 +22,7 @@ import numpy as np import tvm from tvm import te -from ..ir.base import register_relay_node, Object +from tvm.runtime import Object from ... import target as _target from ... import autotvm from .. import expr as _expr @@ -33,7 +33,7 @@ logger = logging.getLogger('compile_engine') -@register_relay_node +@tvm._ffi.register_object("relay.LoweredOutput") class LoweredOutput(Object): """Lowered output""" def __init__(self, outputs, implement): @@ -41,7 +41,7 @@ def __init__(self, outputs, implement): _backend._make_LoweredOutput, outputs, implement) -@register_relay_node +@tvm._ffi.register_object("relay.CCacheKey") class CCacheKey(Object): """Key in the CompileEngine. @@ -58,7 +58,7 @@ def __init__(self, source_func, target): _backend._make_CCacheKey, source_func, target) -@register_relay_node +@tvm._ffi.register_object("relay.CCacheValue") class CCacheValue(Object): """Value in the CompileEngine, including usage statistics. """ @@ -261,7 +261,7 @@ def lower_call(call, inputs, target): return LoweredOutput(outputs, best_impl) -@register_relay_node +@tvm._ffi.register_object("relay.CompileEngine") class CompileEngine(Object): """CompileEngine to get lowered code. """ diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index a7245987ab01..ab39f7c56446 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -20,25 +20,25 @@ import numpy as np -from tvm.runtime import container +import tvm._ffi +from tvm.runtime import container, Object from tvm.ir import IRModule from . import _backend from .. import _make, analysis, transform from ... import nd -from ..ir.base import Object, register_relay_node -from ..ir import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const -from ..ir.scope_builder import ScopeBuilder +from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const +from ..scope_builder import ScopeBuilder -@register_relay_node +@tvm._ffi.register_object("relay.ConstructorValue") class ConstructorValue(Object): def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( _make.ConstructorValue, tag, fields, constructor) -@register_relay_node +@tvm._ffi.register_object("relay.RefValue") class RefValue(Object): def __init__(self, value): self.__init_handle_by_constructor__( diff --git a/python/tvm/relay/ir/base.py b/python/tvm/relay/base.py similarity index 59% rename from python/tvm/relay/ir/base.py rename to python/tvm/relay/base.py index ad801ae23062..2c35681deb80 100644 --- a/python/tvm/relay/ir/base.py +++ b/python/tvm/relay/base.py @@ -23,43 +23,15 @@ from tvm.ir import SourceName, Span, Node as RelayNode -__STD_PATH__ = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), \ - os.pardir), "std") +__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") + @tvm._ffi.register_func("tvm.relay.std_path") def _std_path(): return __STD_PATH__ -def register_relay_node(type_key=None): - """Register a Relay node type. - - Parameters - ---------- - type_key : str or cls - The type key of the node. - """ - if not isinstance(type_key, str): - return tvm._ffi.register_object( - "relay." + type_key.__name__)(type_key) - return tvm._ffi.register_object(type_key) - - -def register_relay_attr_node(type_key=None): - """Register a Relay attribute node. - - Parameters - ---------- - type_key : str or cls - The type key of the node. - """ - if not isinstance(type_key, str): - return tvm._ffi.register_object( - "relay.attrs." + type_key.__name__)(type_key) - return tvm._ffi.register_object(type_key) - - -@register_relay_node +@tvm._ffi.register_object("relay.Id") class Id(Object): """Unique identifier(name) used in Var. Guaranteed to be stable across all passes. diff --git a/python/tvm/relay/ir/expr.py b/python/tvm/relay/expr.py similarity index 97% rename from python/tvm/relay/ir/expr.py rename to python/tvm/relay/expr.py index 1ee187b27e7b..380cdf7d90ef 100644 --- a/python/tvm/relay/ir/expr.py +++ b/python/tvm/relay/expr.py @@ -20,11 +20,12 @@ from numbers import Number as _Number import numpy as _np +import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, convert, ndarray as _nd from tvm.ir import RelayExpr, GlobalVar, BaseFunc -from .base import RelayNode, register_relay_node +from .base import RelayNode from . import _ffi_api from . import ty as _ty @@ -159,7 +160,7 @@ def __call__(self, *args): """ return Call(self, args) -@register_relay_node +@tvm._ffi.register_object("relay.Constant") class Constant(ExprWithOp): """A constant expression in Relay. @@ -172,7 +173,7 @@ def __init__(self, data): self.__init_handle_by_constructor__(_ffi_api.Constant, data) -@register_relay_node +@tvm._ffi.register_object("relay.Tuple") class Tuple(ExprWithOp): """Tuple expression that groups several fields together. @@ -196,7 +197,7 @@ def astype(self, _): raise TypeError("astype cannot be used on tuple") -@register_relay_node +@tvm._ffi.register_object("relay.Var") class Var(ExprWithOp): """A local variable in Relay. @@ -224,7 +225,7 @@ def name_hint(self): return name -@register_relay_node +@tvm._ffi.register_object("relay.Function") class Function(BaseFunc): """A function declaration expression. @@ -286,7 +287,7 @@ def with_attr(self, attr_key, attr_value): -@register_relay_node +@tvm._ffi.register_object("relay.Call") class Call(ExprWithOp): """Function call node in Relay. @@ -315,7 +316,7 @@ def __init__(self, op, args, attrs=None, type_args=None): _ffi_api.Call, op, args, attrs, type_args) -@register_relay_node +@tvm._ffi.register_object("relay.Let") class Let(ExprWithOp): """Let variable binding expression. @@ -335,7 +336,7 @@ def __init__(self, variable, value, body): _ffi_api.Let, variable, value, body) -@register_relay_node +@tvm._ffi.register_object("relay.If") class If(ExprWithOp): """A conditional expression in Relay. @@ -355,7 +356,7 @@ def __init__(self, cond, true_branch, false_branch): _ffi_api.If, cond, true_branch, false_branch) -@register_relay_node +@tvm._ffi.register_object("relay.TupleGetItem") class TupleGetItem(ExprWithOp): """Get index-th item from a tuple. @@ -372,7 +373,7 @@ def __init__(self, tuple_value, index): _ffi_api.TupleGetItem, tuple_value, index) -@register_relay_node +@tvm._ffi.register_object("relay.RefCreate") class RefCreate(ExprWithOp): """Create a new reference from initial value. Parameters @@ -384,7 +385,7 @@ def __init__(self, value): self.__init_handle_by_constructor__(_ffi_api.RefCreate, value) -@register_relay_node +@tvm._ffi.register_object("relay.RefRead") class RefRead(ExprWithOp): """Get the value inside the reference. Parameters @@ -396,7 +397,7 @@ def __init__(self, ref): self.__init_handle_by_constructor__(_ffi_api.RefRead, ref) -@register_relay_node +@tvm._ffi.register_object("relay.RefWrite") class RefWrite(ExprWithOp): """ Update the value inside the reference. diff --git a/python/tvm/relay/ir/expr_functor.py b/python/tvm/relay/expr_functor.py similarity index 99% rename from python/tvm/relay/ir/expr_functor.py rename to python/tvm/relay/expr_functor.py index d3beb200e407..8d6923920979 100644 --- a/python/tvm/relay/ir/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -21,7 +21,7 @@ from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite from .adt import Constructor, Match, Clause -from ..op import Op +from .op import Op class ExprFunctor: """ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7b16bc9849aa..6da91c17fd94 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -30,7 +30,7 @@ from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op -from ..ir.loops import while_loop +from ..loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d1abb4f1a5d3..29d9d1bcb93b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,12 +27,12 @@ import tvm from tvm.ir import IRModule -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from .. import analysis from .. import expr as _expr from .. import op as _op -from ..ir import ExprMutator +from ..expr_functor import ExprMutator from .common import AttrCvt, get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape diff --git a/python/tvm/relay/ir/__init__.py b/python/tvm/relay/ir/__init__.py deleted file mode 100644 index 2a141aa271ed..000000000000 --- a/python/tvm/relay/ir/__init__.py +++ /dev/null @@ -1,95 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, redefined-builtin, invalid-name -"""The Relay IR namespace.""" -from . import base -from . import ty -from . import expr -from . import type_functor -from . import expr_functor -from . import adt -from . import prelude -from . import loops -from . import scope_builder - -# Span -Span = base.Span -SourceName = base.SourceName - -# Type -Type = ty.Type -TupleType = ty.TupleType -TensorType = ty.TensorType -TypeKind = ty.TypeKind -TypeVar = ty.TypeVar -ShapeVar = ty.ShapeVar -TypeConstraint = ty.TypeConstraint -FuncType = ty.FuncType -TypeRelation = ty.TypeRelation -IncompleteType = ty.IncompleteType -scalar_type = ty.scalar_type -RefType = ty.RefType -GlobalTypeVar = ty.GlobalTypeVar -TypeCall = ty.TypeCall -Any = ty.Any - -# Expr -Expr = expr.RelayExpr -Constant = expr.Constant -Tuple = expr.Tuple -Var = expr.Var -GlobalVar = expr.GlobalVar -Function = expr.Function -Call = expr.Call -Let = expr.Let -If = expr.If -TupleGetItem = expr.TupleGetItem -RefCreate = expr.RefCreate -RefRead = expr.RefRead -RefWrite = expr.RefWrite - -# ADT -Pattern = adt.Pattern -PatternWildcard = adt.PatternWildcard -PatternVar = adt.PatternVar -PatternConstructor = adt.PatternConstructor -PatternTuple = adt.PatternTuple -Constructor = adt.Constructor -TypeData = adt.TypeData -Clause = adt.Clause -Match = adt.Match - -# helper functions -var = expr.var -const = expr.const -bind = expr.bind - -# TypeFunctor -TypeFunctor = type_functor.TypeFunctor -TypeVisitor = type_functor.TypeVisitor -TypeMutator = type_functor.TypeMutator - -# ExprFunctor -ExprFunctor = expr_functor.ExprFunctor -ExprVisitor = expr_functor.ExprVisitor -ExprMutator = expr_functor.ExprMutator - -# Prelude -Prelude = prelude.Prelude - -# Scope builder -ScopeBuilder = scope_builder.ScopeBuilder diff --git a/python/tvm/relay/ir/loops.py b/python/tvm/relay/loops.py similarity index 100% rename from python/tvm/relay/ir/loops.py rename to python/tvm/relay/loops.py diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index acc975cbe854..b3054d67885b 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -41,13 +41,12 @@ from . import _transform from . import _reduce from . import _algorithm -from ..ir.base import register_relay_node def _register_op_make(): # pylint: disable=import-outside-toplevel from . import _make - from ..ir import expr + from .. import expr expr._op_make = _make _register_op_make() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 039b2f738883..33a193799288 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -21,7 +21,7 @@ from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple -from ..ir.expr import Tuple, TupleGetItem, const +from ..expr import Tuple, TupleGetItem, const from . import nn as _nn from .op import register_gradient from .reduce import sum as _sum diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 414b458cbd7c..17fab80118af 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,7 +17,7 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs from . import _make -from ..ir.expr import TupleWrapper +from ..expr import TupleWrapper def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 36f2fa565c9f..30918a4183b1 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -17,7 +17,7 @@ #pylint: disable=invalid-name, too-many-lines """Neural network operations.""" from __future__ import absolute_import as _abs -from ...ir.expr import TupleWrapper +from ...expr import TupleWrapper from . import _make from .util import get_pad_tuple2d diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 26a6b6ead8cd..e6bd6bf230dd 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -19,13 +19,12 @@ import tvm._ffi from tvm.driver import lower, build -from ..ir.base import register_relay_node -from ..ir.expr import RelayExpr +from ..expr import RelayExpr from ...target import get_native_generic_func, GenericFunc from ...runtime import Object from . import _make -@register_relay_node +@tvm._ffi.register_object("relay.Op") class Op(RelayExpr): """A Relay operator definition.""" diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 9224f570f39d..2f68f7074427 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -16,314 +16,326 @@ # under the License. """The attributes node used for Relay operators""" +import tvm._ffi from tvm.ir import Attrs -from ..ir.base import register_relay_attr_node -@register_relay_attr_node +def _register_relay_attr_node(type_key=None): + """Register a Relay attribute node. + + Parameters + ---------- + type_key : str or cls + The type key of the node. + """ + return tvm._ffi.register_object( + "relay.attrs." + type_key.__name__)(type_key) + + +@_register_relay_attr_node class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DWinogradAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_without_weight_transform""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DWinogradWeightTransformAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_weight_transform""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_nnpack_weight_transform""" -@register_relay_attr_node +@_register_relay_attr_node class GlobalPool2DAttrs(Attrs): """Attributes for nn.global_pool""" -@register_relay_attr_node +@_register_relay_attr_node class BiasAddAttrs(Attrs): """Atttribute of nn.bias_add""" -@register_relay_attr_node +@_register_relay_attr_node class DenseAttrs(Attrs): """Attributes for nn.dense""" -@register_relay_attr_node +@_register_relay_attr_node class FIFOBufferAttrs(Attrs): """Attributes for nn.fifo_buffer""" -@register_relay_attr_node +@_register_relay_attr_node class UpSamplingAttrs(Attrs): """Attributes for nn.upsampling""" -@register_relay_attr_node +@_register_relay_attr_node class UpSampling3DAttrs(Attrs): """Attributes for nn.upsampling3d""" -@register_relay_attr_node +@_register_relay_attr_node class PadAttrs(Attrs): """Attributes for nn.pad""" -@register_relay_attr_node +@_register_relay_attr_node class MirrorPadAttrs(Attrs): """Attributes for nn.mirror_pad""" -@register_relay_attr_node +@_register_relay_attr_node class LeakyReluAttrs(Attrs): """Attributes for nn.leaky_relu""" -@register_relay_attr_node +@_register_relay_attr_node class PReluAttrs(Attrs): """Attributes for nn.prelu""" -@register_relay_attr_node +@_register_relay_attr_node class DropoutAttrs(Attrs): """Attributes for nn.dropout""" -@register_relay_attr_node +@_register_relay_attr_node class BatchNormAttrs(Attrs): """Attributes for nn.batch_norm""" -@register_relay_attr_node +@_register_relay_attr_node class LRNAttrs(Attrs): """Attributes for nn.lrn""" -@register_relay_attr_node +@_register_relay_attr_node class L2NormalizeAttrs(Attrs): """Attributes for nn.l2_normalize""" -@register_relay_attr_node +@_register_relay_attr_node class DeformableConv2DAttrs(Attrs): """Attributes for nn.deformable_conv2d""" -@register_relay_attr_node +@_register_relay_attr_node class ResizeAttrs(Attrs): """Attributes for image.resize""" -@register_relay_attr_node +@_register_relay_attr_node class CropAndResizeAttrs(Attrs): """Attributes for image.crop_and_resize""" -@register_relay_attr_node +@_register_relay_attr_node class ArgsortAttrs(Attrs): """Attributes for algorithm.argsort""" -@register_relay_attr_node +@_register_relay_attr_node class OnDeviceAttrs(Attrs): """Attributes for annotation.on_device""" -@register_relay_attr_node +@_register_relay_attr_node class DebugAttrs(Attrs): """Attributes for debug""" -@register_relay_attr_node +@_register_relay_attr_node class DeviceCopyAttrs(Attrs): """Attributes for tensor.device_copy""" -@register_relay_attr_node +@_register_relay_attr_node class CastAttrs(Attrs): """Attributes for transform.cast""" -@register_relay_attr_node +@_register_relay_attr_node class ConcatenateAttrs(Attrs): """Attributes for tensor.concatenate""" -@register_relay_attr_node +@_register_relay_attr_node class TransposeAttrs(Attrs): """Attributes for transform.transpose""" -@register_relay_attr_node +@_register_relay_attr_node class ReshapeAttrs(Attrs): """Attributes for transform.reshape""" -@register_relay_attr_node +@_register_relay_attr_node class TakeAttrs(Attrs): """Attributes for transform.take""" -@register_relay_attr_node +@_register_relay_attr_node class InitOpAttrs(Attrs): """Attributes for ops specifying a tensor""" -@register_relay_attr_node +@_register_relay_attr_node class ArangeAttrs(Attrs): """Attributes used in arange operators""" -@register_relay_attr_node +@_register_relay_attr_node class StackAttrs(Attrs): """Attributes used in stack operators""" -@register_relay_attr_node +@_register_relay_attr_node class RepeatAttrs(Attrs): """Attributes used in repeat operators""" -@register_relay_attr_node +@_register_relay_attr_node class TileAttrs(Attrs): """Attributes used in tile operators""" -@register_relay_attr_node +@_register_relay_attr_node class ReverseAttrs(Attrs): """Attributes used in reverse operators""" -@register_relay_attr_node +@_register_relay_attr_node class SqueezeAttrs(Attrs): """Attributes used in squeeze operators""" -@register_relay_attr_node +@_register_relay_attr_node class SplitAttrs(Attrs): """Attributes for transform.split""" -@register_relay_attr_node +@_register_relay_attr_node class StridedSliceAttrs(Attrs): """Attributes for transform.stranded_slice""" -@register_relay_attr_node +@_register_relay_attr_node class SliceLikeAttrs(Attrs): """Attributes for transform.slice_like""" -@register_relay_attr_node +@_register_relay_attr_node class ClipAttrs(Attrs): """Attributes for transform.clip""" -@register_relay_attr_node +@_register_relay_attr_node class LayoutTransformAttrs(Attrs): """Attributes for transform.layout_transform""" -@register_relay_attr_node +@_register_relay_attr_node class ShapeOfAttrs(Attrs): """Attributes for tensor.shape_of""" -@register_relay_attr_node +@_register_relay_attr_node class MultiBoxPriorAttrs(Attrs): """Attributes for vision.multibox_prior""" -@register_relay_attr_node +@_register_relay_attr_node class MultiBoxTransformLocAttrs(Attrs): """Attributes for vision.multibox_transform_loc""" -@register_relay_attr_node +@_register_relay_attr_node class GetValidCountsAttrs(Attrs): """Attributes for vision.get_valid_counts""" -@register_relay_attr_node +@_register_relay_attr_node class NonMaximumSuppressionAttrs(Attrs): """Attributes for vision.non_maximum_suppression""" -@register_relay_attr_node +@_register_relay_attr_node class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" -@register_relay_attr_node +@_register_relay_attr_node class ROIPoolAttrs(Attrs): """Attributes for vision.roi_pool""" -@register_relay_attr_node +@_register_relay_attr_node class YoloReorgAttrs(Attrs): """Attributes for vision.yolo_reorg""" -@register_relay_attr_node +@_register_relay_attr_node class ProposalAttrs(Attrs): """Attributes used in proposal operators""" -@register_relay_attr_node +@_register_relay_attr_node class MaxPool2DAttrs(Attrs): """Attributes used in max_pool2d operators""" -@register_relay_attr_node +@_register_relay_attr_node class AvgPool2DAttrs(Attrs): """Attributes used in avg_pool2d operators""" -@register_relay_attr_node +@_register_relay_attr_node class MaxPool1DAttrs(Attrs): """Attributes used in max_pool1d operators""" -@register_relay_attr_node +@_register_relay_attr_node class AvgPool1DAttrs(Attrs): """Attributes used in avg_pool1d operators""" -@register_relay_attr_node +@_register_relay_attr_node class MaxPool3DAttrs(Attrs): """Attributes used in max_pool3d operators""" -@register_relay_attr_node +@_register_relay_attr_node class AvgPool3DAttrs(Attrs): """Attributes used in avg_pool3d operators""" -@register_relay_attr_node +@_register_relay_attr_node class BitPackAttrs(Attrs): """Attributes used in bitpack operator""" -@register_relay_attr_node +@_register_relay_attr_node class BinaryConv2DAttrs(Attrs): """Attributes used in bitserial conv2d operators""" -@register_relay_attr_node +@_register_relay_attr_node class BinaryDenseAttrs(Attrs): """Attributes used in bitserial dense operators""" -@register_relay_attr_node +@_register_relay_attr_node class Conv2DTransposeAttrs(Attrs): """Attributes used in Transposed Conv2D operators""" -@register_relay_attr_node +@_register_relay_attr_node class SubPixelAttrs(Attrs): """Attributes used in depth to space and space to depth operators""" diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 05f8932e396c..d3226012e887 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -20,7 +20,7 @@ from . import _make from .tensor import sqrt from .transform import squeeze -from ..ir.expr import Tuple, TupleWrapper +from ..expr import Tuple, TupleWrapper def argmax(data, axis=None, keepdims=False, exclude=False): """Returns the indices of the maximum values along an axis. diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 90e4604e80f3..77969185c0a7 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -20,7 +20,7 @@ from tvm.runtime import TVMContext as _TVMContext from . import _make -from ..ir.expr import Tuple +from ..expr import Tuple # We create a wrapper function for each operator in the diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 3e18447c2c6d..6a30eb2b7ec9 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -19,7 +19,7 @@ """Transform operators.""" from . import _make -from ..ir.expr import TupleWrapper, const +from ..expr import TupleWrapper, const def cast(data, dtype): @@ -38,7 +38,7 @@ def cast(data, dtype): result : relay.Expr The casted result. """ - from ..ir import _ffi_api as _relay_make + from .. import _ffi_api as _relay_make return _relay_make.cast(data, dtype) @@ -55,7 +55,7 @@ def cast_like(data, dtype_like): result : relay.Expr The casted result. """ - from .. import _make as _relay_make + from .. import _ffi_api as _relay_make return _relay_make.cast_like(data, dtype_like) diff --git a/python/tvm/relay/op/vision/multibox.py b/python/tvm/relay/op/vision/multibox.py index 73f4893412fc..55fb01c5eaef 100644 --- a/python/tvm/relay/op/vision/multibox.py +++ b/python/tvm/relay/op/vision/multibox.py @@ -16,7 +16,7 @@ # under the License. """Multibox operations.""" from . import _make -from ...ir.expr import TupleWrapper +from ...expr import TupleWrapper def multibox_prior(data, sizes=(1.0,), diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 859c4999545d..cba08bfba824 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -16,7 +16,7 @@ # under the License. """Non-maximum suppression operations.""" from . import _make -from ...ir.expr import TupleWrapper +from ...expr import TupleWrapper def get_valid_counts(data, score_threshold, diff --git a/python/tvm/relay/ir/parser.py b/python/tvm/relay/parser.py similarity index 94% rename from python/tvm/relay/ir/parser.py rename to python/tvm/relay/parser.py index 053d299da022..6c4e3131e3c2 100644 --- a/python/tvm/relay/ir/parser.py +++ b/python/tvm/relay/parser.py @@ -16,14 +16,14 @@ # under the License. """A parser for Relay's text format.""" from __future__ import absolute_import -from ... import register_func +from .. import register_func @register_func("relay.fromtext") def fromtext(data, source_name=None): """Parse a Relay program.""" # pylint: disable=import-outside-toplevel - from . import _parser + from tvm.relay import _parser x = _parser.fromtext(data + "\n", source_name) if x is None: raise Exception("cannot parse: ", data) diff --git a/python/tvm/relay/ir/prelude.py b/python/tvm/relay/prelude.py similarity index 99% rename from python/tvm/relay/ir/prelude.py rename to python/tvm/relay/prelude.py index fa68d3ae177a..5288a2e08011 100644 --- a/python/tvm/relay/ir/prelude.py +++ b/python/tvm/relay/prelude.py @@ -20,10 +20,10 @@ from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, Function, GlobalVar, If, const -from ..op.tensor import add, subtract, equal +from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard -from .. import op +from . import op class TensorArrayOps(object): diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index b53992637250..c94a4194daee 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,7 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm.relay.ir import Tuple +from tvm.relay.expr import Tuple from tvm.relay.op.nn.util import get_pad_tuple2d from . import _make diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 26d8a18f7b98..2658a0aa7dad 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -24,7 +24,6 @@ from .. import analysis as _analysis from .. import op as _op from ..op import op as _reg -from ..ir.base import register_relay_node from . import _quantize from .quantize import QAnnotateKind, current_qconfig, quantize_context from .quantize import _forward_op @@ -58,7 +57,7 @@ def simulated_quantize_compute(attrs, inputs, out_type): _reg.register_injective_schedule("annotation.cast_hint") -@register_relay_node +@tvm._ffi.register_object("relay.QAnnotateExpr") class QAnnotateExpr(_expr.TempExpr): """A special kind of Expr for Annotating. diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index 90274e879992..bb3db99eed79 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -19,7 +19,6 @@ import tvm from .. import expr as _expr from .. import analysis as _analysis -from ..ir.base import register_relay_node from ..op import op as _reg from . import _quantize from .quantize import _forward_op @@ -30,7 +29,7 @@ def _register(func): return _register(frewrite) if frewrite is not None else _register -@register_relay_node +@tvm._ffi.register_object("relay.QPartitionExpr") class QPartitionExpr(_expr.TempExpr): def __init__(self, expr): self.__init_handle_by_constructor__( diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index bd1ab6fd7dfc..2ad4e18771d7 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -17,12 +17,12 @@ #pylint: disable=unused-argument, not-context-manager """Automatic quantization toolkit.""" import tvm.ir +from tvm.runtime import Object from . import _quantize from ._calibrate import calibrate from .. import expr as _expr from .. import transform as _transform -from ..ir.base import Object, register_relay_node class QAnnotateKind(object): @@ -52,7 +52,7 @@ def _forward_op(ref_call, args): ref_call.op, args, ref_call.attrs, ref_call.type_args) -@register_relay_node("relay.quantize.QConfig") +@tvm._ffi.register_object("relay.quantize.QConfig") class QConfig(Object): """Configure the quantization behavior by setting config variables. diff --git a/python/tvm/relay/ir/scope_builder.py b/python/tvm/relay/scope_builder.py similarity index 99% rename from python/tvm/relay/ir/scope_builder.py rename to python/tvm/relay/scope_builder.py index 35357707e535..cd8dc8dcd309 100644 --- a/python/tvm/relay/ir/scope_builder.py +++ b/python/tvm/relay/scope_builder.py @@ -20,7 +20,7 @@ from . import ty as _ty from . import expr as _expr -from ..._ffi import base as _base +from .._ffi import base as _base class WithScope(object): """A wrapper for builder methods which introduce scoping. diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 068c02c6d53b..54c909179e4f 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -23,9 +23,9 @@ from tvm import te import tvm.relay as relay import tvm.relay.op as op -from tvm.relay import transform, create_executor -from tvm.relay.ir import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem -from tvm.relay.ir import TensorType, TupleType +from tvm.relay import transform +from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor +from tvm.relay import TensorType, TupleType from . import mlp from . import resnet diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index d1110fdb19d5..eb71120610d3 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -19,10 +19,10 @@ Nats are useful for testing purposes, as they make it easy to write test cases for recursion and pattern matching.""" -from tvm.relay.ir import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar +from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar from tvm.relay.backend.interpreter import ConstructorValue -from tvm.relay.ir import Var, Function, GlobalVar -from tvm.relay.ir import GlobalTypeVar, TypeVar, FuncType +from tvm.relay.expr import Var, Function, GlobalVar +from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType def define_nat_adt(prelude): """Defines a Peano (unary) natural number ADT. diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index e40436c4df22..eacfe379137f 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -21,10 +21,10 @@ import tvm from tvm import relay -from tvm.relay.ir import Pattern +from tvm.relay.adt import Pattern from tvm.relay.backend import compile_engine -from tvm.relay.ir import Expr, Function, GlobalVar, Var -from tvm.relay.ir import ExprFunctor +from tvm.relay.expr import Expr, Function, GlobalVar, Var +from tvm.relay.expr_functor import ExprFunctor OUTPUT_VAR_NAME = '_py_out' diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 7b5a0b75ac6c..c238730807d3 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -19,12 +19,12 @@ A pass for manifesting explicit memory allocations. """ import numpy as np -from ..ir.expr_functor import ExprMutator -from ..ir.scope_builder import ScopeBuilder +from ..expr_functor import ExprMutator +from ..scope_builder import ScopeBuilder from . import transform from .. import op from ... import DataType, register_func -from ..ir import ty, expr +from .. import ty, expr from ..backend import compile_engine diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index e32de3e3d112..43a116e64e5b 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -29,7 +29,6 @@ from tvm import relay from . import _ffi_api -from ..ir.base import register_relay_node def build_config(opt_level=2, @@ -83,7 +82,7 @@ def build_config(opt_level=2, disabled_pass, trace) -@register_relay_node +@tvm._ffi.register_object("relay.FunctionPass") class FunctionPass(Pass): """A pass that works on each tvm.relay.Function in a module. A function pass class should be created through `function_pass`. diff --git a/python/tvm/relay/ir/ty.py b/python/tvm/relay/ty.py similarity index 97% rename from python/tvm/relay/ir/ty.py rename to python/tvm/relay/ty.py index b9643803f2f6..19cc10aba41e 100644 --- a/python/tvm/relay/ir/ty.py +++ b/python/tvm/relay/ty.py @@ -20,7 +20,7 @@ from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType -from .base import RelayNode, register_relay_node +from .base import RelayNode from . import _ffi_api Any = _ffi_api.Any diff --git a/python/tvm/relay/ir/type_functor.py b/python/tvm/relay/type_functor.py similarity index 100% rename from python/tvm/relay/ir/type_functor.py rename to python/tvm/relay/type_functor.py diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 491f18df3de0..deeb7330f9da 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -19,7 +19,7 @@ from tvm import relay from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr import numpy as np diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_analysis_feature.py similarity index 98% rename from tests/python/relay/test_feature.py rename to tests/python/relay/test_analysis_feature.py index f54fa713e957..ec5deb3c4e60 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_analysis_feature.py @@ -20,7 +20,7 @@ from tvm import relay from tvm.relay.analysis import detect_feature, Feature from tvm.relay.transform import gradient -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import run_infer_type def test_prelude(): diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 914ac12ed6ff..aa81e3113b7f 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -20,7 +20,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.ir.loops import while_loop +from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type def int32(val): diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 0534c16dedaa..360b6bd20416 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.runtime import container from tvm.relay.backend.interpreter import RefValue, ConstructorValue -from tvm.relay.ir import ScopeBuilder +from tvm.relay.scope_builder import ScopeBuilder from tvm.relay import testing, create_executor diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py index bfc9accf5906..bab82472263a 100644 --- a/tests/python/relay/test_ir_module.py +++ b/tests/python/relay/test_ir_module.py @@ -18,7 +18,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions def constructor_list(p): diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index daf436c584cc..db953d5762e3 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -18,7 +18,7 @@ from tvm import te from tvm import relay from tvm.relay.analysis import well_formed -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude def test_let(): x = relay.Var("x") diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index e98c7ec1fe6f..3e7d916c96fa 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -19,10 +19,9 @@ import numpy as np import tvm -from tvm import te from tvm import relay from tvm.contrib import graph_runtime -from tvm.relay.ir import ExprMutator +from tvm.relay.expr_functor import ExprMutator from tvm.relay import transform diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 48923b5ce061..6f2a12589fb5 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -22,7 +22,7 @@ from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal from tvm.relay import create_executor, transform from tvm.relay.transform import gradient -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand import tvm.relay.op as op diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 2b865c57123d..aed026996a21 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -21,7 +21,8 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.ir import ExprFunctor, Function, Call +from tvm.relay import ExprFunctor +from tvm.relay import Function, Call from tvm.relay import analysis from tvm.relay import transform as _transform from tvm.relay.testing import ctx_list, run_infer_type diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index f2f6f85955e4..f54dd6bf69c5 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -20,11 +20,11 @@ from tvm import te from tvm import relay from tvm.relay.analysis import alpha_equal, assert_alpha_equal -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay import op, create_executor, transform -from tvm.relay.ir import Var, TypeVar, TupleGetItem, Let, const, RefRead, RefWrite, RefCreate -from tvm.relay.ir import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match -from tvm.relay.ir import GlobalVar, Call, Function +from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate +from tvm.relay import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match +from tvm.relay import GlobalVar, Call from tvm.relay.transform import gradient from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 36e371d68747..c4fbbc1458d9 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -26,7 +26,7 @@ from tvm.relay import transform from tvm.contrib import util from tvm.relay.op.annotation import compiler_begin, compiler_end -from tvm.relay.ir import ExprMutator +from tvm.relay.expr_functor import ExprMutator # Leverage the pass manager to write a simple white list based annotator @transform.function_pass(opt_level=0) diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index e42bdbd94986..33816344f562 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -19,7 +19,7 @@ from tvm import te from tvm import relay from tvm.relay import transform -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude def test_remove_all_prelude_functions(): diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 1f0c9f33cc73..2a6103ea1fbe 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -20,7 +20,7 @@ from tvm import relay from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay import op, create_executor, transform -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count from tvm.relay.analysis import Feature diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 76b906cf96ab..e2ac924e9661 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -16,12 +16,11 @@ # under the License. import numpy as np import tvm -from tvm import te from tvm import relay from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay.transform import to_cps, un_cps from tvm.relay.analysis import Feature -from tvm.relay.ir import Prelude, Function +from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass from tvm.relay import create_executor from tvm.relay import transform diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index bac0970b098b..42344bccabaa 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -18,7 +18,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.relay.analysis import unmatched_cases import pytest diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index 36ed034ba9eb..f6b1b2432d92 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -16,9 +16,10 @@ # under the License. import numpy as np import tvm +from tvm import te from tvm import relay from tvm.relay.testing import to_python, run_as_python -from tvm.relay.ir import Prelude +from tvm.relay.prelude import Prelude from tvm.runtime.container import ADT from tvm.relay.backend.interpreter import RefValue, ConstructorValue diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py index 8f581b5b715c..9e023bc6b1e4 100644 --- a/tests/python/relay/test_type_functor.py +++ b/tests/python/relay/test_type_functor.py @@ -19,9 +19,9 @@ from tvm import relay from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor from tvm.relay.analysis import assert_graph_equal -from tvm.relay.ir import (TypeVar, IncompleteType, TensorType, FuncType, +from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType, TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall) -from tvm.relay.ir import TypeData +from tvm.relay.adt import TypeData def check_visit(typ): try: diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index b31f04c5edd9..f2b15ec26f32 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -20,10 +20,10 @@ import tvm from tvm import runtime from tvm import relay -from tvm.relay.ir import ScopeBuilder +from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.testing.config import ctx_list -from tvm.relay.ir import Prelude -from tvm.relay.ir.loops import while_loop +from tvm.relay.prelude import Prelude +from tvm.relay.loops import while_loop from tvm.relay import testing def check_result(args, expected_result, mod=None): diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index a2a786e6a8be..5d20651a8126 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -24,8 +24,8 @@ from tvm.relay import vm as rly_vm from tvm import relay -from tvm.relay.ir import ScopeBuilder -from tvm.relay.ir import Prelude +from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay.prelude import Prelude from tvm.contrib import util from tvm.relay import testing diff --git a/tests/python/unittest/test_autotvm_graph_tuner_utils.py b/tests/python/unittest/test_autotvm_graph_tuner_utils.py index b558010823e1..bd0ebe0cd3f5 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_utils.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_utils.py @@ -28,7 +28,7 @@ from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \ get_out_nodes, expr2graph, bind_inputs from tvm.autotvm.graph_tuner._base import OPT_OUT_OP -from tvm.relay.ir import Call, TupleGetItem, Tuple, Var +from tvm.relay.expr import Call, TupleGetItem, Tuple, Var def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result): From c2eb0b22a28dce7abb93deb7029b5c710e59468b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 16 Mar 2020 19:13:27 +0000 Subject: [PATCH 3/4] Address comments --- docs/api/python/relay/base.rst | 4 - python/tvm/relay/op/op_attrs.py | 141 +++++++++++++++----------------- 2 files changed, 64 insertions(+), 81 deletions(-) diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst index dc9dac0f67bd..8dcab78d3231 100644 --- a/docs/api/python/relay/base.rst +++ b/docs/api/python/relay/base.rst @@ -19,10 +19,6 @@ tvm.relay.base -------------- .. automodule:: tvm.relay.base -.. autofunction:: tvm.relay.base.register_relay_node - -.. autofunction:: tvm.relay.base.register_relay_attr_node - .. autoclass:: tvm.relay.base.RelayNode :members: diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 2f68f7074427..9a5fb5592e90 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -15,327 +15,314 @@ # specific language governing permissions and limitations # under the License. """The attributes node used for Relay operators""" - -import tvm._ffi from tvm.ir import Attrs +import tvm._ffi -def _register_relay_attr_node(type_key=None): - """Register a Relay attribute node. - - Parameters - ---------- - type_key : str or cls - The type key of the node. - """ - return tvm._ffi.register_object( - "relay.attrs." + type_key.__name__)(type_key) - - -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DAttrs") class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DWinogradAttrs") class Conv2DWinogradAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_without_weight_transform""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DWinogradWeightTransformAttrs") class Conv2DWinogradWeightTransformAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_weight_transform""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs): """Attributes for nn.contrib_conv2d_winograd_nnpack_weight_transform""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.GlobalPool2DAttrs") class GlobalPool2DAttrs(Attrs): """Attributes for nn.global_pool""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BiasAddAttrs") class BiasAddAttrs(Attrs): """Atttribute of nn.bias_add""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.DenseAttrs") class DenseAttrs(Attrs): """Attributes for nn.dense""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.FIFOBufferAttrs") class FIFOBufferAttrs(Attrs): """Attributes for nn.fifo_buffer""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.UpSamplingAttrs") class UpSamplingAttrs(Attrs): """Attributes for nn.upsampling""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.UpSampling3DAttrs") class UpSampling3DAttrs(Attrs): """Attributes for nn.upsampling3d""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.PadAttrs") class PadAttrs(Attrs): """Attributes for nn.pad""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MirrorPadAttrs") class MirrorPadAttrs(Attrs): """Attributes for nn.mirror_pad""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.LeakyReluAttrs") class LeakyReluAttrs(Attrs): """Attributes for nn.leaky_relu""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.PReluAttrs") class PReluAttrs(Attrs): """Attributes for nn.prelu""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.DropoutAttrs") class DropoutAttrs(Attrs): """Attributes for nn.dropout""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BatchNormAttrs") class BatchNormAttrs(Attrs): """Attributes for nn.batch_norm""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.LRNAttrs") class LRNAttrs(Attrs): """Attributes for nn.lrn""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.L2NormalizeAttrs") class L2NormalizeAttrs(Attrs): """Attributes for nn.l2_normalize""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.DeformableConv2DAttrs") class DeformableConv2DAttrs(Attrs): """Attributes for nn.deformable_conv2d""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ResizeAttrs") class ResizeAttrs(Attrs): """Attributes for image.resize""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.CropAndResizeAttrs") class CropAndResizeAttrs(Attrs): """Attributes for image.crop_and_resize""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ArgsortAttrs") class ArgsortAttrs(Attrs): """Attributes for algorithm.argsort""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.OnDeviceAttrs") class OnDeviceAttrs(Attrs): """Attributes for annotation.on_device""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.DebugAttrs") class DebugAttrs(Attrs): """Attributes for debug""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.OnDeviceAttrs") class DeviceCopyAttrs(Attrs): """Attributes for tensor.device_copy""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.CastAttrs") class CastAttrs(Attrs): """Attributes for transform.cast""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ConcatenateAttrs") class ConcatenateAttrs(Attrs): """Attributes for tensor.concatenate""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.TransposeAttrs") class TransposeAttrs(Attrs): """Attributes for transform.transpose""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ReshapeAttrs") class ReshapeAttrs(Attrs): """Attributes for transform.reshape""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes for transform.take""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.InitOpAttrs") class InitOpAttrs(Attrs): """Attributes for ops specifying a tensor""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ArangeAttrs") class ArangeAttrs(Attrs): """Attributes used in arange operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.StackAttrs") class StackAttrs(Attrs): """Attributes used in stack operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.RepeatAttrs") class RepeatAttrs(Attrs): """Attributes used in repeat operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.TileAttrs") class TileAttrs(Attrs): """Attributes used in tile operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ReverseAttrs") class ReverseAttrs(Attrs): """Attributes used in reverse operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.SqueezeAttrs") class SqueezeAttrs(Attrs): """Attributes used in squeeze operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.SplitAttrs") class SplitAttrs(Attrs): """Attributes for transform.split""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.StridedSliceAttrs") class StridedSliceAttrs(Attrs): """Attributes for transform.stranded_slice""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.SliceLikeAttrs") class SliceLikeAttrs(Attrs): """Attributes for transform.slice_like""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ClipAttrs") class ClipAttrs(Attrs): """Attributes for transform.clip""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes for transform.layout_transform""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ShapeOfAttrs") class ShapeOfAttrs(Attrs): """Attributes for tensor.shape_of""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MultiBoxPriorAttrs") class MultiBoxPriorAttrs(Attrs): """Attributes for vision.multibox_prior""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MultiBoxTransformLocAttrs") class MultiBoxTransformLocAttrs(Attrs): """Attributes for vision.multibox_transform_loc""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.GetValidCountsAttrs") class GetValidCountsAttrs(Attrs): """Attributes for vision.get_valid_counts""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.NonMaximumSuppressionAttrs") class NonMaximumSuppressionAttrs(Attrs): """Attributes for vision.non_maximum_suppression""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ROIAlignAttrs") class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ROIPoolAttrs") class ROIPoolAttrs(Attrs): """Attributes for vision.roi_pool""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.YoloReorgAttrs") class YoloReorgAttrs(Attrs): """Attributes for vision.yolo_reorg""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.ProposalAttrs") class ProposalAttrs(Attrs): """Attributes used in proposal operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MaxPool2DAttrs") class MaxPool2DAttrs(Attrs): """Attributes used in max_pool2d operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.AvgPool2DAttrs") class AvgPool2DAttrs(Attrs): """Attributes used in avg_pool2d operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MaxPool1DAttrs") class MaxPool1DAttrs(Attrs): """Attributes used in max_pool1d operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.AvgPool1DAttrs") class AvgPool1DAttrs(Attrs): """Attributes used in avg_pool1d operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.MaxPool3DAttrs") class MaxPool3DAttrs(Attrs): """Attributes used in max_pool3d operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.AvgPool3DAttrs") class AvgPool3DAttrs(Attrs): """Attributes used in avg_pool3d operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BitPackAttrs") class BitPackAttrs(Attrs): """Attributes used in bitpack operator""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BinaryConv2DAttrs") class BinaryConv2DAttrs(Attrs): """Attributes used in bitserial conv2d operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.BinaryDenseAttrs") class BinaryDenseAttrs(Attrs): """Attributes used in bitserial dense operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.Conv2DTransposeAttrs") class Conv2DTransposeAttrs(Attrs): """Attributes used in Transposed Conv2D operators""" -@_register_relay_attr_node +@tvm._ffi.register_object("relay.attrs.SubPixelAttrs") class SubPixelAttrs(Attrs): """Attributes used in depth to space and space to depth operators""" From a5cf36b7bfd6d1438314e5b735c8e914413f1df8 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 17 Mar 2020 02:25:10 +0000 Subject: [PATCH 4/4] remove direct access to analysis and transform namespace --- python/tvm/relay/__init__.py | 18 +++--------------- python/tvm/relay/analysis/__init__.py | 2 ++ tests/python/relay/test_call_graph.py | 18 +++++++++--------- ...mory_alloc.py => test_pass_memory_alloc.py} | 2 +- .../relay/test_pass_to_graph_normal_form.py | 4 ++-- 5 files changed, 17 insertions(+), 27 deletions(-) rename tests/python/relay/{test_memory_alloc.py => test_pass_memory_alloc.py} (98%) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index f24fe44f9041..b1aac3e606a2 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -32,7 +32,7 @@ from . import transform from . import analysis -from .analysis import call_graph, feature, alpha_equal +from .analysis import alpha_equal from .build_module import build, create_executor, optimize from .transform import build_config from . import debug @@ -43,16 +43,13 @@ from .op import Op from .op import nn from .op import image -from .op import vision from .op import annotation +from .op import vision +from .op import contrib from .op.reduce import * from .op.tensor import * from .op.transform import * from .op.algorithm import * -from .op.nn import * -from .op.vision import * -from .op.contrib import * -from .op.image import * from . import frontend from . import backend from . import quantize @@ -60,9 +57,6 @@ # Dialects from . import qnn -# Load Memory pass -from .transform import memory_alloc - # Required to traverse large programs setrecursionlimit(10000) @@ -151,9 +145,3 @@ ModulePass = transform.ModulePass FunctionPass = transform.FunctionPass Sequential = transform.Sequential - -# Feature -Feature = feature.Feature - -# CallGraph -CallGraph = call_graph.CallGraph diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index eb8a91e084f3..957f5a3dcd94 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -24,3 +24,5 @@ # Feature from . import feature + +CallGraph = call_graph.CallGraph diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py index 92bb37367ef3..849f01546788 100644 --- a/tests/python/relay/test_call_graph.py +++ b/tests/python/relay/test_call_graph.py @@ -25,7 +25,7 @@ def test_callgraph_construct(): x = relay.var("x", shape=(2, 3)) y = relay.var("y", shape=(2, 3)) mod["g1"] = relay.Function([x, y], x + y) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert "g1" in str(call_graph) assert relay.alpha_equal(mod, call_graph.module) @@ -38,7 +38,7 @@ def test_print_element(): x1 = relay.var("x1", shape=(2, 3)) y1 = relay.var("y1", shape=(2, 3)) mod["g1"] = relay.Function([x1, y1], x1 - y1) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert "#refs = 0" in str(call_graph.print_var("g0")) assert "#refs = 0" in str(call_graph.print_var("g1")) @@ -54,13 +54,13 @@ def test_global_call_count(): y1 = relay.var("y1", shape=(2, 3)) g1 = relay.GlobalVar("g1") mod[g1] = relay.Function([x1, y1], g0(x1, y1)) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) p0 = relay.var("p0", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) mod["main"] = func - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert call_graph.global_call_count(g0) == 0 assert call_graph.global_call_count(g1) == 1 @@ -77,13 +77,13 @@ def test_ref_count(): y1 = relay.var("y1", shape=(2, 3)) g1 = relay.GlobalVar("g1") mod[g1] = relay.Function([x1, y1], x1 - y1) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) p0 = relay.var("p0", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) mod["main"] = func - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert call_graph.ref_count(g0) == 1 assert call_graph.ref_count(g1) == 1 @@ -100,13 +100,13 @@ def test_nested_ref(): y1 = relay.var("y1", shape=(2, 3)) g1 = relay.GlobalVar("g1") mod[g1] = relay.Function([x1, y1], g0(x1, y1)) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) p0 = relay.var("p0", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) mod["main"] = func - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert call_graph.ref_count(g0) == 2 assert call_graph.ref_count(g1) == 1 @@ -138,7 +138,7 @@ def test_recursive_func(): mod[sum_up] = func iarg = relay.var('i', shape=[], dtype='int32') mod["main"] = relay.Function([iarg], sum_up(iarg)) - call_graph = relay.CallGraph(mod) + call_graph = relay.analysis.CallGraph(mod) assert call_graph.is_recursive(sum_up) assert call_graph.ref_count(sum_up) == 2 diff --git a/tests/python/relay/test_memory_alloc.py b/tests/python/relay/test_pass_memory_alloc.py similarity index 98% rename from tests/python/relay/test_memory_alloc.py rename to tests/python/relay/test_pass_memory_alloc.py index 08fc39df9ad0..c3c53121d934 100644 --- a/tests/python/relay/test_memory_alloc.py +++ b/tests/python/relay/test_pass_memory_alloc.py @@ -18,7 +18,7 @@ from tvm import te import numpy as np from tvm import relay -from tvm.relay import memory_alloc +from tvm.relay.transform import memory_alloc def check_vm_alloc(func, check_fn): mod = tvm.IRModule() diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index dc47ad350fe5..94886220874d 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -16,9 +16,9 @@ # under the License. import numpy as np import tvm -from tvm import te from tvm import relay -from tvm.relay import op, create_executor, transform, Feature +from tvm.relay import op, create_executor, transform +from tvm.relay.analysis import Feature from tvm.relay.analysis import detect_feature