diff --git a/apps/lldb/tvm.py b/apps/lldb/tvm.py index ec6dc439ebc37..d7779b011c0ad 100644 --- a/apps/lldb/tvm.py +++ b/apps/lldb/tvm.py @@ -103,11 +103,9 @@ def __lldb_init_module(debugger, _): "tvm::relay::Span", "tvm::relay::TempExpr", "tvm::relay::TensorType", - "tvm::relay::TensorValue", "tvm::relay::Tuple", "tvm::relay::TupleGetItem", "tvm::relay::TupleType", - "tvm::relay::TupleValue", "tvm::relay::Type", "tvm::relay::TypeCall", "tvm::relay::TypeConstraint", diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index dc35fc26486a8..51bdfc253f522 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -25,8 +25,8 @@ * Given a Relay module, and a Relay expression it produces a value. * * The interpreter's values are a naive representation of the values that - * can be produced by a Relay program and are exposed via tvm::Node's - * system to Python for introspection and debugging. + * can be produced by a Relay program and are exposed via tvm's object + * protocol to Python for introspection and debugging. * * The interpreter's intent is to serve as a reference semantics for the Relay IR, * as well as for debugging and testing. @@ -38,6 +38,8 @@ #include #include #include +#include +#include namespace tvm { namespace relay { @@ -64,100 +66,42 @@ namespace relay { runtime::TypedPackedFunc CreateInterpreter(Module mod, DLContext context, Target target); -/*! \brief A Relay closure, i.e a scope and a function. */ -class Closure; - -/*! \brief The container type of Closures. */ -class ClosureNode : public Object { - public: - /*! \brief The set of free variables in the closure. - * - * These are the captured variables which are required for - * evaluation when we call the closure. - */ - tvm::Map env; - /*! \brief The function which implements the closure. - * - * \note May reference the variables contained in the env. - */ - Function func; - - ClosureNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("env", &env); - v->Visit("func", &func); - } - - TVM_DLL static Closure make(tvm::Map env, Function func); - - static constexpr const char* _type_key = "relay.Closure"; - TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object); -}; - -class Closure : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode); -}; - /*! \brief A Relay Recursive Closure. A closure that has a name. */ class RecClosure; /*! \brief The container type of RecClosure. */ -class RecClosureNode : public Object { +class RecClosureObj : public Object { public: /*! \brief The closure. */ - Closure clos; + runtime::Closure clos; /*! \brief variable the closure bind to. */ Var bind; - RecClosureNode() {} + RecClosureObj() {} void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("clos", &clos); v->Visit("bind", &bind); } - TVM_DLL static RecClosure make(Closure clos, Var bind); + TVM_DLL static RecClosure make(runtime::Closure clos, Var bind); static constexpr const char* _type_key = "relay.RecClosure"; - TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object); }; class RecClosure : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode); -}; - -/*! \brief A tuple value. */ -class TupleValue; - -/*! \brief Tuple (x, ... y). */ -struct TupleValueNode : Object { - tvm::Array fields; - - TupleValueNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } - - TVM_DLL static TupleValue make(tvm::Array value); - - static constexpr const char* _type_key = "relay.TupleValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object); -}; - -class TupleValue : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj); }; /*! \brief A reference value. */ class RefValue; -struct RefValueNode : Object { +struct RefValueObj : Object { mutable ObjectRef value; - RefValueNode() {} + RefValueObj() {} void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); @@ -166,18 +110,18 @@ struct RefValueNode : Object { TVM_DLL static RefValue make(ObjectRef val); static constexpr const char* _type_key = "relay.RefValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object); }; class RefValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj); }; /*! \brief An ADT constructor value. */ class ConstructorValue; -struct ConstructorValueNode : Object { +struct ConstructorValueObj : Object { int32_t tag; tvm::Array fields; @@ -196,12 +140,12 @@ struct ConstructorValueNode : Object { Constructor construtor = {}); static constexpr const char* _type_key = "relay.ConstructorValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object); }; class ConstructorValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); }; } // namespace relay diff --git a/include/tvm/runtime/common_object.h b/include/tvm/runtime/common_object.h new file mode 100644 index 0000000000000..632101088d22d --- /dev/null +++ b/include/tvm/runtime/common_object.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/common_object.h + * \brief The objects that are commonly used by different runtime, i.e. Relay VM + * and interpreter. + */ +#ifndef TVM_RUNTIME_COMMON_OBJECT_H_ +#define TVM_RUNTIME_COMMON_OBJECT_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief An object representing a closure. This object is used by both the + * Relay VM and interpreter. + */ +class ClosureObj : public Object { + public: + /*! + * \brief The index into the function list. The function could be any + * function object that is compatible to a certain runtime, i.e. VM or + * interpreter. + */ + size_t func_index; + /*! \brief The free variables of the closure. */ + std::vector free_vars; + + static constexpr const uint32_t _type_index = TypeIndex::kClosure; + static constexpr const char* _type_key = "Closure"; + TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object); +}; + +/*! \brief reference to closure. */ +class Closure : public ObjectRef { + public: + Closure(size_t func_index, std::vector free_vars) { + auto ptr = make_object(); + ptr->func_index = func_index; + ptr->free_vars = std::move(free_vars); + data_ = std::move(ptr); + } + + TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_COMMON_OBJECT_H_ diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 92d3e7149463f..29125b4a25a0c 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -169,8 +169,8 @@ class ADTObj : public Object, public InplaceArrayBase { uint32_t size; // The fields of the structure follows directly in memory. - static constexpr const uint32_t _type_index = TypeIndex::kVMADT; - static constexpr const char* _type_key = "vm.ADT"; + static constexpr const uint32_t _type_index = TypeIndex::kADT; + static constexpr const char* _type_key = "ADT"; TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); private: diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index a2e9188fcd2b7..8e7727894af1c 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -51,10 +51,9 @@ namespace runtime { enum TypeIndex { /*! \brief Root object type. */ kRoot = 0, - kVMTensor = 1, - kVMClosure = 2, - kVMADT = 3, - kRuntimeModule = 4, + kClosure = 1, + kADT = 2, + kRuntimeModule = 3, kStaticIndexEnd, /*! \brief Type index is allocated during runtime. */ kDynamic = kStaticIndexEnd diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 990ecf5ea7336..dce65c2c184c7 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_VM_H_ #include +#include #include #include #include @@ -36,27 +37,6 @@ namespace tvm { namespace runtime { namespace vm { -/*! \brief An object representing a closure. */ -class ClosureObj : public Object { - public: - /*! \brief The index into the VM function table. */ - size_t func_index; - /*! \brief The free variables of the closure. */ - std::vector free_vars; - - static constexpr const uint32_t _type_index = TypeIndex::kVMClosure; - static constexpr const char* _type_key = "vm.Closure"; - TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object); -}; - -/*! \brief reference to closure. */ -class Closure : public ObjectRef { - public: - Closure(size_t func_index, std::vector free_vars); - - TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); -}; - /*! \brief Magic number for NDArray list file */ constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; diff --git a/python/tvm/container.py b/python/tvm/container.py index 274fc1f4027ce..70aa0c6ef6c3a 100644 --- a/python/tvm/container.py +++ b/python/tvm/container.py @@ -16,8 +16,10 @@ # under the License. """Container data structures used in TVM DSL.""" from __future__ import absolute_import as _abs -from ._ffi.object import Object, register_object +from tvm import ndarray as _nd from . import _api_internal +from ._ffi.object import Object, register_object, getitem_helper +from ._ffi.function import _init_api @register_object class Array(Object): @@ -114,3 +116,56 @@ class LoweredFunc(Object): MixedFunc = 0 HostFunc = 1 DeviceFunc = 2 + + +@register_object +class ADT(Object): + """Algebatic data type(ADT) object. + + Parameters + ---------- + tag : int + The tag of ADT. + + fields : list[Object] or tuple[Object] + The source tuple. + """ + def __init__(self, tag, fields): + for f in fields: + assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \ + "tvm NDArray type, but received : {0}".format(type(f)) + self.__init_handle_by_constructor__(_ADT, tag, *fields) + + @property + def tag(self): + return _GetADTTag(self) + + def __getitem__(self, idx): + return getitem_helper( + self, _GetADTFields, len(self), idx) + + def __len__(self): + return _GetADTSize(self) + + +def tuple_object(fields=None): + """Create a ADT object from source tuple. + + Parameters + ---------- + fields : list[Object] or tuple[Object] + The source tuple. + + Returns + ------- + ret : ADT + The created object. + """ + fields = fields if fields else [] + for f in fields: + assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \ + "NDArray type, but received : {0}".format(type(f)) + return _Tuple(*fields) + + +_init_api("tvm.container") diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c7cbcf096a6cf..2432ec31cfe57 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -38,7 +38,6 @@ from . import feature from .backend import vm from .backend import profiler_vm -from .backend import vmobj # Root operators from .op import Op diff --git a/python/tvm/relay/backend/_vmobj.py b/python/tvm/relay/backend/_vmobj.py deleted file mode 100644 index 1e7efa4673873..0000000000000 --- a/python/tvm/relay/backend/_vmobj.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""The VM Object FFI namespace.""" -from tvm._ffi.function import _init_api - -_init_api("_vmobj", __name__) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 59d9a8fae43c0..f85f92fbdd593 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -20,6 +20,7 @@ import numpy as np +from tvm import container from . import _backend from .. import _make, analysis, transform from .. import module @@ -28,40 +29,6 @@ from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder -@register_relay_node -class TupleValue(Object): - """A tuple value produced by the interpreter.""" - def __init__(self, *fields): - self.__init_handle_by_constructor__( - _make.TupleValue, fields) - - def __getitem__(self, field_no): - return self.fields[field_no] - - def __len__(self): - return len(self.fields) - - def __str__(self): - body = ','.join(str(f) for f in self.fields) - return '({0})'.format(body) - - def __repr__(self): - body = ','.join(repr(f) for f in self.fields) - return '({0})'.format(body) - - def __iter__(self): - return iter(self.fields) - - -@register_relay_node -class Closure(Object): - """A closure produced by the interpreter.""" - - -@register_relay_node -class RecClosure(Object): - """A recursive closure produced by the interpreter.""" - @register_relay_node class ConstructorValue(Object): @@ -80,8 +47,8 @@ def __init__(self, value): def _arg_to_ast(mod, arg): if isinstance(arg, nd.NDArray): return Constant(arg.copyto(nd.cpu(0))) - elif isinstance(arg, TupleValue): - return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) + elif isinstance(arg, container.ADT): + return Tuple([_arg_to_ast(mod, field) for field in arg]) elif isinstance(arg, tuple): return Tuple([_arg_to_ast(mod, field) for field in arg]) elif isinstance(arg, RefValue): diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index aba55ef7d13e6..6279b4146ce79 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -23,20 +23,18 @@ import numpy as np import tvm -from tvm import autotvm +from tvm import autotvm, container from tvm.relay import expr as _expr from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi import base as _base +from tvm._ffi.object import Object from . import _vm -from . import vmobj as _obj from .interpreter import Executor -ADT = _obj.ADT - def _convert(arg, cargs): if isinstance(arg, _expr.Constant): cargs.append(arg.data) - elif isinstance(arg, _obj.Object): + elif isinstance(arg, Object): cargs.append(arg) elif isinstance(arg, np.ndarray): nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0)) @@ -47,7 +45,7 @@ def _convert(arg, cargs): field_args = [] for field in arg: _convert(field, field_args) - cargs.append(_obj.tuple_object(field_args)) + cargs.append(container.tuple_object(field_args)) elif isinstance(arg, (_base.numeric_types, bool)): dtype = "int32" if isinstance(arg, (int, bool)) else "float32" value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0)) diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py deleted file mode 100644 index 330257ff94674..0000000000000 --- a/python/tvm/relay/backend/vmobj.py +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Runtime Object API.""" -from __future__ import absolute_import as _abs - -from tvm._ffi.object import Object, register_object, getitem_helper -from tvm import ndarray as _nd -from . import _vmobj - - -@register_object("vm.ADT") -class ADT(Object): - """Algebatic data type(ADT) object. - - Parameters - ---------- - tag : int - The tag of ADT. - - fields : list[Object] or tuple[Object] - The source tuple. - """ - def __init__(self, tag, fields): - for f in fields: - assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " - "tvm NDArray type, but received : {0}".format(type(f)) - self.__init_handle_by_constructor__( - _vmobj.ADT, tag, *fields) - - @property - def tag(self): - return _vmobj.GetADTTag(self) - - def __getitem__(self, idx): - return getitem_helper( - self, _vmobj.GetADTFields, len(self), idx) - - def __len__(self): - return _vmobj.GetADTNumberOfFields(self) - - -def tuple_object(fields): - """Create a ADT object from source tuple. - - Parameters - ---------- - fields : list[Object] or tuple[Object] - The source tuple. - - Returns - ------- - ret : ADT - The created object. - """ - for f in fields: - assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " - "NDArray type, but received : {0}".format(type(f)) - return _vmobj.Tuple(*fields) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 473b77ae7a7aa..bc5c0e4222fb4 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -17,9 +17,9 @@ """Common utilities""" from __future__ import absolute_import as _abs import logging +import numpy as np import tvm -import numpy as np from topi.util import get_const_tuple from .. import expr as _expr from .. import module as _module diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 25967b02df9b8..bcf8985657dae 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -17,6 +17,7 @@ #pylint: disable=invalid-name """Utilities for testing and benchmarks""" from __future__ import absolute_import as _abs +import numpy as np import tvm import tvm.relay as relay @@ -24,7 +25,6 @@ from tvm.relay import transform from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor from tvm.relay import TensorType, TupleType -import numpy as np from . import mlp from . import resnet diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 1edb27ae5eb32..72b835dddee76 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -32,18 +32,20 @@ # import numpy # import tvm # from tvm import relay +# from tvm import import container as _container # from tvm import nd -# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue +# from tvm.relay.backend.interpreter import RefValue, ConstructorValue PROLOGUE = [ ast.Import([alias('numpy', None)]), ast.Import([alias('tvm', None)]), ast.ImportFrom('tvm', [alias('relay', None)], 0), ast.ImportFrom('tvm', [alias('nd', None)], 0), + ast.ImportFrom('tvm', [alias('container', '_container')], + 0), ast.ImportFrom('tvm.relay.backend.interpreter', [alias('RefValue', None), - alias('TupleValue', None), alias('ConstructorValue', None)], - 0) + 0), ] class PythonConverter(ExprFunctor): @@ -253,7 +255,7 @@ def convert_input(py_input, arg_type): for i in range(len(arg_type.fields)): ret += convert_input( ast.Subscript( - ast.Attribute(py_input, 'fields', Load()), + py_input, ast.Index(Num(i)), Load()), arg_type.fields[i]) return ret @@ -282,7 +284,8 @@ def convert_output(ret_type): assignments += inner_assignments extra_args += inner_args fields.append(inner_output) - return (assignments, extra_args, self.create_call('TupleValue', fields)) + fields = [ast.List(fields, Load())] + return (assignments, extra_args, self.create_call('_container.tuple_object', fields)) # create a function to wrap the call of the lowered op and return # a call to that function @@ -444,7 +447,8 @@ def let_thunk(var): def visit_tuple(self, tup: Expr): fields, ret_defs = self.convert_fields(tup.fields) - return (self.create_call('TupleValue', fields), ret_defs) + fields = [ast.List(fields, Load())] + return (self.create_call('_container.tuple_object', fields), ret_defs) def visit_tuple_getitem(self, tgi: Expr): @@ -534,7 +538,7 @@ def visit_ref_write(self, write: Expr): thunk_name, [], ref_defs + val_defs + [ Assign([ast.Attribute(ref, 'value', Store())], val), - Return(self.create_call('TupleValue', [])) + Return(self.create_call('_container.tuple_object', [])) ]) return (self.create_call(thunk_name, []), [thunk]) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 432ad29b13ce6..1830ff76dbe6f 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -43,79 +44,42 @@ inline const PackedFunc& GetPackedFunc(const std::string& name) { return *pf; } -/* Object Implementation */ -Closure ClosureNode::make(tvm::Map env, Function func) { - ObjectPtr n = make_object(); - n->env = std::move(env); - n->func = std::move(func); - return Closure(n); -} - -TVM_REGISTER_GLOBAL("relay._make.Closure") -.set_body_typed(ClosureNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ClosureNode(" << node->func << ", " << node->env << ")"; - }); - - // TODO(@jroesch): this doesn't support mutual letrec /* Object Implementation */ -RecClosure RecClosureNode::make(Closure clos, Var bind) { - ObjectPtr n = make_object(); +RecClosure RecClosureObj::make(Closure clos, Var bind) { + ObjectPtr n = make_object(); n->clos = std::move(clos); n->bind = std::move(bind); return RecClosure(n); } -TVM_REGISTER_GLOBAL("relay._make.RecClosure") -.set_body_typed(RecClosureNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RecClosureNode(" << node->clos << ")"; - }); - -TupleValue TupleValueNode::make(tvm::Array value) { - ObjectPtr n = make_object(); - n->fields = value; - return TupleValue(n); -} - -TVM_REGISTER_GLOBAL("relay._make.TupleValue") -.set_body_typed(TupleValueNode::make); - TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleValueNode(" << node->fields << ")"; +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RecClosureObj(" << node->clos << ")"; }); - -RefValue RefValueNode::make(ObjectRef value) { - ObjectPtr n = make_object(); +RefValue RefValueObj::make(ObjectRef value) { + ObjectPtr n = make_object(); n->value = value; return RefValue(n); } TVM_REGISTER_GLOBAL("relay._make.RefValue") -.set_body_typed(RefValueNode::make); +.set_body_typed(RefValueObj::make); -TVM_REGISTER_NODE_TYPE(RefValueNode); +TVM_REGISTER_NODE_TYPE(RefValueObj); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefValueNode(" << node->value << ")"; +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefValueObj(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(int32_t tag, +ConstructorValue ConstructorValueObj::make(int32_t tag, tvm::Array fields, Constructor constructor) { - ObjectPtr n = make_object(); + ObjectPtr n = make_object(); n->tag = tag; n->fields = fields; n->constructor = constructor; @@ -123,14 +87,14 @@ ConstructorValue ConstructorValueNode::make(int32_t tag, } TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") -.set_body_typed(ConstructorValueNode::make); +.set_body_typed(ConstructorValueObj::make); -TVM_REGISTER_NODE_TYPE(ConstructorValueNode); +TVM_REGISTER_NODE_TYPE(ConstructorValueObj); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorValueNode(" << node->tag << "," +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")"; }); @@ -293,12 +257,17 @@ class Interpreter : values.push_back(field_value); } - return TupleValueNode::make(values); + return ADT::Tuple(values); } ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) { - tvm::Map captured_mod; + if (func_index_map_.count(func) == 0) { + func_index_map_[func] = func_index_++; + eval_funcs_.push_back(func); + } + std::vector free_var_values; Array free_vars = FreeVars(func); + std::vector captured_vars; for (const auto& var : free_vars) { // Evaluate the free var (which could be a function call) if it hasn't @@ -307,13 +276,16 @@ class Interpreter : continue; } - captured_mod.Set(var, Eval(var)); + ObjectRef value = Eval(var); + free_var_values.push_back(value); + captured_vars.push_back(var); } // We must use mutation here to build a self referential closure. - auto closure = ClosureNode::make(captured_mod, func); + Closure closure(func_index_map_[func], free_var_values); + closure_captured_vars_[closure] = captured_vars; if (letrec_name.defined()) { - return RecClosureNode::make(closure, letrec_name); + return RecClosureObj::make(closure, letrec_name); } return std::move(closure); } @@ -375,16 +347,15 @@ class Interpreter : fset_input(arg_counter++, arg, true); } } else { - const TupleValueNode* tuple = arg.as(); - CHECK(tuple != nullptr); + const ADT adt = Downcast(arg); if (state & kNeedInputData) { - for (size_t i = 0; i < tuple->fields.size(); ++i) { - fset_input(arg_counter++, tuple->fields[i], false); + for (size_t i = 0; i < adt.size(); ++i) { + fset_input(arg_counter++, adt[i], false); } } if (state & kNeedInputShape) { - for (size_t i = 0; i < tuple->fields.size(); ++i) { - fset_input(arg_counter++, tuple->fields[i], true); + for (size_t i = 0; i < adt.size(); ++i) { + fset_input(arg_counter++, adt[i], true); } } } @@ -458,14 +429,14 @@ class Interpreter : } // Marshal the arguments. - // Handle tuple input/output by flattening them. + // Handle adt input/output by flattening them. size_t arg_len = 0; for (size_t i = 0; i < args.size(); ++i) { if (args[i]->IsInstance()) { ++arg_len; } else { - const auto* tvalue = args[i].as(); - arg_len += tvalue->fields.size(); + auto adt = Downcast(args[i]); + arg_len += adt.size(); } } size_t num_inputs = arg_len; @@ -495,10 +466,9 @@ class Interpreter : if (arg->IsInstance()) { fset_input(arg_counter++, arg); } else { - const TupleValueNode* tuple = arg.as(); - CHECK(tuple != nullptr); - for (size_t i = 0; i < tuple->fields.size(); ++i) { - fset_input(arg_counter++, tuple->fields[i]); + auto adt = Downcast(arg); + for (size_t i = 0; i < adt.size(); ++i) { + fset_input(arg_counter++, adt[i]); } } } @@ -541,7 +511,7 @@ class Interpreter : TVMRetValue rv; if (const TupleTypeNode* rtype = func->body->checked_type().as()) { CHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); - Array fields; + std::vector fields; for (size_t i = 0; i < rtype->fields.size(); ++i) { if (is_dyn) { auto sh = out_shapes[i]; @@ -552,7 +522,7 @@ class Interpreter : } } packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); - return TupleValueNode::make(fields); + return ADT::Tuple(fields); } else { ObjectRef out_tensor; if (is_dyn) { @@ -572,15 +542,20 @@ class Interpreter : ObjectRef Invoke(const Closure& closure, const tvm::Array& args, const Var& bind = Var()) { + CHECK_GT(eval_funcs_.size(), closure->func_index); + CHECK_GT(func_index_map_.count(eval_funcs_[closure->func_index]), 0U); + auto func = eval_funcs_[closure->func_index]; // Get a reference to the function inside the closure. - if (closure->func->IsPrimitive()) { - return InvokePrimitiveOp(closure->func, args); + if (func->IsPrimitive()) { + return InvokePrimitiveOp(func, args); } - auto func = closure->func; // Allocate a frame with the parameters and free variables. tvm::Map locals; CHECK_EQ(func->params.size(), args.size()); + CHECK_GT(closure_captured_vars_.count(closure), 0U); + const auto& captured_vars = closure_captured_vars_[closure]; + CHECK_EQ(captured_vars.size(), closure->free_vars.size()); for (size_t i = 0; i < func->params.size(); i++) { CHECK_EQ(locals.count(func->params[i]), 0); @@ -588,13 +563,14 @@ class Interpreter : } // Add the var to value mappings from the Closure's environment. - for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { - CHECK_EQ(locals.count((*it).first), 0); - locals.Set((*it).first, (*it).second); + for (size_t i = 0; i < closure->free_vars.size(); i++) { + Var var = captured_vars[i]; + CHECK_EQ(locals.count(var), 0); + locals.Set(var, closure->free_vars[i]); } if (bind.defined()) { - locals.Set(bind, RecClosureNode::make(closure, bind)); + locals.Set(bind, RecClosureObj::make(closure, bind)); } return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); @@ -616,14 +592,14 @@ class Interpreter : "fusing and lowering"; } if (auto con = call->op.as()) { - return ConstructorValueNode::make(con->tag, args, GetRef(con)); + return ConstructorValueObj::make(con->tag, args, GetRef(con)); } // Now we just evaluate and expect to find a closure. ObjectRef fn_val = Eval(call->op); - if (const ClosureNode* closure_node = fn_val.as()) { + if (const ClosureObj* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); return this->Invoke(closure, args); - } else if (const RecClosureNode* closure_node = fn_val.as()) { + } else if (const RecClosureObj* closure_node = fn_val.as()) { return this->Invoke(closure_node->clos, args, closure_node->bind); } else { LOG(FATAL) << "internal error: type error, expected function value in the call " @@ -646,12 +622,13 @@ class Interpreter : ObjectRef VisitExpr_(const TupleGetItemNode* op) final { ObjectRef val = Eval(op->tuple); - auto product_node = val.as(); - CHECK(product_node) - << "interal error: when evaluating TupleGetItem expected a tuple value"; - CHECK_LT(static_cast(op->index), product_node->fields.size()) + const auto* adt_obj = val.as(); + CHECK(adt_obj) + << "interal error: when evaluating TupleGetItem expected an ADT value"; + auto adt = GetRef(adt_obj); + CHECK_LT(static_cast(op->index), adt.size()) << "internal error: index out of bounds"; - return product_node->fields[op->index]; + return adt[op->index]; } ObjectRef VisitExpr_(const IfNode* op) final { @@ -677,9 +654,9 @@ class Interpreter : ObjectRef VisitExpr_(const RefWriteNode* op) final { ObjectRef r = Eval(op->ref); - if (const RefValueNode* rv = r.as()) { + if (const RefValueObj* rv = r.as()) { rv->value = Eval(op->value); - return TupleValueNode::make({}); + return ADT::Tuple(std::vector()); } else { LOG(FATAL) << "type error, type system should have caught this"; return ObjectRef(); @@ -687,12 +664,12 @@ class Interpreter : } ObjectRef VisitExpr_(const RefCreateNode* op) final { - return RefValueNode::make(Eval(op->value)); + return RefValueObj::make(Eval(op->value)); } ObjectRef VisitExpr_(const RefReadNode* op) final { ObjectRef r = Eval(op->ref); - if (const RefValueNode* rv = r.as()) { + if (const RefValueObj* rv = r.as()) { return rv->value; } else { LOG(FATAL) << "type error, type system should have caught this"; @@ -712,7 +689,7 @@ class Interpreter : } bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final { - const ConstructorValueNode* cvn = v.as(); + const ConstructorValueObj* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; CHECK_NE(op->constructor->tag, -1); CHECK_NE(cvn->tag, -1); @@ -729,11 +706,10 @@ class Interpreter : } bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final { - const TupleValueNode* tvn = v.as(); - CHECK(tvn) << "need to be a tuple for match"; - CHECK_EQ(op->patterns.size(), tvn->fields.size()); + auto adt = Downcast(v); + CHECK_EQ(op->patterns.size(), adt.size()); for (size_t i = 0; i < op->patterns.size(); ++i) { - if (!VisitPattern(op->patterns[i], tvn->fields[i])) { + if (!VisitPattern(op->patterns[i], adt[i])) { return false; } } @@ -774,6 +750,14 @@ class Interpreter : // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; const Op& shape_of_op_; + // The free vars captured by the last closure. + std::unordered_map, ObjectHash, ObjectEqual> closure_captured_vars_; + // The index of the Relay function being evaluated. + int func_index_{0}; + // The Relay function to index map. + std::unordered_map func_index_map_; + // The saved functions. + std::vector eval_funcs_; }; @@ -804,8 +788,5 @@ CreateInterpreter( TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") .set_body_typed(CreateInterpreter); -TVM_REGISTER_NODE_TYPE(ClosureNode); -TVM_REGISTER_NODE_TYPE(TupleValueNode); - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 352a1d77c8772..a28709d4b2a9a 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "pattern_util.h" namespace tvm { @@ -187,10 +188,11 @@ class ConstantFolder : public ExprMutator { << "invalid dimension after constant eval"; } return ConstantNode::make(nd_array); - } else if (const auto* val = value.as()) { + } else if (const auto* val = value.as()) { + runtime::ADT adt = GetRef(val); Array fields; - for (ObjectRef field : val->fields) { - fields.push_back(ObjectToExpr(field)); + for (size_t i = 0; i < adt.size(); ++i) { + fields.push_back(ObjectToExpr(adt[i])); } return TupleNode::make(fields); } else { diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index a2e8d06a689b6..14feea86f3f2f 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -935,11 +935,12 @@ class PartialEvaluator : public ExprFunctor if (v->IsInstance()) { auto nd_array = Downcast(v); return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array))); - } else if (const TupleValueNode* op = v.as()) { + } else if (const runtime::ADTObj* op = v.as()) { std::vector fields; tvm::Array fields_dyn; - for (const ObjectRef& field : op->fields) { - PStatic ps = Reify(field, ll); + auto adt = GetRef(op); + for (size_t i = 0; i < adt.size(); ++i) { + PStatic ps = Reify(adt[i], ll); fields.push_back(ps); fields_dyn.push_back(ps->dynamic); } diff --git a/src/runtime/vm/object.cc b/src/runtime/container.cc similarity index 70% rename from src/runtime/vm/object.cc rename to src/runtime/container.cc index d7760d53e4df7..103b31035ef87 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/container.cc @@ -18,38 +18,27 @@ */ /*! - * \file src/runtime/vm/object.cc - * \brief VM related objects. + * \file src/runtime/container.cc + * \brief Implementation of common POD containers. */ -#include + #include -#include -#include #include +#include +#include #include -#include -#include "../runtime_base.h" namespace tvm { namespace runtime { -namespace vm { - -Closure::Closure(size_t func_index, std::vector free_vars) { - auto ptr = make_object(); - ptr->func_index = func_index; - ptr->free_vars = std::move(free_vars); - data_ = std::move(ptr); -} - -TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") +TVM_REGISTER_GLOBAL("container._GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.tag()); }); -TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") +TVM_REGISTER_GLOBAL("container._GetADTSize") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); @@ -57,7 +46,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") }); -TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") +TVM_REGISTER_GLOBAL("container._GetADTFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; @@ -66,7 +55,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("_vmobj.Tuple") +TVM_REGISTER_GLOBAL("container._Tuple") .set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; for (auto i = 0; i < args.size(); ++i) { @@ -75,7 +64,7 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple") *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("_vmobj.ADT") +TVM_REGISTER_GLOBAL("container._ADT") .set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); @@ -88,15 +77,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT") TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); -} // namespace vm + } // namespace runtime } // namespace tvm - -using namespace tvm::runtime; - -int TVMGetObjectTag(TVMObjectHandle handle, int* tag) { - API_BEGIN(); - int res = static_cast(static_cast(handle)->type_index()); - *tag = res; - API_END(); -} diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 19e5b1ff9c3c7..8ff0d410184a8 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -62,16 +62,11 @@ def convert_to_list(x): def vmobj_to_list(o): if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.ADT): + elif isinstance(o, tvm.container.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) return result - elif isinstance(o, tvm.relay.backend.interpreter.TupleValue): - result = [] - for f in o.fields: - result.append(vmobj_to_list(f)) - return result elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): if o.constructor.name_hint == 'Cons': tl = vmobj_to_list(o.fields[1]) diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py index cfb3fd42f8341..f9aa6ce8bfa76 100644 --- a/tests/python/relay/benchmarking/benchmark_vm.py +++ b/tests/python/relay/benchmarking/benchmark_vm.py @@ -19,10 +19,9 @@ import tvm from tvm.contrib import graph_runtime -from tvm import relay +from tvm import relay, container from tvm.relay import testing from tvm.relay import vm -from tvm.relay import vmobj as _obj def benchmark_execution(mod, @@ -69,7 +68,7 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32', ftimer = rly_vm.mod.time_evaluator("invoke", ctx, number=number, repeat=repeat) # Measure in millisecond. - prof_res = np.array(ftimer("main", _obj.Tensor(data)).results) * 1000 + prof_res = np.array(ftimer("main", data.results) * 1000 print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 8e304bd856e94..00c07b928b1ab 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -117,7 +117,7 @@ def tree_to_dict(t): def vmobj_to_list(o, dtype="float32"): if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.ADT): + elif isinstance(o, tvm.container.ADT): if len(o) == 0: tensor_nil = p.get_var("tensor_nil", dtype=dtype) if tensor_nil.tag == o.tag: diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 85bba4402ea29..11a9e05b3e7f1 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -18,8 +18,7 @@ import tvm import tvm.testing from tvm import nd -from tvm import relay -from tvm.relay.backend.interpreter import TupleValue +from tvm import relay, container from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.scope_builder import ScopeBuilder from tvm.relay import testing, create_executor @@ -39,7 +38,8 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): def test_tuple_value(): - tv = TupleValue(relay.const(1), relay.const(2), relay.const(3)) + tv = container.tuple_object([relay.const(1), relay.const(2), + relay.const(3)]) np.testing.assert_allclose(tv[0].data.asnumpy(), 1) np.testing.assert_allclose(tv[1].data.asnumpy(), 2) np.testing.assert_allclose(tv[2].data.asnumpy(), 3) @@ -178,7 +178,7 @@ def test_function_taking_adt_ref_tuple(): ], prelude.cons) ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32'))) - tuple_value = TupleValue(*[ + tuple_value = container.tuple_object([ nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10) ]) @@ -202,8 +202,8 @@ def test_function_taking_adt_ref_tuple(): res_tuple = id_func(tuple_value) for i in range(10): - tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(), - tuple_value.fields[i].asnumpy()) + tvm.testing.assert_allclose(res_tuple[i].asnumpy(), + tuple_value[i].asnumpy()) def test_tuple_passing(): x = relay.var('x', type_annotation=relay.ty.TupleType([ @@ -224,7 +224,8 @@ def test_tuple_passing(): out = f((10, 8)) tvm.testing.assert_allclose(out.asnumpy(), np.array(10)) # Second use a tuple value. - value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12))) + value_tuple = container.tuple_object([nd.array(np.array(11)), + nd.array(np.array(12))]) out = f(value_tuple) tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index f87f90a85a0b9..76aa697a2aaba 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -19,7 +19,8 @@ from tvm import relay from tvm.relay.testing import to_python, run_as_python from tvm.relay.prelude import Prelude -from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue +from tvm.container import ADT +from tvm.relay.backend.interpreter import RefValue, ConstructorValue # helper: uses a dummy let binding to sequence a list # of expressions: expr1; expr2; expr3, etc. @@ -45,10 +46,10 @@ def assert_tensor_value(candidate, val): assert np.array_equal(candidate.asnumpy(), np.array(val)) -# assert that the candidate is a TupleValue with the indicate number of fields -def assert_tuple_value(candidate, fields): - assert isinstance(candidate, TupleValue) - assert len(candidate.fields) == fields +# assert that the candidate is an ADT with the indicated number of fields +def assert_adt_len(candidate, fields): + assert isinstance(candidate, ADT) + assert len(candidate) == fields # assert that the candidate is a ConstructorValue with the approrpaite constructor @@ -62,7 +63,7 @@ def assert_constructor_value(candidate, constructor, fields): def test_create_empty_tuple(): empty = relay.Tuple([]) tup_val = run_as_python(empty) - assert_tuple_value(tup_val, 0) + assert_adt_len(tup_val, 0) def test_create_scalar(): @@ -87,12 +88,12 @@ def test_create_nested_tuple(): ]) ]) tup_val = run_as_python(relay_tup) - assert_tuple_value(tup_val, 3) + assert_adt_len(tup_val, 3) for i in range(2): - assert_tensor_value(tup_val.fields[i], i + 1) - assert_tuple_value(tup_val.fields[2], 2) + assert_tensor_value(tup_val[i], i + 1) + assert_adt_len(tup_val[2], 2) for i in range(2): - assert_tensor_value(tup_val.fields[2].fields[i], i + 3) + assert_tensor_value(tup_val[2][i], i + 3) def test_tuple_get_item(): @@ -118,23 +119,23 @@ def test_create_let(): v = relay.Var('v') let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v])) tup_val = run_as_python(let) - assert_tuple_value(tup_val, 2) - assert_tuple_value(tup_val.fields[0], 0) - assert_tuple_value(tup_val.fields[1], 0) + assert_adt_len(tup_val, 2) + assert_adt_len(tup_val[0], 0) + assert_adt_len(tup_val[1], 0) def test_create_ref(): relay_ref = relay.RefCreate(relay.Tuple([])) ref_val = run_as_python(relay_ref) assert isinstance(ref_val, RefValue) - assert_tuple_value(ref_val.value, 0) + assert_adt_len(ref_val.value, 0) def test_ref_read(): v = relay.Var('v') assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v)) read_val = run_as_python(assign) - assert_tuple_value(read_val, 0) + assert_adt_len(read_val, 0) def test_ref_write(): @@ -143,7 +144,7 @@ def test_ref_write(): initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])), relay.RefWrite(v, relay.Tuple([relay.const(2)]))) write_val = run_as_python(initial_write) - assert_tuple_value(write_val, 0) + assert_adt_len(write_val, 0) # now ensure that the value, once written, can be read back # (we read the value before and after mutation) @@ -155,11 +156,11 @@ def test_ref_write(): seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])), relay.Tuple([relay.RefRead(w), relay.RefRead(v)])))) read_val = run_as_python(read_after_write) - assert_tuple_value(read_val, 2) - assert_tuple_value(read_val.fields[0], 1) - assert_tuple_value(read_val.fields[1], 1) - assert_tensor_value(read_val.fields[0].fields[0], 1) - assert_tensor_value(read_val.fields[1].fields[0], 2) + assert_adt_len(read_val, 2) + assert_adt_len(read_val[0], 1) + assert_adt_len(read_val[1], 1) + assert_tensor_value(read_val[0][0], 1) + assert_tensor_value(read_val[1][0], 2) def test_if(): @@ -191,7 +192,7 @@ def test_local_function(): call2 = relay.Let(f, ident, f(relay.const(2))) call_val1 = run_as_python(call1) - assert_tuple_value(call_val1, 0) + assert_adt_len(call_val1, 0) call_val2 = run_as_python(call2) assert_tensor_value(call_val2, 2) @@ -211,9 +212,9 @@ def test_global_function(): assert_tensor_value(call_val1, 1) call_val2 = run_as_python(call2, mod) - assert_tuple_value(call_val2, 2) - assert_tensor_value(call_val2.fields[0], 2) - assert_tensor_value(call_val2.fields[1], 2) + assert_adt_len(call_val2, 2) + assert_tensor_value(call_val2[0], 2) + assert_tensor_value(call_val2[1], 2) def test_constructor(): @@ -230,7 +231,7 @@ def test_constructor(): box_val_tup = run_as_python(init_box_tup, mod) assert_constructor_value(box_val_tup, box_ctor, 1) - assert_tuple_value(box_val_tup.fields[0], 0) + assert_adt_len(box_val_tup.fields[0], 0) def test_match_wildcard(): @@ -341,7 +342,7 @@ def test_local_recursion(): assert_tensor_value(val.fields[1].fields[0], 2) assert_constructor_value(val.fields[1].fields[1], p.cons, 2) assert_tensor_value(val.fields[1].fields[1].fields[0], 3) - assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0) + assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0) def test_global_recursion(): @@ -372,7 +373,7 @@ def test_global_recursion(): call2 = copy_def(p.cons(relay.Tuple([]), p.nil())) val2 = run_as_python(call2, mod) assert_constructor_value(val2, p.cons, 2) - assert_tuple_value(val2.fields[0], 0) + assert_adt_len(val2.fields[0], 0) assert_constructor_value(val2.fields[1], p.nil, 0) @@ -437,10 +438,10 @@ def test_arbitrary_let_nesting(): ]) tup_val = run_as_python(expr, mod) - assert_tuple_value(tup_val, 3) - assert_tensor_value(tup_val.fields[0], 2) - assert_tensor_value(tup_val.fields[1], 3) - assert_tensor_value(tup_val.fields[2], 4) + assert_adt_len(tup_val, 3) + assert_tensor_value(tup_val[0], 2) + assert_tensor_value(tup_val[1], 3) + assert_tensor_value(tup_val[2], 4) def test_ref_execution_order(): @@ -475,12 +476,12 @@ def test_ref_execution_order(): ]))) tup_val = run_as_python(expr) - assert_tuple_value(tup_val, 5) - assert_tensor_value(tup_val.fields[0], 1) - assert_tensor_value(tup_val.fields[1], 2) - assert_tensor_value(tup_val.fields[2], 3) - assert_tensor_value(tup_val.fields[3], 4) - assert_tensor_value(tup_val.fields[4], 5) + assert_adt_len(tup_val, 5) + assert_tensor_value(tup_val[0], 1) + assert_tensor_value(tup_val[1], 2) + assert_tensor_value(tup_val[2], 3) + assert_tensor_value(tup_val[3], 4) + assert_tensor_value(tup_val[4], 5) def test_op_add(): @@ -501,6 +502,7 @@ def verify_stack(dshapes, axis): args.append(relay.const(data)) call = relay.stack(relay.Tuple(args), axis) call_val = run_as_python(call) + type(call_val) assert_tensor_value(call_val, ref_res) verify_stack([(2,), (2,), (2,)], -1) @@ -517,9 +519,9 @@ def verify_split(shape, indices_or_sections, axis=0): ref_res = np.split(x, indices_or_sections, axis=axis) call = relay.split(relay.const(x), indices_or_sections, axis=axis) call_val = run_as_python(call) - assert_tuple_value(call_val, len(ref_res)) + assert_adt_len(call_val, len(ref_res)) for i in range(len(ref_res)): - assert_tensor_value(call_val.fields[i], ref_res[i]) + assert_tensor_value(call_val[i], ref_res[i]) verify_split((2, 3), 2) verify_split((5, 3), [3]) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index d53360d766560..d4a7a1a25689f 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -58,7 +58,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): def vmobj_to_list(o): if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vm.ADT): + elif isinstance(o, tvm.container.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py deleted file mode 100644 index 82a2b116d82c8..0000000000000 --- a/tests/python/relay/test_vm_object.py +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np -import tvm -from tvm.relay import vm - -def test_adt(): - arr = tvm.nd.array([1,2,3]) - y = vm.ADT(0, [arr, arr]) - - assert len(y) == 2 - assert isinstance(y, vm.ADT) - y[0:1][-1] == arr - assert y.tag == 0 - assert isinstance(arr, tvm.nd.NDArray) - - -if __name__ == "__main__": - test_adt() diff --git a/tests/python/unittest/test_container.py b/tests/python/unittest/test_container.py new file mode 100644 index 0000000000000..7bdab82d7a659 --- /dev/null +++ b/tests/python/unittest/test_container.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +from tvm import nd, relay +from tvm import container as _container + + +def test_adt_constructor(): + arr = nd.array([1, 2, 3]) + fields = [arr, arr] + y = _container.ADT(0, [arr, arr]) + + assert len(y) == 2 + assert isinstance(y, _container.ADT) + y[0:1][-1] == arr + assert y.tag == 0 + assert isinstance(arr, nd.NDArray) + + +def test_tuple_object(): + x = relay.var( + 'x', + type_annotation=relay.ty.TupleType([ + relay.ty.TensorType((), 'int32'), + relay.ty.TensorType((), 'int32') + ])) + + fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) + mod = relay.Module.from_expr(fn) + + exe = relay.create_executor( + kind="vm", mod=mod, ctx=nd.cpu(), target="llvm") + f = exe.evaluate() + value_tuple = _container.tuple_object( + [nd.array(np.array(11)), + nd.array(np.array(12))]) + # pass an ADT object to evaluate + out = f(value_tuple) + tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) + + +if __name__ == "__main__": + test_adt_constructor() + test_tuple_object()