diff --git a/apps/lldb/tvm.py b/apps/lldb/tvm.py index ec6dc439ebc3..d7779b011c0a 100644 --- a/apps/lldb/tvm.py +++ b/apps/lldb/tvm.py @@ -103,11 +103,9 @@ def __lldb_init_module(debugger, _): "tvm::relay::Span", "tvm::relay::TempExpr", "tvm::relay::TensorType", - "tvm::relay::TensorValue", "tvm::relay::Tuple", "tvm::relay::TupleGetItem", "tvm::relay::TupleType", - "tvm::relay::TupleValue", "tvm::relay::Type", "tvm::relay::TypeCall", "tvm::relay::TypeConstraint", diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 73868008a7e1..e090dc85f238 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -25,8 +25,8 @@ * Given a Relay module, and a Relay expression it produces a value. * * The interpreter's values are a naive representation of the values that - * can be produced by a Relay program and are exposed via tvm::Node's - * system to Python for introspection and debugging. + * can be produced by a Relay program and are exposed via TVM's object + * protocol to Python for introspection and debugging. * * The interpreter's intent is to serve as a reference semantics for the Relay IR, * as well as for debugging and testing. @@ -38,6 +38,8 @@ #include #include #include +#include +#include namespace tvm { namespace relay { @@ -64,11 +66,8 @@ namespace relay { runtime::TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, Target target); -/*! \brief A Relay closure, i.e a scope and a function. */ -class Closure; - -/*! \brief The container type of Closures. */ -class ClosureNode : public Object { +/*! \brief The container type of Closures used by the interpreter. */ +class InterpreterClosureObj : public runtime::vm::ClosureObj { public: /*! \brief The set of free variables in the closure. * @@ -82,102 +81,69 @@ class ClosureNode : public Object { */ Function func; - ClosureNode() {} + InterpreterClosureObj() {} void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("env", &env); v->Visit("func", &func); } - TVM_DLL static Closure make(tvm::Map env, Function func); - - static constexpr const char* _type_key = "relay.Closure"; - TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object); + static constexpr const char* _type_key = "interpreter.Closure"; + TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj); }; -class Closure : public ObjectRef { +class InterpreterClosure : public runtime::vm::Closure { public: - TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode); + TVM_DLL InterpreterClosure(tvm::Map env, Function func); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, + InterpreterClosureObj); }; -/*! \brief A Relay Recursive Closure. A closure that has a name. */ -class RecClosure; - /*! \brief The container type of RecClosure. */ -class RecClosureNode : public Object { +class RecClosureObj : public Object { public: /*! \brief The closure. */ - Closure clos; + InterpreterClosure clos; /*! \brief variable the closure bind to. */ Var bind; - RecClosureNode() {} + RecClosureObj() {} void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("clos", &clos); v->Visit("bind", &bind); } - TVM_DLL static RecClosure make(Closure clos, Var bind); - - static constexpr const char* _type_key = "relay.RecClosure"; - TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object); + static constexpr const char* _type_key = "interpreter.RecClosure"; + TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object); }; class RecClosure : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode); -}; - -/*! \brief A tuple value. */ -class TupleValue; - -/*! \brief Tuple (x, ... y). */ -struct TupleValueNode : Object { - tvm::Array fields; - - TupleValueNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } - - TVM_DLL static TupleValue make(tvm::Array value); - - static constexpr const char* _type_key = "relay.TupleValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object); -}; - -class TupleValue : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode); + TVM_DLL RecClosure(InterpreterClosure clos, Var bind); + TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj); }; -/*! \brief A reference value. */ -class RefValue; - -struct RefValueNode : Object { +struct RefValueObj : Object { mutable ObjectRef value; - RefValueNode() {} + RefValueObj() {} void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); } - TVM_DLL static RefValue make(ObjectRef val); - static constexpr const char* _type_key = "relay.RefValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object); }; class RefValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode); + TVM_DLL RefValue(ObjectRef val); + TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj); }; -/*! \brief An ADT constructor value. */ -class ConstructorValue; - -struct ConstructorValueNode : Object { +struct ConstructorValueObj : Object { int32_t tag; tvm::Array fields; @@ -191,17 +157,17 @@ struct ConstructorValueNode : Object { v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(int32_t tag, - tvm::Array fields, - Constructor construtor = {}); - static constexpr const char* _type_key = "relay.ConstructorValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object); }; class ConstructorValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode); + TVM_DLL ConstructorValue(int32_t tag, + tvm::Array fields, + Constructor construtor = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); }; } // namespace relay diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 8ef9cb449d1b..1989afdf0787 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -50,10 +50,9 @@ namespace runtime { enum TypeIndex { /*! \brief Root object type. */ kRoot = 0, - kVMTensor = 1, - kVMClosure = 2, - kVMADT = 3, - kRuntimeModule = 4, + kClosure = 1, + kVMADT = 2, + kRuntimeModule = 3, kStaticIndexEnd, /*! \brief Type index is allocated during runtime. */ kDynamic = kStaticIndexEnd diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 990ecf5ea733..43c222d0994a 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -25,36 +25,58 @@ #define TVM_RUNTIME_VM_H_ #include +#include #include #include #include #include #include +#include #include namespace tvm { namespace runtime { namespace vm { -/*! \brief An object representing a closure. */ +/*! + * \brief An object representing a closure. This object is used by both the + * Relay VM and interpreter. + */ class ClosureObj : public Object { public: - /*! \brief The index into the VM function table. */ + static constexpr const uint32_t _type_index = TypeIndex::kClosure; + static constexpr const char* _type_key = "Closure"; + TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object); +}; + +/*! \brief reference to closure. */ +class Closure : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); +}; + +/*! + * \brief An object representing a vm closure. + */ +class VMClosureObj : public ClosureObj { + public: + /*! + * \brief The index into the function list. The function could be any + * function object that is compatible to the VM runtime. + */ size_t func_index; /*! \brief The free variables of the closure. */ std::vector free_vars; - static constexpr const uint32_t _type_index = TypeIndex::kVMClosure; static constexpr const char* _type_key = "vm.Closure"; - TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj); }; /*! \brief reference to closure. */ -class Closure : public ObjectRef { +class VMClosure : public Closure { public: - Closure(size_t func_index, std::vector free_vars); - - TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); + VMClosure(size_t func_index, std::vector free_vars); + TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj); }; /*! \brief Magic number for NDArray list file */ diff --git a/python/tvm/container.py b/python/tvm/container.py index 274fc1f4027c..673afb428987 100644 --- a/python/tvm/container.py +++ b/python/tvm/container.py @@ -16,8 +16,10 @@ # under the License. """Container data structures used in TVM DSL.""" from __future__ import absolute_import as _abs -from ._ffi.object import Object, register_object +from tvm import ndarray as _nd from . import _api_internal +from ._ffi.object import Object, register_object, getitem_helper +from ._ffi.function import _init_api @register_object class Array(Object): @@ -114,3 +116,56 @@ class LoweredFunc(Object): MixedFunc = 0 HostFunc = 1 DeviceFunc = 2 + + +@register_object("vm.ADT") +class ADT(Object): + """Algebatic data type(ADT) object. + + Parameters + ---------- + tag : int + The tag of ADT. + + fields : list[Object] or tuple[Object] + The source tuple. + """ + def __init__(self, tag, fields): + for f in fields: + assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \ + "tvm NDArray type, but received : {0}".format(type(f)) + self.__init_handle_by_constructor__(_ADT, tag, *fields) + + @property + def tag(self): + return _GetADTTag(self) + + def __getitem__(self, idx): + return getitem_helper( + self, _GetADTFields, len(self), idx) + + def __len__(self): + return _GetADTSize(self) + + +def tuple_object(fields=None): + """Create a ADT object from source tuple. + + Parameters + ---------- + fields : list[Object] or tuple[Object] + The source tuple. + + Returns + ------- + ret : ADT + The created object. + """ + fields = fields if fields else [] + for f in fields: + assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \ + "NDArray type, but received : {0}".format(type(f)) + return _Tuple(*fields) + + +_init_api("tvm.container") diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c7cbcf096a6c..2432ec31cfe5 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -38,7 +38,6 @@ from . import feature from .backend import vm from .backend import profiler_vm -from .backend import vmobj # Root operators from .op import Op diff --git a/python/tvm/relay/backend/_vmobj.py b/python/tvm/relay/backend/_vmobj.py deleted file mode 100644 index 1e7efa467387..000000000000 --- a/python/tvm/relay/backend/_vmobj.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""The VM Object FFI namespace.""" -from tvm._ffi.function import _init_api - -_init_api("_vmobj", __name__) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 59d9a8fae43c..f85f92fbdd59 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -20,6 +20,7 @@ import numpy as np +from tvm import container from . import _backend from .. import _make, analysis, transform from .. import module @@ -28,40 +29,6 @@ from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder -@register_relay_node -class TupleValue(Object): - """A tuple value produced by the interpreter.""" - def __init__(self, *fields): - self.__init_handle_by_constructor__( - _make.TupleValue, fields) - - def __getitem__(self, field_no): - return self.fields[field_no] - - def __len__(self): - return len(self.fields) - - def __str__(self): - body = ','.join(str(f) for f in self.fields) - return '({0})'.format(body) - - def __repr__(self): - body = ','.join(repr(f) for f in self.fields) - return '({0})'.format(body) - - def __iter__(self): - return iter(self.fields) - - -@register_relay_node -class Closure(Object): - """A closure produced by the interpreter.""" - - -@register_relay_node -class RecClosure(Object): - """A recursive closure produced by the interpreter.""" - @register_relay_node class ConstructorValue(Object): @@ -80,8 +47,8 @@ def __init__(self, value): def _arg_to_ast(mod, arg): if isinstance(arg, nd.NDArray): return Constant(arg.copyto(nd.cpu(0))) - elif isinstance(arg, TupleValue): - return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) + elif isinstance(arg, container.ADT): + return Tuple([_arg_to_ast(mod, field) for field in arg]) elif isinstance(arg, tuple): return Tuple([_arg_to_ast(mod, field) for field in arg]) elif isinstance(arg, RefValue): diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index aba55ef7d13e..31009008b23c 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -23,20 +23,18 @@ import numpy as np import tvm -from tvm import autotvm +from tvm import autotvm, container +from tvm.object import Object from tvm.relay import expr as _expr from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi import base as _base from . import _vm -from . import vmobj as _obj from .interpreter import Executor -ADT = _obj.ADT - def _convert(arg, cargs): if isinstance(arg, _expr.Constant): cargs.append(arg.data) - elif isinstance(arg, _obj.Object): + elif isinstance(arg, Object): cargs.append(arg) elif isinstance(arg, np.ndarray): nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0)) @@ -47,7 +45,7 @@ def _convert(arg, cargs): field_args = [] for field in arg: _convert(field, field_args) - cargs.append(_obj.tuple_object(field_args)) + cargs.append(container.tuple_object(field_args)) elif isinstance(arg, (_base.numeric_types, bool)): dtype = "int32" if isinstance(arg, (int, bool)) else "float32" value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0)) diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py deleted file mode 100644 index 330257ff9467..000000000000 --- a/python/tvm/relay/backend/vmobj.py +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Runtime Object API.""" -from __future__ import absolute_import as _abs - -from tvm._ffi.object import Object, register_object, getitem_helper -from tvm import ndarray as _nd -from . import _vmobj - - -@register_object("vm.ADT") -class ADT(Object): - """Algebatic data type(ADT) object. - - Parameters - ---------- - tag : int - The tag of ADT. - - fields : list[Object] or tuple[Object] - The source tuple. - """ - def __init__(self, tag, fields): - for f in fields: - assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " - "tvm NDArray type, but received : {0}".format(type(f)) - self.__init_handle_by_constructor__( - _vmobj.ADT, tag, *fields) - - @property - def tag(self): - return _vmobj.GetADTTag(self) - - def __getitem__(self, idx): - return getitem_helper( - self, _vmobj.GetADTFields, len(self), idx) - - def __len__(self): - return _vmobj.GetADTNumberOfFields(self) - - -def tuple_object(fields): - """Create a ADT object from source tuple. - - Parameters - ---------- - fields : list[Object] or tuple[Object] - The source tuple. - - Returns - ------- - ret : ADT - The created object. - """ - for f in fields: - assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " - "NDArray type, but received : {0}".format(type(f)) - return _vmobj.Tuple(*fields) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 473b77ae7a7a..bc5c0e4222fb 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -17,9 +17,9 @@ """Common utilities""" from __future__ import absolute_import as _abs import logging +import numpy as np import tvm -import numpy as np from topi.util import get_const_tuple from .. import expr as _expr from .. import module as _module diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 25967b02df9b..bcf8985657da 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -17,6 +17,7 @@ #pylint: disable=invalid-name """Utilities for testing and benchmarks""" from __future__ import absolute_import as _abs +import numpy as np import tvm import tvm.relay as relay @@ -24,7 +25,6 @@ from tvm.relay import transform from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor from tvm.relay import TensorType, TupleType -import numpy as np from . import mlp from . import resnet diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 1edb27ae5eb3..72b835dddee7 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -32,18 +32,20 @@ # import numpy # import tvm # from tvm import relay +# from tvm import import container as _container # from tvm import nd -# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue +# from tvm.relay.backend.interpreter import RefValue, ConstructorValue PROLOGUE = [ ast.Import([alias('numpy', None)]), ast.Import([alias('tvm', None)]), ast.ImportFrom('tvm', [alias('relay', None)], 0), ast.ImportFrom('tvm', [alias('nd', None)], 0), + ast.ImportFrom('tvm', [alias('container', '_container')], + 0), ast.ImportFrom('tvm.relay.backend.interpreter', [alias('RefValue', None), - alias('TupleValue', None), alias('ConstructorValue', None)], - 0) + 0), ] class PythonConverter(ExprFunctor): @@ -253,7 +255,7 @@ def convert_input(py_input, arg_type): for i in range(len(arg_type.fields)): ret += convert_input( ast.Subscript( - ast.Attribute(py_input, 'fields', Load()), + py_input, ast.Index(Num(i)), Load()), arg_type.fields[i]) return ret @@ -282,7 +284,8 @@ def convert_output(ret_type): assignments += inner_assignments extra_args += inner_args fields.append(inner_output) - return (assignments, extra_args, self.create_call('TupleValue', fields)) + fields = [ast.List(fields, Load())] + return (assignments, extra_args, self.create_call('_container.tuple_object', fields)) # create a function to wrap the call of the lowered op and return # a call to that function @@ -444,7 +447,8 @@ def let_thunk(var): def visit_tuple(self, tup: Expr): fields, ret_defs = self.convert_fields(tup.fields) - return (self.create_call('TupleValue', fields), ret_defs) + fields = [ast.List(fields, Load())] + return (self.create_call('_container.tuple_object', fields), ret_defs) def visit_tuple_getitem(self, tgi: Expr): @@ -534,7 +538,7 @@ def visit_ref_write(self, write: Expr): thunk_name, [], ref_defs + val_defs + [ Assign([ast.Attribute(ref, 'value', Store())], val), - Return(self.create_call('TupleValue', [])) + Return(self.create_call('_container.tuple_object', [])) ]) return (self.create_call(thunk_name, []), [thunk]) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 7fdfdbb101af..eb0a7b75a44d 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -22,6 +22,7 @@ * \brief An interpreter for the Relay IR. */ #include +#include #include #include #include @@ -36,100 +37,82 @@ namespace relay { using namespace runtime; -inline const PackedFunc& GetPackedFunc(const std::string& name) { - const PackedFunc* pf = tvm::runtime::Registry::Get(name); - CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; - return *pf; -} - -/* Object Implementation */ -Closure ClosureNode::make(tvm::Map env, Function func) { - ObjectPtr n = make_object(); +InterpreterClosure::InterpreterClosure(tvm::Map env, + Function func) { + ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); - return Closure(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relay._make.Closure") -.set_body_typed(ClosureNode::make); - TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ClosureNode(" << node->func << ", " << node->env << ")"; - }); +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; +}); +inline const PackedFunc& GetPackedFunc(const std::string& name) { + const PackedFunc* pf = tvm::runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; + return *pf; +} // TODO(@jroesch): this doesn't support mutual letrec /* Object Implementation */ -RecClosure RecClosureNode::make(Closure clos, Var bind) { - ObjectPtr n = make_object(); +RecClosure::RecClosure(InterpreterClosure clos, Var bind) { + ObjectPtr n = make_object(); n->clos = std::move(clos); n->bind = std::move(bind); - return RecClosure(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relay._make.RecClosure") -.set_body_typed(RecClosureNode::make); - TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RecClosureNode(" << node->clos << ")"; +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RecClosureObj(" << node->clos << ")"; }); -TupleValue TupleValueNode::make(tvm::Array value) { - ObjectPtr n = make_object(); - n->fields = value; - return TupleValue(n); -} - -TVM_REGISTER_GLOBAL("relay._make.TupleValue") -.set_body_typed(TupleValueNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleValueNode(" << node->fields << ")"; - }); - - -RefValue RefValueNode::make(ObjectRef value) { - ObjectPtr n = make_object(); +RefValue::RefValue(ObjectRef value) { + ObjectPtr n = make_object(); n->value = value; - return RefValue(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("relay._make.RefValue") -.set_body_typed(RefValueNode::make); +.set_body_typed([](ObjectRef value){ + return RefValue(value); +}); -TVM_REGISTER_NODE_TYPE(RefValueNode); +TVM_REGISTER_NODE_TYPE(RefValueObj); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefValueNode(" << node->value << ")"; +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefValueObj(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(int32_t tag, - tvm::Array fields, - Constructor constructor) { - ObjectPtr n = make_object(); +ConstructorValue::ConstructorValue(int32_t tag, + tvm::Array fields, + Constructor constructor) { + ObjectPtr n = make_object(); n->tag = tag; n->fields = fields; n->constructor = constructor; - return ConstructorValue(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") -.set_body_typed(ConstructorValueNode::make); +.set_body_typed([](int32_t tag, tvm::Array fields, + Constructor constructor) { + return ConstructorValue(tag, fields, constructor); +}); -TVM_REGISTER_NODE_TYPE(ConstructorValueNode); +TVM_REGISTER_NODE_TYPE(ConstructorValueObj); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorValueNode(" << node->tag << "," +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")"; }); @@ -187,7 +170,7 @@ struct Stack { class InterpreterState; /*! \brief A container capturing the state of the interpreter. */ -class InterpreterStateNode : public Object { +class InterpreterStateObj : public Object { public: using Frame = tvm::Map; using Stack = tvm::Array; @@ -206,16 +189,16 @@ class InterpreterStateNode : public Object { static InterpreterState make(Expr current_expr, Stack stack); static constexpr const char* _type_key = "relay.InterpreterState"; - TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object); }; class InterpreterState : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateNode); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj); }; -InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { - ObjectPtr n = make_object(); +InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) { + ObjectPtr n = make_object(); n->current_expr = std::move(current_expr); n->stack = std::move(stack); return InterpreterState(n); @@ -292,7 +275,7 @@ class Interpreter : values.push_back(field_value); } - return TupleValueNode::make(values); + return ADT::Tuple(values); } ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) { @@ -310,9 +293,9 @@ class Interpreter : } // We must use mutation here to build a self referential closure. - auto closure = ClosureNode::make(captured_mod, func); + InterpreterClosure closure(captured_mod, func); if (letrec_name.defined()) { - return RecClosureNode::make(closure, letrec_name); + return RecClosure(closure, letrec_name); } return std::move(closure); } @@ -374,16 +357,15 @@ class Interpreter : fset_input(arg_counter++, arg, true); } } else { - const TupleValueNode* tuple = arg.as(); - CHECK(tuple != nullptr); + const ADT adt = Downcast(arg); if (state & kNeedInputData) { - for (size_t i = 0; i < tuple->fields.size(); ++i) { - fset_input(arg_counter++, tuple->fields[i], false); + for (size_t i = 0; i < adt.size(); ++i) { + fset_input(arg_counter++, adt[i], false); } } if (state & kNeedInputShape) { - for (size_t i = 0; i < tuple->fields.size(); ++i) { - fset_input(arg_counter++, tuple->fields[i], true); + for (size_t i = 0; i < adt.size(); ++i) { + fset_input(arg_counter++, adt[i], true); } } } @@ -458,14 +440,14 @@ class Interpreter : } // Marshal the arguments. - // Handle tuple input/output by flattening them. + // Handle adt input/output by flattening them. size_t arg_len = 0; for (size_t i = 0; i < args.size(); ++i) { if (args[i]->IsInstance()) { ++arg_len; } else { - const auto* tvalue = args[i].as(); - arg_len += tvalue->fields.size(); + auto adt = Downcast(args[i]); + arg_len += adt.size(); } } size_t num_inputs = arg_len; @@ -495,10 +477,9 @@ class Interpreter : if (arg->IsInstance()) { fset_input(arg_counter++, arg); } else { - const TupleValueNode* tuple = arg.as(); - CHECK(tuple != nullptr); - for (size_t i = 0; i < tuple->fields.size(); ++i) { - fset_input(arg_counter++, tuple->fields[i]); + auto adt = Downcast(arg); + for (size_t i = 0; i < adt.size(); ++i) { + fset_input(arg_counter++, adt[i]); } } } @@ -541,7 +522,7 @@ class Interpreter : TVMRetValue rv; if (const TupleTypeNode* rtype = func->body->checked_type().as()) { CHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); - Array fields; + std::vector fields; for (size_t i = 0; i < rtype->fields.size(); ++i) { if (is_dyn) { auto sh = out_shapes[i]; @@ -552,7 +533,7 @@ class Interpreter : } } packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); - return TupleValueNode::make(fields); + return ADT::Tuple(fields); } else { ObjectRef out_tensor; if (is_dyn) { @@ -569,7 +550,7 @@ class Interpreter : } // Invoke the closure - ObjectRef Invoke(const Closure& closure, + ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. @@ -594,7 +575,7 @@ class Interpreter : } if (bind.defined()) { - locals.Set(bind, RecClosureNode::make(closure, bind)); + locals.Set(bind, RecClosure(closure, bind)); } return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); @@ -616,14 +597,14 @@ class Interpreter : "fusing and lowering"; } if (auto con = call->op.as()) { - return ConstructorValueNode::make(con->tag, args, GetRef(con)); + return ConstructorValue(con->tag, args, GetRef(con)); } // Now we just evaluate and expect to find a closure. ObjectRef fn_val = Eval(call->op); - if (const ClosureNode* closure_node = fn_val.as()) { - auto closure = GetRef(closure_node); + if (const InterpreterClosureObj* closure_node = fn_val.as()) { + auto closure = GetRef(closure_node); return this->Invoke(closure, args); - } else if (const RecClosureNode* closure_node = fn_val.as()) { + } else if (const RecClosureObj* closure_node = fn_val.as()) { return this->Invoke(closure_node->clos, args, closure_node->bind); } else { LOG(FATAL) << "internal error: type error, expected function value in the call " @@ -646,12 +627,13 @@ class Interpreter : ObjectRef VisitExpr_(const TupleGetItemNode* op) final { ObjectRef val = Eval(op->tuple); - auto product_node = val.as(); - CHECK(product_node) - << "interal error: when evaluating TupleGetItem expected a tuple value"; - CHECK_LT(static_cast(op->index), product_node->fields.size()) + const auto* adt_obj = val.as(); + CHECK(adt_obj) + << "interal error: when evaluating TupleGetItem expected an ADT value"; + auto adt = GetRef(adt_obj); + CHECK_LT(static_cast(op->index), adt.size()) << "internal error: index out of bounds"; - return product_node->fields[op->index]; + return adt[op->index]; } ObjectRef VisitExpr_(const IfNode* op) final { @@ -677,9 +659,9 @@ class Interpreter : ObjectRef VisitExpr_(const RefWriteNode* op) final { ObjectRef r = Eval(op->ref); - if (const RefValueNode* rv = r.as()) { + if (const RefValueObj* rv = r.as()) { rv->value = Eval(op->value); - return TupleValueNode::make({}); + return ADT::Tuple(std::vector()); } else { LOG(FATAL) << "type error, type system should have caught this"; return ObjectRef(); @@ -687,12 +669,12 @@ class Interpreter : } ObjectRef VisitExpr_(const RefCreateNode* op) final { - return RefValueNode::make(Eval(op->value)); + return RefValue(Eval(op->value)); } ObjectRef VisitExpr_(const RefReadNode* op) final { ObjectRef r = Eval(op->ref); - if (const RefValueNode* rv = r.as()) { + if (const RefValueObj* rv = r.as()) { return rv->value; } else { LOG(FATAL) << "type error, type system should have caught this"; @@ -712,7 +694,7 @@ class Interpreter : } bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final { - const ConstructorValueNode* cvn = v.as(); + const ConstructorValueObj* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; CHECK_NE(op->constructor->tag, -1); CHECK_NE(cvn->tag, -1); @@ -729,11 +711,10 @@ class Interpreter : } bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final { - const TupleValueNode* tvn = v.as(); - CHECK(tvn) << "need to be a tuple for match"; - CHECK_EQ(op->patterns.size(), tvn->fields.size()); + auto adt = Downcast(v); + CHECK_EQ(op->patterns.size(), adt.size()); for (size_t i = 0; i < op->patterns.size(); ++i) { - if (!VisitPattern(op->patterns[i], tvn->fields[i])) { + if (!VisitPattern(op->patterns[i], adt[i])) { return false; } } @@ -750,12 +731,12 @@ class Interpreter : } InterpreterState get_state(Expr e = Expr()) const { - InterpreterStateNode::Stack stack; + InterpreterStateObj::Stack stack; for (auto fr : this->stack_.frames) { - InterpreterStateNode::Frame frame = fr.locals; + InterpreterStateObj::Frame frame = fr.locals; stack.push_back(frame); } - auto state = InterpreterStateNode::make(e, stack); + auto state = InterpreterStateObj::make(e, stack); return state; } @@ -804,8 +785,5 @@ CreateInterpreter( TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") .set_body_typed(CreateInterpreter); -TVM_REGISTER_NODE_TYPE(ClosureNode); -TVM_REGISTER_NODE_TYPE(TupleValueNode); - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index af4f4390aa26..1e1d626a02ca 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "pattern_util.h" namespace tvm { @@ -187,10 +188,11 @@ class ConstantFolder : public ExprMutator { << "invalid dimension after constant eval"; } return ConstantNode::make(nd_array); - } else if (const auto* val = value.as()) { + } else if (const auto* val = value.as()) { + runtime::ADT adt = GetRef(val); Array fields; - for (ObjectRef field : val->fields) { - fields.push_back(ObjectToExpr(field)); + for (size_t i = 0; i < adt.size(); ++i) { + fields.push_back(ObjectToExpr(adt[i])); } return TupleNode::make(fields); } else { diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index c7935c49dfaf..e9e37d2e9102 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -935,11 +935,12 @@ class PartialEvaluator : public ExprFunctor if (v->IsInstance()) { auto nd_array = Downcast(v); return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array))); - } else if (const TupleValueNode* op = v.as()) { + } else if (const runtime::ADTObj* op = v.as()) { std::vector fields; tvm::Array fields_dyn; - for (const ObjectRef& field : op->fields) { - PStatic ps = Reify(field, ll); + auto adt = GetRef(op); + for (size_t i = 0; i < adt.size(); ++i) { + PStatic ps = Reify(adt[i], ll); fields.push_back(ps); fields_dyn.push_back(ps->dynamic); } diff --git a/src/runtime/vm/object.cc b/src/runtime/container.cc similarity index 71% rename from src/runtime/vm/object.cc rename to src/runtime/container.cc index b7174abc4ba8..cd426482d285 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/container.cc @@ -18,38 +18,28 @@ */ /*! - * \file src/runtime/vm/object.cc - * \brief VM related objects. + * \file src/runtime/container.cc + * \brief Implementations of common plain old data (POD) containers. */ -#include #include +#include #include #include -#include #include -#include -#include "../runtime_base.h" namespace tvm { namespace runtime { -namespace vm { - -Closure::Closure(size_t func_index, std::vector free_vars) { - auto ptr = make_object(); - ptr->func_index = func_index; - ptr->free_vars = std::move(free_vars); - data_ = std::move(ptr); -} +using namespace vm; -TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") +TVM_REGISTER_GLOBAL("container._GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.tag()); }); -TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") +TVM_REGISTER_GLOBAL("container._GetADTSize") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); @@ -57,7 +47,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") }); -TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") +TVM_REGISTER_GLOBAL("container._GetADTFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; @@ -66,7 +56,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("_vmobj.Tuple") +TVM_REGISTER_GLOBAL("container._Tuple") .set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; for (auto i = 0; i < args.size(); ++i) { @@ -75,7 +65,7 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple") *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("_vmobj.ADT") +TVM_REGISTER_GLOBAL("container._ADT") .set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); @@ -88,15 +78,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT") TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); -} // namespace vm + } // namespace runtime } // namespace tvm - -using namespace tvm::runtime; - -int TVMGetObjectTag(TVMObjectHandle handle, int* tag) { - API_BEGIN(); - int res = static_cast(static_cast(handle)->type_index()); - *tag = res; - API_END(); -} diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c5ab1fdb4b62..84a3e26fb7f9 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -45,6 +45,12 @@ namespace tvm { namespace runtime { namespace vm { +VMClosure::VMClosure(size_t func_index, std::vector free_vars) { + auto ptr = make_object(); + ptr->func_index = func_index; + ptr->free_vars = std::move(free_vars); + data_ = std::move(ptr); +} inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) { // We could put cache in here, from ctx to storage allocator. @@ -906,7 +912,7 @@ void VirtualMachine::RunLoop() { } case Opcode::InvokeClosure: { auto object = ReadRegister(instr.closure); - const auto* closure = object.as(); + const auto* closure = object.as(); std::vector args; for (auto free_var : closure->free_vars) { @@ -1008,7 +1014,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_freevar; i++) { free_vars.push_back(ReadRegister(instr.free_vars[i])); } - WriteRegister(instr.dst, Closure(instr.func_index, free_vars)); + WriteRegister(instr.dst, VMClosure(instr.func_index, free_vars)); pc_++; goto main_loop; } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 19e5b1ff9c3c..8ff0d410184a 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -62,16 +62,11 @@ def convert_to_list(x): def vmobj_to_list(o): if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.ADT): + elif isinstance(o, tvm.container.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) return result - elif isinstance(o, tvm.relay.backend.interpreter.TupleValue): - result = [] - for f in o.fields: - result.append(vmobj_to_list(f)) - return result elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): if o.constructor.name_hint == 'Cons': tl = vmobj_to_list(o.fields[1]) diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py index cfb3fd42f834..3513832184ce 100644 --- a/tests/python/relay/benchmarking/benchmark_vm.py +++ b/tests/python/relay/benchmarking/benchmark_vm.py @@ -19,10 +19,9 @@ import tvm from tvm.contrib import graph_runtime -from tvm import relay +from tvm import relay, container from tvm.relay import testing from tvm.relay import vm -from tvm.relay import vmobj as _obj def benchmark_execution(mod, @@ -69,7 +68,7 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32', ftimer = rly_vm.mod.time_evaluator("invoke", ctx, number=number, repeat=repeat) # Measure in millisecond. - prof_res = np.array(ftimer("main", _obj.Tensor(data)).results) * 1000 + prof_res = np.array(ftimer("main", data).results) * 1000 print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 8e304bd856e9..00c07b928b1a 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -117,7 +117,7 @@ def tree_to_dict(t): def vmobj_to_list(o, dtype="float32"): if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.ADT): + elif isinstance(o, tvm.container.ADT): if len(o) == 0: tensor_nil = p.get_var("tensor_nil", dtype=dtype) if tensor_nil.tag == o.tag: diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 85bba4402ea2..11a9e05b3e7f 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -18,8 +18,7 @@ import tvm import tvm.testing from tvm import nd -from tvm import relay -from tvm.relay.backend.interpreter import TupleValue +from tvm import relay, container from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.scope_builder import ScopeBuilder from tvm.relay import testing, create_executor @@ -39,7 +38,8 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): def test_tuple_value(): - tv = TupleValue(relay.const(1), relay.const(2), relay.const(3)) + tv = container.tuple_object([relay.const(1), relay.const(2), + relay.const(3)]) np.testing.assert_allclose(tv[0].data.asnumpy(), 1) np.testing.assert_allclose(tv[1].data.asnumpy(), 2) np.testing.assert_allclose(tv[2].data.asnumpy(), 3) @@ -178,7 +178,7 @@ def test_function_taking_adt_ref_tuple(): ], prelude.cons) ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32'))) - tuple_value = TupleValue(*[ + tuple_value = container.tuple_object([ nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10) ]) @@ -202,8 +202,8 @@ def test_function_taking_adt_ref_tuple(): res_tuple = id_func(tuple_value) for i in range(10): - tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(), - tuple_value.fields[i].asnumpy()) + tvm.testing.assert_allclose(res_tuple[i].asnumpy(), + tuple_value[i].asnumpy()) def test_tuple_passing(): x = relay.var('x', type_annotation=relay.ty.TupleType([ @@ -224,7 +224,8 @@ def test_tuple_passing(): out = f((10, 8)) tvm.testing.assert_allclose(out.asnumpy(), np.array(10)) # Second use a tuple value. - value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12))) + value_tuple = container.tuple_object([nd.array(np.array(11)), + nd.array(np.array(12))]) out = f(value_tuple) tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index f87f90a85a0b..76aa697a2aab 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -19,7 +19,8 @@ from tvm import relay from tvm.relay.testing import to_python, run_as_python from tvm.relay.prelude import Prelude -from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue +from tvm.container import ADT +from tvm.relay.backend.interpreter import RefValue, ConstructorValue # helper: uses a dummy let binding to sequence a list # of expressions: expr1; expr2; expr3, etc. @@ -45,10 +46,10 @@ def assert_tensor_value(candidate, val): assert np.array_equal(candidate.asnumpy(), np.array(val)) -# assert that the candidate is a TupleValue with the indicate number of fields -def assert_tuple_value(candidate, fields): - assert isinstance(candidate, TupleValue) - assert len(candidate.fields) == fields +# assert that the candidate is an ADT with the indicated number of fields +def assert_adt_len(candidate, fields): + assert isinstance(candidate, ADT) + assert len(candidate) == fields # assert that the candidate is a ConstructorValue with the approrpaite constructor @@ -62,7 +63,7 @@ def assert_constructor_value(candidate, constructor, fields): def test_create_empty_tuple(): empty = relay.Tuple([]) tup_val = run_as_python(empty) - assert_tuple_value(tup_val, 0) + assert_adt_len(tup_val, 0) def test_create_scalar(): @@ -87,12 +88,12 @@ def test_create_nested_tuple(): ]) ]) tup_val = run_as_python(relay_tup) - assert_tuple_value(tup_val, 3) + assert_adt_len(tup_val, 3) for i in range(2): - assert_tensor_value(tup_val.fields[i], i + 1) - assert_tuple_value(tup_val.fields[2], 2) + assert_tensor_value(tup_val[i], i + 1) + assert_adt_len(tup_val[2], 2) for i in range(2): - assert_tensor_value(tup_val.fields[2].fields[i], i + 3) + assert_tensor_value(tup_val[2][i], i + 3) def test_tuple_get_item(): @@ -118,23 +119,23 @@ def test_create_let(): v = relay.Var('v') let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v])) tup_val = run_as_python(let) - assert_tuple_value(tup_val, 2) - assert_tuple_value(tup_val.fields[0], 0) - assert_tuple_value(tup_val.fields[1], 0) + assert_adt_len(tup_val, 2) + assert_adt_len(tup_val[0], 0) + assert_adt_len(tup_val[1], 0) def test_create_ref(): relay_ref = relay.RefCreate(relay.Tuple([])) ref_val = run_as_python(relay_ref) assert isinstance(ref_val, RefValue) - assert_tuple_value(ref_val.value, 0) + assert_adt_len(ref_val.value, 0) def test_ref_read(): v = relay.Var('v') assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v)) read_val = run_as_python(assign) - assert_tuple_value(read_val, 0) + assert_adt_len(read_val, 0) def test_ref_write(): @@ -143,7 +144,7 @@ def test_ref_write(): initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])), relay.RefWrite(v, relay.Tuple([relay.const(2)]))) write_val = run_as_python(initial_write) - assert_tuple_value(write_val, 0) + assert_adt_len(write_val, 0) # now ensure that the value, once written, can be read back # (we read the value before and after mutation) @@ -155,11 +156,11 @@ def test_ref_write(): seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])), relay.Tuple([relay.RefRead(w), relay.RefRead(v)])))) read_val = run_as_python(read_after_write) - assert_tuple_value(read_val, 2) - assert_tuple_value(read_val.fields[0], 1) - assert_tuple_value(read_val.fields[1], 1) - assert_tensor_value(read_val.fields[0].fields[0], 1) - assert_tensor_value(read_val.fields[1].fields[0], 2) + assert_adt_len(read_val, 2) + assert_adt_len(read_val[0], 1) + assert_adt_len(read_val[1], 1) + assert_tensor_value(read_val[0][0], 1) + assert_tensor_value(read_val[1][0], 2) def test_if(): @@ -191,7 +192,7 @@ def test_local_function(): call2 = relay.Let(f, ident, f(relay.const(2))) call_val1 = run_as_python(call1) - assert_tuple_value(call_val1, 0) + assert_adt_len(call_val1, 0) call_val2 = run_as_python(call2) assert_tensor_value(call_val2, 2) @@ -211,9 +212,9 @@ def test_global_function(): assert_tensor_value(call_val1, 1) call_val2 = run_as_python(call2, mod) - assert_tuple_value(call_val2, 2) - assert_tensor_value(call_val2.fields[0], 2) - assert_tensor_value(call_val2.fields[1], 2) + assert_adt_len(call_val2, 2) + assert_tensor_value(call_val2[0], 2) + assert_tensor_value(call_val2[1], 2) def test_constructor(): @@ -230,7 +231,7 @@ def test_constructor(): box_val_tup = run_as_python(init_box_tup, mod) assert_constructor_value(box_val_tup, box_ctor, 1) - assert_tuple_value(box_val_tup.fields[0], 0) + assert_adt_len(box_val_tup.fields[0], 0) def test_match_wildcard(): @@ -341,7 +342,7 @@ def test_local_recursion(): assert_tensor_value(val.fields[1].fields[0], 2) assert_constructor_value(val.fields[1].fields[1], p.cons, 2) assert_tensor_value(val.fields[1].fields[1].fields[0], 3) - assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0) + assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0) def test_global_recursion(): @@ -372,7 +373,7 @@ def test_global_recursion(): call2 = copy_def(p.cons(relay.Tuple([]), p.nil())) val2 = run_as_python(call2, mod) assert_constructor_value(val2, p.cons, 2) - assert_tuple_value(val2.fields[0], 0) + assert_adt_len(val2.fields[0], 0) assert_constructor_value(val2.fields[1], p.nil, 0) @@ -437,10 +438,10 @@ def test_arbitrary_let_nesting(): ]) tup_val = run_as_python(expr, mod) - assert_tuple_value(tup_val, 3) - assert_tensor_value(tup_val.fields[0], 2) - assert_tensor_value(tup_val.fields[1], 3) - assert_tensor_value(tup_val.fields[2], 4) + assert_adt_len(tup_val, 3) + assert_tensor_value(tup_val[0], 2) + assert_tensor_value(tup_val[1], 3) + assert_tensor_value(tup_val[2], 4) def test_ref_execution_order(): @@ -475,12 +476,12 @@ def test_ref_execution_order(): ]))) tup_val = run_as_python(expr) - assert_tuple_value(tup_val, 5) - assert_tensor_value(tup_val.fields[0], 1) - assert_tensor_value(tup_val.fields[1], 2) - assert_tensor_value(tup_val.fields[2], 3) - assert_tensor_value(tup_val.fields[3], 4) - assert_tensor_value(tup_val.fields[4], 5) + assert_adt_len(tup_val, 5) + assert_tensor_value(tup_val[0], 1) + assert_tensor_value(tup_val[1], 2) + assert_tensor_value(tup_val[2], 3) + assert_tensor_value(tup_val[3], 4) + assert_tensor_value(tup_val[4], 5) def test_op_add(): @@ -501,6 +502,7 @@ def verify_stack(dshapes, axis): args.append(relay.const(data)) call = relay.stack(relay.Tuple(args), axis) call_val = run_as_python(call) + type(call_val) assert_tensor_value(call_val, ref_res) verify_stack([(2,), (2,), (2,)], -1) @@ -517,9 +519,9 @@ def verify_split(shape, indices_or_sections, axis=0): ref_res = np.split(x, indices_or_sections, axis=axis) call = relay.split(relay.const(x), indices_or_sections, axis=axis) call_val = run_as_python(call) - assert_tuple_value(call_val, len(ref_res)) + assert_adt_len(call_val, len(ref_res)) for i in range(len(ref_res)): - assert_tensor_value(call_val.fields[i], ref_res[i]) + assert_tensor_value(call_val[i], ref_res[i]) verify_split((2, 3), 2) verify_split((5, 3), [3]) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index d53360d76656..d4a7a1a25689 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -58,7 +58,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): def vmobj_to_list(o): if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vm.ADT): + elif isinstance(o, tvm.container.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py deleted file mode 100644 index 82a2b116d82c..000000000000 --- a/tests/python/relay/test_vm_object.py +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np -import tvm -from tvm.relay import vm - -def test_adt(): - arr = tvm.nd.array([1,2,3]) - y = vm.ADT(0, [arr, arr]) - - assert len(y) == 2 - assert isinstance(y, vm.ADT) - y[0:1][-1] == arr - assert y.tag == 0 - assert isinstance(arr, tvm.nd.NDArray) - - -if __name__ == "__main__": - test_adt() diff --git a/tests/python/unittest/test_container.py b/tests/python/unittest/test_container.py new file mode 100644 index 000000000000..7bdab82d7a65 --- /dev/null +++ b/tests/python/unittest/test_container.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +from tvm import nd, relay +from tvm import container as _container + + +def test_adt_constructor(): + arr = nd.array([1, 2, 3]) + fields = [arr, arr] + y = _container.ADT(0, [arr, arr]) + + assert len(y) == 2 + assert isinstance(y, _container.ADT) + y[0:1][-1] == arr + assert y.tag == 0 + assert isinstance(arr, nd.NDArray) + + +def test_tuple_object(): + x = relay.var( + 'x', + type_annotation=relay.ty.TupleType([ + relay.ty.TensorType((), 'int32'), + relay.ty.TensorType((), 'int32') + ])) + + fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) + mod = relay.Module.from_expr(fn) + + exe = relay.create_executor( + kind="vm", mod=mod, ctx=nd.cpu(), target="llvm") + f = exe.evaluate() + value_tuple = _container.tuple_object( + [nd.array(np.array(11)), + nd.array(np.array(12))]) + # pass an ADT object to evaluate + out = f(value_tuple) + tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) + + +if __name__ == "__main__": + test_adt_constructor() + test_tuple_object()