diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 8ef7f6e4ed891..dc35fc26486a8 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -37,15 +37,11 @@ #include #include #include +#include namespace tvm { namespace relay { -/*! - * \brief A Relay value. - */ -class Value; - /*! *\brief Create a Interpreter function that can * evaluate an expression and produce a value. @@ -65,39 +61,21 @@ class Value; * \param target Compiler target flag to compile the functions on the context. * \return A function that takes in an expression and returns a value. */ -runtime::TypedPackedFunc +runtime::TypedPackedFunc CreateInterpreter(Module mod, DLContext context, Target target); -/*! \brief The base container type of Relay values. */ -class ValueNode : public RelayNode { - public: - static constexpr const char* _type_key = "relay.Value"; - TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode); -}; - -class Value : public ObjectRef { - public: - Value() {} - explicit Value(ObjectPtr n) : ObjectRef(n) {} - const ValueNode* operator->() const { - return static_cast(get()); - } - - using ContainerType = ValueNode; -}; - /*! \brief A Relay closure, i.e a scope and a function. */ class Closure; /*! \brief The container type of Closures. */ -class ClosureNode : public ValueNode { +class ClosureNode : public Object { public: /*! \brief The set of free variables in the closure. * * These are the captured variables which are required for * evaluation when we call the closure. */ - tvm::Map env; + tvm::Map env; /*! \brief The function which implements the closure. * * \note May reference the variables contained in the env. @@ -111,22 +89,22 @@ class ClosureNode : public ValueNode { v->Visit("func", &func); } - TVM_DLL static Closure make(tvm::Map env, Function func); + TVM_DLL static Closure make(tvm::Map env, Function func); static constexpr const char* _type_key = "relay.Closure"; - TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object); }; -class Closure : public Value { +class Closure : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode); + TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode); }; /*! \brief A Relay Recursive Closure. A closure that has a name. */ class RecClosure; /*! \brief The container type of RecClosure. */ -class RecClosureNode : public ValueNode { +class RecClosureNode : public Object { public: /*! \brief The closure. */ Closure clos; @@ -143,64 +121,41 @@ class RecClosureNode : public ValueNode { TVM_DLL static RecClosure make(Closure clos, Var bind); static constexpr const char* _type_key = "relay.RecClosure"; - TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object); }; -class RecClosure : public Value { +class RecClosure : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode); + TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode); }; /*! \brief A tuple value. */ class TupleValue; /*! \brief Tuple (x, ... y). */ -struct TupleValueNode : ValueNode { - tvm::Array fields; +struct TupleValueNode : Object { + tvm::Array fields; TupleValueNode() {} void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } - TVM_DLL static TupleValue make(tvm::Array value); + TVM_DLL static TupleValue make(tvm::Array value); static constexpr const char* _type_key = "relay.TupleValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode); -}; - -class TupleValue : public Value { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode); -}; - -/*! \brief A tensor value. */ -class TensorValue; - -/*! \brief The tensor value container, wrapping an NDArray. */ -struct TensorValueNode : ValueNode { - runtime::NDArray data; - - TensorValueNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); } - - /*! \brief Build a value from an NDArray. */ - TVM_DLL static TensorValue make(runtime::NDArray data); - - static constexpr const char* _type_key = "relay.TensorValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object); }; -class TensorValue : public Value { +class TupleValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode); }; /*! \brief A reference value. */ class RefValue; -struct RefValueNode : ValueNode { - mutable Value value; +struct RefValueNode : Object { + mutable ObjectRef value; RefValueNode() {} @@ -208,24 +163,24 @@ struct RefValueNode : ValueNode { v->Visit("value", &value); } - TVM_DLL static RefValue make(Value val); + TVM_DLL static RefValue make(ObjectRef val); static constexpr const char* _type_key = "relay.RefValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object); }; -class RefValue : public Value { +class RefValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode); }; /*! \brief An ADT constructor value. */ class ConstructorValue; -struct ConstructorValueNode : ValueNode { +struct ConstructorValueNode : Object { int32_t tag; - tvm::Array fields; + tvm::Array fields; /*! \brief Optional field tracking ADT constructor. */ Constructor constructor; @@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode { } TVM_DLL static ConstructorValue make(int32_t tag, - tvm::Array fields, + tvm::Array fields, Constructor construtor = {}); static constexpr const char* _type_key = "relay.ConstructorValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object); }; -class ConstructorValue : public Value { +class ConstructorValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode); }; } // namespace relay diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 59e9ae8610380..990ecf5ea7336 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -36,25 +36,6 @@ namespace tvm { namespace runtime { namespace vm { -/*! \brief An object containing an NDArray. */ -class TensorObj : public Object { - public: - /*! \brief The NDArray. */ - NDArray data; - - static constexpr const uint32_t _type_index = TypeIndex::kVMTensor; - static constexpr const char* _type_key = "vm.Tensor"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object); -}; - -/*! \brief reference to tensor. */ -class Tensor : public ObjectRef { - public: - explicit Tensor(NDArray data); - - TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj); -}; - /*! \brief An object representing a closure. */ class ClosureObj : public Object { public: diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 1d53f6a92b07c..128edfca0fe13 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -23,27 +23,13 @@ from . import _backend from .. import _make, analysis, transform from .. import module -from ... import register_func, nd +from ... import nd from ..base import NodeBase, register_relay_node from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder -from . import _vm - -class Value(NodeBase): - """Base class of all values. - """ - @staticmethod - @register_func("relay.from_scalar") - def from_scalar(value, dtype=None): - """Convert a Python scalar to a Relay scalar.""" - return TensorValue(const(value, dtype).data) - - def to_vm(self): - return _vm._ValueToVM(self) - @register_relay_node -class TupleValue(Value): +class TupleValue(NodeBase): """A tuple value produced by the interpreter.""" def __init__(self, *fields): self.__init_handle_by_constructor__( @@ -68,60 +54,32 @@ def __iter__(self): @register_relay_node -class Closure(Value): +class Closure(NodeBase): """A closure produced by the interpreter.""" @register_relay_node -class RecClosure(Value): +class RecClosure(NodeBase): """A recursive closure produced by the interpreter.""" @register_relay_node -class ConstructorValue(Value): +class ConstructorValue(NodeBase): def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( _make.ConstructorValue, tag, fields, constructor) @register_relay_node -class TensorValue(Value): - """A Tensor value produced by the interpreter.""" - - def __init__(self, data): - """Allocate a new TensorValue and copy the data from `array` into - the new array. - """ - if isinstance(data, np.ndarray): - data = nd.array(data) - - self.__init_handle_by_constructor__( - _make.TensorValue, data) - - def asnumpy(self): - """Convert a Relay TensorValue into a numpy.ndarray.""" - return self.data.asnumpy() - - def __eq__(self, other): - return self.data == other.data - - def __repr__(self): - return repr(self.data) - - def __str__(self): - return str(self.data) - - -@register_relay_node -class RefValue(Value): +class RefValue(NodeBase): def __init__(self, value): self.__init_handle_by_constructor__( _make.RefValue, value) def _arg_to_ast(mod, arg): - if isinstance(arg, TensorValue): - return Constant(arg.data.copyto(nd.cpu(0))) + if isinstance(arg, nd.NDArray): + return Constant(arg.copyto(nd.cpu(0))) elif isinstance(arg, TupleValue): return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) elif isinstance(arg, tuple): @@ -231,7 +189,7 @@ def evaluate(self, expr=None, binds=None): Returns ------- - val : Union[function, Value] + val : Union[function, NodeBase] The evaluation result. """ if binds: diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index a523722def61a..91b2554fa63df 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -31,16 +31,18 @@ from . import vmobj as _obj from .interpreter import Executor -Tensor = _obj.Tensor ADT = _obj.ADT def _convert(arg, cargs): if isinstance(arg, _expr.Constant): - cargs.append(_obj.Tensor(arg.data)) + cargs.append(arg.data) elif isinstance(arg, _obj.Object): cargs.append(arg) - elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)): - cargs.append(_obj.Tensor(arg)) + elif isinstance(arg, np.ndarray): + nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0)) + cargs.append(nd_arr) + elif isinstance(arg, tvm.nd.NDArray): + cargs.append(arg) elif isinstance(arg, (tuple, list)): field_args = [] for field in arg: @@ -48,7 +50,7 @@ def _convert(arg, cargs): cargs.append(_obj.tuple_object(field_args)) elif isinstance(arg, (_base.numeric_types, bool)): dtype = "int32" if isinstance(arg, (int, bool)) else "float32" - value = _obj.Tensor(np.array(arg, dtype=dtype)) + value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0)) cargs.append(value) else: raise TypeError("Unsupported type: %s" % (type(arg))) diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py index f3fdb763209da..330257ff94674 100644 --- a/python/tvm/relay/backend/vmobj.py +++ b/python/tvm/relay/backend/vmobj.py @@ -16,51 +16,12 @@ # under the License. """TVM Runtime Object API.""" from __future__ import absolute_import as _abs -import numpy as _np from tvm._ffi.object import Object, register_object, getitem_helper from tvm import ndarray as _nd from . import _vmobj -@register_object("vm.Tensor") -class Tensor(Object): - """Tensor object. - - Parameters - ---------- - arr : numpy.ndarray or tvm.nd.NDArray - The source array. - - ctx : TVMContext, optional - The device context to create the array - """ - def __init__(self, arr, ctx=None): - if isinstance(arr, _np.ndarray): - ctx = ctx if ctx else _nd.cpu(0) - self.__init_handle_by_constructor__( - _vmobj.Tensor, _nd.array(arr, ctx=ctx)) - elif isinstance(arr, _nd.NDArray): - self.__init_handle_by_constructor__( - _vmobj.Tensor, arr) - else: - raise RuntimeError("Unsupported type for tensor object.") - - @property - def data(self): - return _vmobj.GetTensorData(self) - - def asnumpy(self): - """Convert data to numpy array - - Returns - ------- - np_arr : numpy.ndarray - The corresponding numpy array. - """ - return self.data.asnumpy() - - @register_object("vm.ADT") class ADT(Object): """Algebatic data type(ADT) object. @@ -75,7 +36,8 @@ class ADT(Object): """ def __init__(self, tag, fields): for f in fields: - assert isinstance(f, Object) + assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " + "tvm NDArray type, but received : {0}".format(type(f)) self.__init_handle_by_constructor__( _vmobj.ADT, tag, *fields) @@ -105,5 +67,6 @@ def tuple_object(fields): The created object. """ for f in fields: - assert isinstance(f, Object) + assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " + "NDArray type, but received : {0}".format(type(f)) return _vmobj.Tuple(*fields) diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index d7b59922b89df..1edb27ae5eb32 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -32,15 +32,16 @@ # import numpy # import tvm # from tvm import relay -# from tvm.relay.backend.interpreter import RefValue, TupleValue, TensorValue, ConstructorValue +# from tvm import nd +# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue PROLOGUE = [ ast.Import([alias('numpy', None)]), ast.Import([alias('tvm', None)]), ast.ImportFrom('tvm', [alias('relay', None)], 0), + ast.ImportFrom('tvm', [alias('nd', None)], 0), ast.ImportFrom('tvm.relay.backend.interpreter', [alias('RefValue', None), alias('TupleValue', None), - alias('TensorValue', None), alias('ConstructorValue', None)], 0) ] @@ -245,7 +246,7 @@ def convert_input(py_input, arg_type): a tensor or tuple (returns list of inputs to the lowered op call)""" # equivalent: input.data if isinstance(arg_type, relay.TensorType): - return [ast.Attribute(py_input, 'data', Load())] + return [py_input] assert isinstance(arg_type, relay.TupleType) # convert each input.fields[i] ret = [] @@ -265,15 +266,13 @@ def convert_output(ret_type): output_var_name = self.generate_var_name('_out') output_var = Name(output_var_name, Load()) shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load()) - # create a new TensorValue of the right shape and dtype + # create a new NDArray of the right shape and dtype assign_output = Assign( [Name(output_var_name, Store())], - self.create_call('TensorValue', [ + self.create_call('nd.array', [ self.create_call('numpy.empty', [shape, Str(ret_type.dtype)]) ])) - # we pass the data field as an argument - extra_arg = ast.Attribute(output_var, 'data', Load()) - return ([assign_output], [extra_arg], output_var) + return ([assign_output], [output_var], output_var) assert isinstance(ret_type, relay.TupleType) assignments = [] extra_args = [] @@ -459,7 +458,7 @@ def visit_if(self, if_block: Expr): true_body, true_defs = self.visit(if_block.true_branch) false_body, false_defs = self.visit(if_block.false_branch) - # need to get the value out of a TensorValue to check the condition + # need to get the value out of a NDArray to check the condition # equvialent to: val.asnumpy() cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], []) ret = ast.IfExp(cond_check, true_body, false_body) @@ -474,7 +473,7 @@ def visit_constant(self, constant: Expr): const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()), [self.parse_numpy_array(value)], [ast.keyword('dtype', Str(constant.checked_type.dtype))]) - return (self.create_call('TensorValue', [const_expr]), []) + return (self.create_call('nd.array', [const_expr]), []) def visit_function(self, func: Expr): diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index c1e4fd59d0425..432ad29b13ce6 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -43,8 +43,8 @@ inline const PackedFunc& GetPackedFunc(const std::string& name) { return *pf; } -/* Value Implementation */ -Closure ClosureNode::make(tvm::Map env, Function func) { +/* Object Implementation */ +Closure ClosureNode::make(tvm::Map env, Function func) { ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); @@ -62,7 +62,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) // TODO(@jroesch): this doesn't support mutual letrec -/* Value Implementation */ +/* Object Implementation */ RecClosure RecClosureNode::make(Closure clos, Var bind) { ObjectPtr n = make_object(); n->clos = std::move(clos); @@ -79,7 +79,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "RecClosureNode(" << node->clos << ")"; }); -TupleValue TupleValueNode::make(tvm::Array value) { +TupleValue TupleValueNode::make(tvm::Array value) { ObjectPtr n = make_object(); n->fields = value; return TupleValue(n); @@ -94,24 +94,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "TupleValueNode(" << node->fields << ")"; }); -TensorValue TensorValueNode::make(runtime::NDArray data) { - ObjectPtr n = make_object(); - n->data = std::move(data); - return TensorValue(n); -} - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - auto to_str = GetPackedFunc("relay._tensor_value_repr"); - std::string data_str = to_str(GetRef(node)); - p->stream << "TensorValueNode(" << data_str << ")"; - }); - -TVM_REGISTER_GLOBAL("relay._make.TensorValue") -.set_body_typed(TensorValueNode::make); -RefValue RefValueNode::make(Value value) { +RefValue RefValueNode::make(ObjectRef value) { ObjectPtr n = make_object(); n->value = value; return RefValue(n); @@ -129,7 +113,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); ConstructorValue ConstructorValueNode::make(int32_t tag, - tvm::Array fields, + tvm::Array fields, Constructor constructor) { ObjectPtr n = make_object(); n->tag = tag; @@ -153,13 +137,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) /*! * \brief A stack frame in the Relay interpreter. * - * Contains a mapping from relay::Var to relay::Value. + * Contains a mapping from relay::Var to relay::ObjectRef. */ struct Frame { /*! \brief The set of local variables and arguments for the frame. */ - tvm::Map locals; + tvm::Map locals; - explicit Frame(tvm::Map locals) : locals(locals) {} + explicit Frame(tvm::Map locals) : locals(locals) {} }; /*! @@ -175,7 +159,7 @@ struct Stack { Frame& current_frame() { return frames.back(); } - Value Lookup(const Var& local) { + ObjectRef Lookup(const Var& local) { for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) { auto elem = frame->locals.find(local); if (elem != frame->locals.end()) { @@ -185,7 +169,7 @@ struct Stack { LOG(FATAL) << "could not find variable binding for " << local << "address= " << local.operator->(); - return Value(); + return ObjectRef(); } /*! * A wrapper around Frame to add RAII semantics to pushing and popping @@ -206,7 +190,7 @@ class InterpreterState; /*! \brief A container capturing the state of the interpreter. */ class InterpreterStateNode : public Object { public: - using Frame = tvm::Map; + using Frame = tvm::Map; using Stack = tvm::Array; /*! \brief The current expression under evaluation. */ @@ -246,8 +230,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { // // Conversion to ANF is recommended before running the interpretation. class Interpreter : - public ExprFunctor, - PatternFunctor { + public ExprFunctor, + PatternFunctor { public: Interpreter(Module mod, DLContext context, Target target) : mod_(mod), @@ -264,56 +248,56 @@ class Interpreter : return f(); } - void extend(const Var& id, Value v) { + void extend(const Var& id, ObjectRef v) { stack_.current_frame().locals.Set(id, v); } - Value Lookup(const Var& local) { + ObjectRef Lookup(const Var& local) { return stack_.Lookup(local); } - Value Eval(const Expr& expr) { + ObjectRef Eval(const Expr& expr) { return VisitExpr(expr); } - Value VisitExpr(const Expr& expr) final { - auto ret = ExprFunctor::VisitExpr(expr); + ObjectRef VisitExpr(const Expr& expr) final { + auto ret = ExprFunctor::VisitExpr(expr); return ret; } - Value VisitExpr_(const VarNode* var_node) final { + ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef(var_node)); } - Value VisitExpr_(const GlobalVarNode* op) final { + ObjectRef VisitExpr_(const GlobalVarNode* op) final { return Eval(mod_->Lookup(GetRef(op))); } - Value VisitExpr_(const OpNode* id) override { + ObjectRef VisitExpr_(const OpNode* id) override { // TODO(@jroesch): Eta-expand and return in this case. LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node " << "in " << "this case, eta expand"; - return Value(); + return ObjectRef(); } - Value VisitExpr_(const ConstantNode* op) final { - return TensorValueNode::make(op->data.CopyTo(context_)); + ObjectRef VisitExpr_(const ConstantNode* op) final { + return op->data.CopyTo(context_); } - Value VisitExpr_(const TupleNode* op) final { - std::vector values; + ObjectRef VisitExpr_(const TupleNode* op) final { + std::vector values; for (const auto& field : op->fields) { - Value field_value = Eval(field); + ObjectRef field_value = Eval(field); values.push_back(field_value); } return TupleValueNode::make(values); } - Value MakeClosure(const Function& func, Var letrec_name = Var()) { - tvm::Map captured_mod; + ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) { + tvm::Map captured_mod; Array free_vars = FreeVars(func); for (const auto& var : free_vars) { @@ -334,13 +318,13 @@ class Interpreter : return std::move(closure); } - Value VisitExpr_(const FunctionNode* func_node) final { + ObjectRef VisitExpr_(const FunctionNode* func_node) final { auto func = GetRef(func_node); return MakeClosure(func); } Array ComputeDynamicShape(const Function& func, - const Array& args) { + const Array& args) { auto key = CCacheKeyNode::make(func, Target::Create("llvm")); auto cfunc = engine_->LowerShapeFunc(key); size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); @@ -355,11 +339,10 @@ class Interpreter : cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; - auto fset_input = [&](size_t i, Value val, bool need_shape) { - const TensorValueNode* tv = val.as(); - CHECK(tv != nullptr) << "expect Tensor argument"; + auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) { + auto nd_array = Downcast(val); if (need_shape) { - int64_t ndim = tv->data.Shape().size(); + int64_t ndim = nd_array.Shape().size(); NDArray shape_arr; if (ndim == 0) { shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx); @@ -367,13 +350,13 @@ class Interpreter : shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); int64_t* data = reinterpret_cast(shape_arr->data); for (auto j = 0; j < ndim; ++j) { - data[j] = tv->data.Shape()[j]; + data[j] = nd_array.Shape()[j]; } } inputs[i] = shape_arr; setter(i, shape_arr); } else { - auto arr = tv->data.CopyTo(cpu_ctx); + auto arr = nd_array.CopyTo(cpu_ctx); inputs[i] = arr; setter(i, arr); } @@ -384,7 +367,7 @@ class Interpreter : auto arg = args[i]; auto param = func->params[i]; int state = cfunc->shape_func_param_states[i]->value; - if (arg.as()) { + if (arg->IsInstance()) { if (state & kNeedInputData) { fset_input(arg_counter++, arg, false); } @@ -457,8 +440,8 @@ class Interpreter : return out_shapes; } - Value InvokePrimitiveOp(const Function& func, - const Array& args) { + ObjectRef InvokePrimitiveOp(const Function& func, + const Array& args) { const auto* call_node = func->body.as(); if (call_node && call_node->op == debug_op_) { @@ -478,7 +461,7 @@ class Interpreter : // Handle tuple input/output by flattening them. size_t arg_len = 0; for (size_t i = 0; i < args.size(); ++i) { - if (args[i].as()) { + if (args[i]->IsInstance()) { ++arg_len; } else { const auto* tvalue = args[i].as(); @@ -497,11 +480,10 @@ class Interpreter : std::vector codes(arg_len); TVMArgsSetter setter(values.data(), codes.data()); - auto fset_input = [&](size_t i, Value val) { - const TensorValueNode* tv = val.as(); - CHECK(tv != nullptr) << "expect Tensor argument"; - setter(i, tv->data); - DLContext arg_ctx = tv->data->ctx; + auto fset_input = [&](size_t i, ObjectRef val) { + const auto nd_array = Downcast(val); + setter(i, nd_array); + DLContext arg_ctx = nd_array->ctx; CHECK(arg_ctx.device_type == context_.device_type && arg_ctx.device_id == context_.device_id) << "Interpreter expect context to be " @@ -509,8 +491,8 @@ class Interpreter : }; int arg_counter = 0; - for (Value arg : args) { - if (arg.as()) { + for (ObjectRef arg : args) { + if (arg->IsInstance()) { fset_input(arg_counter++, arg); } else { const TupleValueNode* tuple = arg.as(); @@ -536,10 +518,9 @@ class Interpreter : shape.push_back(ivalue[0]); } DLDataType dtype = rtype->dtype; - auto out_tensor = TensorValueNode::make( - NDArray::Empty(shape, dtype, context_)); - setter(num_inputs + i, out_tensor->data); - return out_tensor; + NDArray nd_array = NDArray::Empty(shape, dtype, context_); + setter(num_inputs + i, nd_array); + return nd_array; }; Array out_shapes; @@ -560,7 +541,7 @@ class Interpreter : TVMRetValue rv; if (const TupleTypeNode* rtype = func->body->checked_type().as()) { CHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); - Array fields; + Array fields; for (size_t i = 0; i < rtype->fields.size(); ++i) { if (is_dyn) { auto sh = out_shapes[i]; @@ -573,7 +554,7 @@ class Interpreter : packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); return TupleValueNode::make(fields); } else { - Value out_tensor; + ObjectRef out_tensor; if (is_dyn) { CHECK_EQ(out_shapes.size(), 1); auto sh = out_shapes[0]; @@ -588,14 +569,16 @@ class Interpreter : } // Invoke the closure - Value Invoke(const Closure& closure, const tvm::Array& args, const Var& bind = Var()) { + ObjectRef Invoke(const Closure& closure, + const tvm::Array& args, + const Var& bind = Var()) { // Get a reference to the function inside the closure. if (closure->func->IsPrimitive()) { return InvokePrimitiveOp(closure->func, args); } auto func = closure->func; // Allocate a frame with the parameters and free variables. - tvm::Map locals; + tvm::Map locals; CHECK_EQ(func->params.size(), args.size()); @@ -614,11 +597,11 @@ class Interpreter : locals.Set(bind, RecClosureNode::make(closure, bind)); } - return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); + return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); } - Value VisitExpr_(const CallNode* call) final { - tvm::Array args; + ObjectRef VisitExpr_(const CallNode* call) final { + tvm::Array args; for (auto arg : call->args) { args.push_back(Eval(arg)); } @@ -636,7 +619,7 @@ class Interpreter : return ConstructorValueNode::make(con->tag, args, GetRef(con)); } // Now we just evaluate and expect to find a closure. - Value fn_val = Eval(call->op); + ObjectRef fn_val = Eval(call->op); if (const ClosureNode* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); return this->Invoke(closure, args); @@ -645,11 +628,11 @@ class Interpreter : } else { LOG(FATAL) << "internal error: type error, expected function value in the call " << "position"; - return Value(); + return ObjectRef(); } } - Value VisitExpr_(const LetNode* let) final { + ObjectRef VisitExpr_(const LetNode* let) final { if (auto func = let->value.as()) { auto clo = MakeClosure(GetRef(func), let->var); this->extend(let->var, clo); @@ -661,8 +644,8 @@ class Interpreter : return Eval(let->body); } - Value VisitExpr_(const TupleGetItemNode* op) final { - Value val = Eval(op->tuple); + ObjectRef VisitExpr_(const TupleGetItemNode* op) final { + ObjectRef val = Eval(op->tuple); auto product_node = val.as(); CHECK(product_node) << "interal error: when evaluating TupleGetItem expected a tuple value"; @@ -671,13 +654,14 @@ class Interpreter : return product_node->fields[op->index]; } - Value VisitExpr_(const IfNode* op) final { - Value v = Eval(op->cond); - if (const TensorValueNode* bv = v.as()) { + ObjectRef VisitExpr_(const IfNode* op) final { + ObjectRef v = Eval(op->cond); + if (v->IsInstance()) { + auto nd_array = Downcast(v); DLContext cpu_ctx; cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; - NDArray cpu_array = bv->data.CopyTo(cpu_ctx); + NDArray cpu_array = nd_array.CopyTo(cpu_ctx); CHECK_EQ(DataType(cpu_array->dtype), DataType::Bool()); // TODO(@jroesch, @MK): Refactor code into helper from DCE. if (reinterpret_cast(cpu_array->data)[0]) { @@ -687,47 +671,47 @@ class Interpreter : } } else { LOG(FATAL) << "type error, type system should have caught this"; - return Value(); + return ObjectRef(); } } - Value VisitExpr_(const RefWriteNode* op) final { - Value r = Eval(op->ref); + ObjectRef VisitExpr_(const RefWriteNode* op) final { + ObjectRef r = Eval(op->ref); if (const RefValueNode* rv = r.as()) { rv->value = Eval(op->value); return TupleValueNode::make({}); } else { LOG(FATAL) << "type error, type system should have caught this"; - return Value(); + return ObjectRef(); } } - Value VisitExpr_(const RefCreateNode* op) final { + ObjectRef VisitExpr_(const RefCreateNode* op) final { return RefValueNode::make(Eval(op->value)); } - Value VisitExpr_(const RefReadNode* op) final { - Value r = Eval(op->ref); + ObjectRef VisitExpr_(const RefReadNode* op) final { + ObjectRef r = Eval(op->ref); if (const RefValueNode* rv = r.as()) { return rv->value; } else { LOG(FATAL) << "type error, type system should have caught this"; - return Value(); + return ObjectRef(); } } - Value VisitExpr_(const MatchNode* op) final { - Value v = Eval(op->data); + ObjectRef VisitExpr_(const MatchNode* op) final { + ObjectRef v = Eval(op->data); for (const Clause& c : op->clauses) { if (VisitPattern(c->lhs, v)) { return VisitExpr(c->rhs); } } LOG(FATAL) << "did not find any match"; - return Value(); + return ObjectRef(); } - bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final { + bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final { const ConstructorValueNode* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; CHECK_NE(op->constructor->tag, -1); @@ -744,7 +728,7 @@ class Interpreter : return false; } - bool VisitPattern_(const PatternTupleNode* op, const Value& v) final { + bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final { const TupleValueNode* tvn = v.as(); CHECK(tvn) << "need to be a tuple for match"; CHECK_EQ(op->patterns.size(), tvn->fields.size()); @@ -756,11 +740,11 @@ class Interpreter : return true; } - bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final { + bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { return true; } - bool VisitPattern_(const PatternVarNode* op, const Value& v) final { + bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final { extend(op->var, v); return true; } @@ -783,7 +767,7 @@ class Interpreter : DLContext context_; // Target parameter being used by the interpreter. Target target_; - // Value stack. + // Object stack. Stack stack_; // Backend compile engine. CompileEngine engine_; @@ -793,7 +777,7 @@ class Interpreter : }; -TypedPackedFunc +TypedPackedFunc CreateInterpreter( Module mod, DLContext context, @@ -814,7 +798,7 @@ CreateInterpreter( CHECK(f.is_subset_of(FeatureSet::All() - fGraph)); return intrp->Eval(expr); }; - return TypedPackedFunc(packed); + return TypedPackedFunc(packed); } TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") @@ -822,7 +806,6 @@ TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") TVM_REGISTER_NODE_TYPE(ClosureNode); TVM_REGISTER_NODE_TYPE(TupleValueNode); -TVM_REGISTER_NODE_TYPE(TensorValueNode); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 6a3c580aa56e2..d8a97ff7bb3dd 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -849,7 +849,7 @@ void VMCompiler::Compile(Module mod, // populate constants for (auto data : context_.constants) { - exec_->constants.push_back(vm::Tensor(data)); + exec_->constants.push_back(data); } LibraryCodegen(); diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index b830de0b246df..e1af07181ef57 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -27,12 +27,14 @@ #include #include #include -#include "./pattern_util.h" +#include +#include +#include "pattern_util.h" namespace tvm { namespace relay { -using FInterpreter = runtime::TypedPackedFunc; +using FInterpreter = runtime::TypedPackedFunc; class ConstantChecker : private ExprVisitor { public: @@ -177,17 +179,18 @@ class ConstantFolder : public ExprMutator { const Op& cast_op_; // Convert value to expression. - Expr ValueToExpr(Value value) { - if (const auto* val = value.as()) { - for (auto dim : val->data.Shape()) { + Expr ObjectToExpr(const ObjectRef& value) { + if (value->IsInstance()) { + auto nd_array = Downcast(value); + for (auto dim : nd_array.Shape()) { CHECK_GT(dim, 0) << "invalid dimension after constant eval"; } - return ConstantNode::make(val->data); + return ConstantNode::make(nd_array); } else if (const auto* val = value.as()) { Array fields; - for (Value field : val->fields) { - fields.push_back(ValueToExpr(field)); + for (ObjectRef field : val->fields) { + fields.push_back(ObjectToExpr(field)); } return TupleNode::make(fields); } else { @@ -216,7 +219,7 @@ class ConstantFolder : public ExprMutator { mod = seq(mod); auto entry_func = mod->Lookup("main"); expr = expr.as() == nullptr ? entry_func->body : entry_func; - return ValueToExpr(executor_(expr)); + return ObjectToExpr(executor_(expr)); } // Evaluate a call to the shape_of operator for tensors with constant @@ -258,7 +261,7 @@ class ConstantFolder : public ExprMutator { } } - Constant shape = Downcast(ValueToExpr(TensorValueNode::make(value))); + Constant shape = Downcast(ObjectToExpr(value)); if (shape->data.Shape().size() == 0 && GetScalarFromConstant(shape) == 0) { auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx); @@ -283,8 +286,7 @@ Expr FoldConstant(const Expr& expr, const Module& mod) { // in case we are already in a build context. With fresh_build_ctx(BuildConfig::Create()); - return ConstantFolder(CreateInterpreter( - mod, ctx, target), mod).Mutate(expr); + return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr); } namespace transform { diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index a6b867124448c..b06680307cfd0 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -403,7 +403,7 @@ Fuel MkFTop() { /*! * \brief A stack frame in the Relay interpreter. * - * Contains a mapping from relay::Var to relay::Value. + * Contains a mapping from relay::Var to relay::Object. */ struct Frame { /*! \brief The set of local variables and arguments for the frame. */ @@ -554,7 +554,7 @@ bool StatefulOp(const Expr& e) { return sov.stateful; } -using FInterpreter = runtime::TypedPackedFunc; +using FInterpreter = runtime::TypedPackedFunc; DLContext CPUContext() { DLContext ctx; @@ -925,13 +925,14 @@ class PartialEvaluator : public ExprFunctor } } - PStatic Reify(const Value& v, LetList* ll) const { - if (const TensorValueNode* op = v.as()) { - return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data))); + PStatic Reify(const ObjectRef& v, LetList* ll) const { + if (v->IsInstance()) { + auto nd_array = Downcast(v); + return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array))); } else if (const TupleValueNode* op = v.as()) { std::vector fields; tvm::Array fields_dyn; - for (const Value& field : op->fields) { + for (const ObjectRef& field : op->fields) { PStatic ps = Reify(field, ll); fields.push_back(ps); fields_dyn.push_back(ps->dynamic); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 3714425a3323e..bd650665e1960 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -150,10 +150,8 @@ std::string Executable::Stats() const { // Get the number of constants and the shape of each of them. oss << " Constant shapes (# " << constants.size() << "): ["; for (const auto& it : constants) { - const auto* cell = it.as(); - CHECK(cell); - runtime::NDArray data = cell->data; - const auto& shape = data.Shape(); + const auto constant = Downcast(it); + const auto& shape = constant.Shape(); // Scalar if (shape.empty()) { @@ -250,10 +248,8 @@ void Executable::SaveGlobalSection(dmlc::Stream* strm) { void Executable::SaveConstantSection(dmlc::Stream* strm) { std::vector arrays; for (const auto& obj : this->constants) { - const auto* cell = obj.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - arrays.push_back(const_cast(data.operator->())); + const auto cell = Downcast(obj); + arrays.push_back(const_cast(cell.operator->())); } strm->Write(static_cast(this->constants.size())); for (const auto& it : arrays) { @@ -513,8 +509,7 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { for (size_t i = 0; i < size; i++) { runtime::NDArray constant; STREAM_CHECK(constant.Load(strm), "constant"); - runtime::ObjectRef obj = runtime::vm::Tensor(constant); - this->constants.push_back(obj); + this->constants.push_back(constant); } } diff --git a/src/runtime/vm/object.cc b/src/runtime/vm/object.cc index 988ba5d47e7dd..d7760d53e4df7 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/vm/object.cc @@ -34,12 +34,6 @@ namespace tvm { namespace runtime { namespace vm { -Tensor::Tensor(NDArray data) { - auto ptr = make_object(); - ptr->data = std::move(data); - data_ = std::move(ptr); -} - Closure::Closure(size_t func_index, std::vector free_vars) { auto ptr = make_object(); ptr->func_index = func_index; @@ -48,14 +42,6 @@ Closure::Closure(size_t func_index, std::vector free_vars) { } -TVM_REGISTER_GLOBAL("_vmobj.GetTensorData") -.set_body([](TVMArgs args, TVMRetValue* rv) { - ObjectRef obj = args[0]; - const auto* cell = obj.as(); - CHECK(cell != nullptr); - *rv = cell->data; -}); - TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; @@ -80,11 +66,6 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("_vmobj.Tensor") -.set_body([](TVMArgs args, TVMRetValue* rv) { -*rv = Tensor(args[0].operator NDArray()); -}); - TVM_REGISTER_GLOBAL("_vmobj.Tuple") .set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; @@ -105,7 +86,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT") *rv = ADT(tag, fields); }); -TVM_REGISTER_OBJECT_TYPE(TensorObj); TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); } // namespace vm diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 10b27d1a0e46e..49aba7f65eb11 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -613,18 +613,14 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { return os; } -ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { - if (const TensorObj* obj = src.as()) { - auto tensor = obj->data; - if (tensor->ctx.device_type != ctx.device_type) { - auto copy = tensor.CopyTo(ctx); - return Tensor(copy); - } else { - return src; +inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { + if (src->IsInstance()) { + auto nd_array = Downcast(src); + if (nd_array->ctx.device_type != ctx.device_type) { + return nd_array.CopyTo(ctx); } - } else { - return src; } + return src; } PackedFunc VirtualMachine::GetFunction(const std::string& name, @@ -770,16 +766,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, if (const auto* dt_cell = args[i].as()) { for (size_t fi = 0; fi < dt_cell->size; ++fi) { auto obj = (*dt_cell)[fi]; - const auto* tensor = obj.as(); - CHECK(tensor != nullptr) << "Expect tensor object, but received: " - << obj->GetTypeKey(); - setter(idx++, tensor->data); + auto nd_array = Downcast(obj); + setter(idx++, nd_array); } } else { - const auto* tensor = args[i].as(); - CHECK(tensor != nullptr) << "Expect tensor object, but received: " - << args[i]->GetTypeKey(); - setter(idx++, tensor->data); + auto nd_array = Downcast(args[i]); + setter(idx++, nd_array); } } @@ -824,10 +816,8 @@ inline ObjectRef VirtualMachine::ReadRegister(Index r) const { inline int32_t VirtualMachine::LoadScalarInt(Index r) const { int32_t result; const auto& obj = ReadRegister(r); - const auto* tensor = obj.as(); - CHECK(tensor != nullptr) << "Expect tensor object, but received: " - << obj->GetTypeKey(); - NDArray array = tensor->data.CopyTo({kDLCPU, 0}); + auto nd_array = Downcast(obj); + NDArray array = nd_array.CopyTo({kDLCPU, 0}); if (array->dtype.bits <= 8) { result = reinterpret_cast(array->data)[0]; @@ -883,7 +873,7 @@ void VirtualMachine::RunLoop() { case Opcode::LoadConsti: { auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); reinterpret_cast(tensor->data)[0] = instr.load_consti.val; - WriteRegister(instr.dst, Tensor(tensor)); + WriteRegister(instr.dst, tensor); pc_++; goto main_loop; } @@ -943,7 +933,7 @@ void VirtualMachine::RunLoop() { auto tag = adt.tag(); auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); reinterpret_cast(tag_tensor->data)[0] = tag; - WriteRegister(instr.dst, Tensor(tag_tensor)); + WriteRegister(instr.dst, tag_tensor); pc_++; goto main_loop; } @@ -974,9 +964,8 @@ void VirtualMachine::RunLoop() { auto storage_obj = ReadRegister(instr.alloc_tensor.storage); auto storage = Downcast(storage_obj); - auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype); + auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype); - auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc_++; goto main_loop; @@ -986,10 +975,8 @@ void VirtualMachine::RunLoop() { cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); - const auto* tensor = shape_tensor_obj.as(); - CHECK(tensor != nullptr) << "Expect tensor object, but received: " - << shape_tensor_obj->GetTypeKey(); - NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx); + const auto shape_arr = Downcast(shape_tensor_obj); + NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx); const DLTensor* dl_tensor = shape_tensor.operator->(); CHECK_EQ(dl_tensor->dtype.code, 0u); CHECK_LE(dl_tensor->dtype.bits, 64); @@ -1000,9 +987,8 @@ void VirtualMachine::RunLoop() { auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); - auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype); + auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype); - auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc_++; goto main_loop; diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index 612347db1fbd1..e39c41e7078fc 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -18,6 +18,7 @@ import pytest import tensorflow as tf import numpy as np +from tvm import nd from tvm import relay from tvm.relay.frontend.tensorflow import from_tensorflow @@ -26,7 +27,7 @@ def check_equal(graph, tf_out): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) ex = relay.create_executor('vm', mod=mod) relay_out = ex.evaluate()(**params) - if isinstance(relay_out, relay.vmobj.Tensor): + if isinstance(relay_out, nd.NDArray): np.testing.assert_allclose(tf_out, relay_out.asnumpy()) else: if not isinstance(tf_out, list): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 97557d3c2a041..b3940817be890 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -60,7 +60,7 @@ def convert_to_list(x): } def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vmobj.Tensor): + if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] elif isinstance(o, tvm.relay.backend.vmobj.ADT): result = [] @@ -87,8 +87,6 @@ def vmobj_to_list(o): else: raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) - elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): - return [o.data.asnumpy()] else: raise RuntimeError("Unknown object type: %s" % type(o)) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index c0185e438a8d1..8e304bd856e94 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -115,10 +115,8 @@ def tree_to_dict(t): def vmobj_to_list(o, dtype="float32"): - if isinstance(o, tvm.relay.backend.vmobj.Tensor): + if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): - return [o.asnumpy()] elif isinstance(o, tvm.relay.backend.vmobj.ADT): if len(o) == 0: tensor_nil = p.get_var("tensor_nil", dtype=dtype) diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index c1a19c4d9bb19..85bba4402ea29 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -17,8 +17,9 @@ import numpy as np import tvm import tvm.testing +from tvm import nd from tvm import relay -from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue +from tvm.relay.backend.interpreter import TupleValue from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.scope_builder import ScopeBuilder from tvm.relay import testing, create_executor @@ -37,18 +38,11 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): result.asnumpy(), expected_result, rtol=rtol) -def test_from_scalar(): - np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1) - np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0) - np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True) - - def test_tuple_value(): - tv = TupleValue(Value.from_scalar( - 1), Value.from_scalar(2), Value.from_scalar(3)) - np.testing.assert_allclose(tv[0].asnumpy(), 1) - np.testing.assert_allclose(tv[1].asnumpy(), 2) - np.testing.assert_allclose(tv[2].asnumpy(), 3) + tv = TupleValue(relay.const(1), relay.const(2), relay.const(3)) + np.testing.assert_allclose(tv[0].data.asnumpy(), 1) + np.testing.assert_allclose(tv[1].data.asnumpy(), 2) + np.testing.assert_allclose(tv[2].data.asnumpy(), 3) def test_tuple_getitem(): @@ -158,12 +152,6 @@ def test_binds(): tvm.testing.assert_allclose(xx + xx, res) -def test_tensor_value(): - x = relay.var("x", shape=(1, 10)) - xx = np.ones((1, 10)).astype("float32") - check_eval(relay.Function([x], x), [TensorValue(xx)], xx) - - def test_kwargs_params(): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) @@ -174,7 +162,7 @@ def test_kwargs_params(): z_data = np.random.rand(1, 10).astype('float32') params = { 'y': y_data, 'z': z_data } intrp = create_executor("debug") - res = intrp.evaluate(f)(x_data, **params).data + res = intrp.evaluate(f)(x_data, **params) tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data) @@ -185,13 +173,13 @@ def test_function_taking_adt_ref_tuple(): nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil) cons_value = ConstructorValue(prelude.cons.tag, [ - TensorValue(np.random.rand(1, 10).astype('float32')), + nd.array(np.random.rand(1, 10).astype('float32')), nil_value ], prelude.cons) - ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32'))) + ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32'))) tuple_value = TupleValue(*[ - TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10) + nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10) ]) id_func = intrp.evaluate(prelude.id) @@ -236,9 +224,7 @@ def test_tuple_passing(): out = f((10, 8)) tvm.testing.assert_allclose(out.asnumpy(), np.array(10)) # Second use a tuple value. - value_tuple = TupleValue( - TensorValue(np.array(11)), - TensorValue(np.array(12))) + value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12))) out = f(value_tuple) tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) @@ -252,7 +238,6 @@ def test_tuple_passing(): test_binds() test_kwargs_params() test_ref() - test_tensor_value() test_tuple_value() test_tuple_getitem() test_function_taking_adt_ref_tuple() diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index 2a07e9514f0cd..f87f90a85a0b9 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -19,7 +19,7 @@ from tvm import relay from tvm.relay.testing import to_python, run_as_python from tvm.relay.prelude import Prelude -from tvm.relay.backend.interpreter import TensorValue, TupleValue, RefValue, ConstructorValue +from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue # helper: uses a dummy let binding to sequence a list # of expressions: expr1; expr2; expr3, etc. @@ -39,9 +39,9 @@ def init_box_adt(mod): return (box, box_ctor) -# assert that the candidate is a TensorValue with value val +# assert that the candidate is a NDArray with value val def assert_tensor_value(candidate, val): - assert isinstance(candidate, TensorValue) + assert isinstance(candidate, tvm.nd.NDArray) assert np.array_equal(candidate.asnumpy(), np.array(val)) @@ -68,6 +68,7 @@ def test_create_empty_tuple(): def test_create_scalar(): scalar = relay.const(1) tensor_val = run_as_python(scalar) + print(type(tensor_val)) assert_tensor_value(tensor_val, 1) @@ -544,7 +545,7 @@ def reference(x, gamma, beta, moving_mean, moving_var): # there will be a change in accuracy so we need to check # approximate equality - assert isinstance(call_val, TensorValue) + assert isinstance(call_val, tvm.nd.NDArray) tvm.testing.assert_allclose(call_val.asnumpy(), ref_res, atol=eps, rtol=eps) verify_batch_norm([(10, 20), (20,), (20,), (20,), (20,)]) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index a4c5b7d2a3c3b..6b221542d874a 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -56,7 +56,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): return vm.invoke("main", *args) def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vm.Tensor): + if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] elif isinstance(o, tvm.relay.backend.vm.ADT): result = [] diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py index 12d263d1125b2..82a2b116d82c8 100644 --- a/tests/python/relay/test_vm_object.py +++ b/tests/python/relay/test_vm_object.py @@ -19,28 +19,16 @@ import tvm from tvm.relay import vm -def test_tensor(): - arr = tvm.nd.array([1,2,3]) - x = vm.Tensor(arr) - assert isinstance(x, vm.Tensor) - assert x.asnumpy()[0] == 1 - assert x.asnumpy()[-1] == 3 - assert isinstance(x.data, tvm.nd.NDArray) - - def test_adt(): arr = tvm.nd.array([1,2,3]) - x = vm.Tensor(arr) - y = vm.ADT(0, [x, x]) + y = vm.ADT(0, [arr, arr]) assert len(y) == 2 assert isinstance(y, vm.ADT) - y[0:1][-1].data == x.data + y[0:1][-1] == arr assert y.tag == 0 - assert isinstance(x.data, tvm.nd.NDArray) - + assert isinstance(arr, tvm.nd.NDArray) if __name__ == "__main__": - test_tensor() test_adt()