From 4e19231b221912333b5a737f5cda17a6a87b04e6 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 7 Jan 2020 05:00:09 +0000 Subject: [PATCH 1/3] replace TensorObj and TensorValue with NDArray --- include/tvm/relay/interpreter.h | 105 +++------- include/tvm/runtime/vm.h | 19 -- python/tvm/relay/backend/interpreter.py | 60 +----- python/tvm/relay/backend/vm.py | 12 +- python/tvm/relay/backend/vmobj.py | 45 +---- python/tvm/relay/testing/py_converter.py | 19 +- src/relay/backend/interpreter.cc | 191 ++++++++---------- src/relay/backend/vm/compiler.cc | 2 +- src/relay/pass/fold_constant.cc | 26 +-- src/relay/pass/partial_eval.cc | 13 +- src/runtime/vm/executable.cc | 15 +- src/runtime/vm/object.cc | 20 -- src/runtime/vm/vm.cc | 50 ++--- .../frontend/tensorflow/test_control_flow.py | 3 +- .../frontend/tensorflow/test_forward.py | 4 +- tests/python/relay/test_adt.py | 4 +- .../python/relay/test_backend_interpreter.py | 37 +--- tests/python/relay/test_py_converter.py | 9 +- tests/python/relay/test_vm.py | 2 +- tests/python/relay/test_vm_object.py | 18 +- 20 files changed, 215 insertions(+), 439 deletions(-) diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 8ef7f6e4ed89..dc35fc26486a 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -37,15 +37,11 @@ #include #include #include +#include namespace tvm { namespace relay { -/*! - * \brief A Relay value. - */ -class Value; - /*! *\brief Create a Interpreter function that can * evaluate an expression and produce a value. @@ -65,39 +61,21 @@ class Value; * \param target Compiler target flag to compile the functions on the context. * \return A function that takes in an expression and returns a value. */ -runtime::TypedPackedFunc +runtime::TypedPackedFunc CreateInterpreter(Module mod, DLContext context, Target target); -/*! \brief The base container type of Relay values. */ -class ValueNode : public RelayNode { - public: - static constexpr const char* _type_key = "relay.Value"; - TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode); -}; - -class Value : public ObjectRef { - public: - Value() {} - explicit Value(ObjectPtr n) : ObjectRef(n) {} - const ValueNode* operator->() const { - return static_cast(get()); - } - - using ContainerType = ValueNode; -}; - /*! \brief A Relay closure, i.e a scope and a function. */ class Closure; /*! \brief The container type of Closures. */ -class ClosureNode : public ValueNode { +class ClosureNode : public Object { public: /*! \brief The set of free variables in the closure. * * These are the captured variables which are required for * evaluation when we call the closure. */ - tvm::Map env; + tvm::Map env; /*! \brief The function which implements the closure. * * \note May reference the variables contained in the env. @@ -111,22 +89,22 @@ class ClosureNode : public ValueNode { v->Visit("func", &func); } - TVM_DLL static Closure make(tvm::Map env, Function func); + TVM_DLL static Closure make(tvm::Map env, Function func); static constexpr const char* _type_key = "relay.Closure"; - TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object); }; -class Closure : public Value { +class Closure : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode); + TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode); }; /*! \brief A Relay Recursive Closure. A closure that has a name. */ class RecClosure; /*! \brief The container type of RecClosure. */ -class RecClosureNode : public ValueNode { +class RecClosureNode : public Object { public: /*! \brief The closure. */ Closure clos; @@ -143,64 +121,41 @@ class RecClosureNode : public ValueNode { TVM_DLL static RecClosure make(Closure clos, Var bind); static constexpr const char* _type_key = "relay.RecClosure"; - TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object); }; -class RecClosure : public Value { +class RecClosure : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode); + TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode); }; /*! \brief A tuple value. */ class TupleValue; /*! \brief Tuple (x, ... y). */ -struct TupleValueNode : ValueNode { - tvm::Array fields; +struct TupleValueNode : Object { + tvm::Array fields; TupleValueNode() {} void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } - TVM_DLL static TupleValue make(tvm::Array value); + TVM_DLL static TupleValue make(tvm::Array value); static constexpr const char* _type_key = "relay.TupleValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode); -}; - -class TupleValue : public Value { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode); -}; - -/*! \brief A tensor value. */ -class TensorValue; - -/*! \brief The tensor value container, wrapping an NDArray. */ -struct TensorValueNode : ValueNode { - runtime::NDArray data; - - TensorValueNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); } - - /*! \brief Build a value from an NDArray. */ - TVM_DLL static TensorValue make(runtime::NDArray data); - - static constexpr const char* _type_key = "relay.TensorValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object); }; -class TensorValue : public Value { +class TupleValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode); }; /*! \brief A reference value. */ class RefValue; -struct RefValueNode : ValueNode { - mutable Value value; +struct RefValueNode : Object { + mutable ObjectRef value; RefValueNode() {} @@ -208,24 +163,24 @@ struct RefValueNode : ValueNode { v->Visit("value", &value); } - TVM_DLL static RefValue make(Value val); + TVM_DLL static RefValue make(ObjectRef val); static constexpr const char* _type_key = "relay.RefValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object); }; -class RefValue : public Value { +class RefValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode); }; /*! \brief An ADT constructor value. */ class ConstructorValue; -struct ConstructorValueNode : ValueNode { +struct ConstructorValueNode : Object { int32_t tag; - tvm::Array fields; + tvm::Array fields; /*! \brief Optional field tracking ADT constructor. */ Constructor constructor; @@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode { } TVM_DLL static ConstructorValue make(int32_t tag, - tvm::Array fields, + tvm::Array fields, Constructor construtor = {}); static constexpr const char* _type_key = "relay.ConstructorValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object); }; -class ConstructorValue : public Value { +class ConstructorValue : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode); }; } // namespace relay diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 59e9ae861038..990ecf5ea733 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -36,25 +36,6 @@ namespace tvm { namespace runtime { namespace vm { -/*! \brief An object containing an NDArray. */ -class TensorObj : public Object { - public: - /*! \brief The NDArray. */ - NDArray data; - - static constexpr const uint32_t _type_index = TypeIndex::kVMTensor; - static constexpr const char* _type_key = "vm.Tensor"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object); -}; - -/*! \brief reference to tensor. */ -class Tensor : public ObjectRef { - public: - explicit Tensor(NDArray data); - - TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj); -}; - /*! \brief An object representing a closure. */ class ClosureObj : public Object { public: diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 1d53f6a92b07..128edfca0fe1 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -23,27 +23,13 @@ from . import _backend from .. import _make, analysis, transform from .. import module -from ... import register_func, nd +from ... import nd from ..base import NodeBase, register_relay_node from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder -from . import _vm - -class Value(NodeBase): - """Base class of all values. - """ - @staticmethod - @register_func("relay.from_scalar") - def from_scalar(value, dtype=None): - """Convert a Python scalar to a Relay scalar.""" - return TensorValue(const(value, dtype).data) - - def to_vm(self): - return _vm._ValueToVM(self) - @register_relay_node -class TupleValue(Value): +class TupleValue(NodeBase): """A tuple value produced by the interpreter.""" def __init__(self, *fields): self.__init_handle_by_constructor__( @@ -68,60 +54,32 @@ def __iter__(self): @register_relay_node -class Closure(Value): +class Closure(NodeBase): """A closure produced by the interpreter.""" @register_relay_node -class RecClosure(Value): +class RecClosure(NodeBase): """A recursive closure produced by the interpreter.""" @register_relay_node -class ConstructorValue(Value): +class ConstructorValue(NodeBase): def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( _make.ConstructorValue, tag, fields, constructor) @register_relay_node -class TensorValue(Value): - """A Tensor value produced by the interpreter.""" - - def __init__(self, data): - """Allocate a new TensorValue and copy the data from `array` into - the new array. - """ - if isinstance(data, np.ndarray): - data = nd.array(data) - - self.__init_handle_by_constructor__( - _make.TensorValue, data) - - def asnumpy(self): - """Convert a Relay TensorValue into a numpy.ndarray.""" - return self.data.asnumpy() - - def __eq__(self, other): - return self.data == other.data - - def __repr__(self): - return repr(self.data) - - def __str__(self): - return str(self.data) - - -@register_relay_node -class RefValue(Value): +class RefValue(NodeBase): def __init__(self, value): self.__init_handle_by_constructor__( _make.RefValue, value) def _arg_to_ast(mod, arg): - if isinstance(arg, TensorValue): - return Constant(arg.data.copyto(nd.cpu(0))) + if isinstance(arg, nd.NDArray): + return Constant(arg.copyto(nd.cpu(0))) elif isinstance(arg, TupleValue): return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) elif isinstance(arg, tuple): @@ -231,7 +189,7 @@ def evaluate(self, expr=None, binds=None): Returns ------- - val : Union[function, Value] + val : Union[function, NodeBase] The evaluation result. """ if binds: diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index bad4ac227d09..aba55ef7d13e 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -31,16 +31,18 @@ from . import vmobj as _obj from .interpreter import Executor -Tensor = _obj.Tensor ADT = _obj.ADT def _convert(arg, cargs): if isinstance(arg, _expr.Constant): - cargs.append(_obj.Tensor(arg.data)) + cargs.append(arg.data) elif isinstance(arg, _obj.Object): cargs.append(arg) - elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)): - cargs.append(_obj.Tensor(arg)) + elif isinstance(arg, np.ndarray): + nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0)) + cargs.append(nd_arr) + elif isinstance(arg, tvm.nd.NDArray): + cargs.append(arg) elif isinstance(arg, (tuple, list)): field_args = [] for field in arg: @@ -48,7 +50,7 @@ def _convert(arg, cargs): cargs.append(_obj.tuple_object(field_args)) elif isinstance(arg, (_base.numeric_types, bool)): dtype = "int32" if isinstance(arg, (int, bool)) else "float32" - value = _obj.Tensor(np.array(arg, dtype=dtype)) + value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0)) cargs.append(value) else: raise TypeError("Unsupported type: %s" % (type(arg))) diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py index f3fdb763209d..330257ff9467 100644 --- a/python/tvm/relay/backend/vmobj.py +++ b/python/tvm/relay/backend/vmobj.py @@ -16,51 +16,12 @@ # under the License. """TVM Runtime Object API.""" from __future__ import absolute_import as _abs -import numpy as _np from tvm._ffi.object import Object, register_object, getitem_helper from tvm import ndarray as _nd from . import _vmobj -@register_object("vm.Tensor") -class Tensor(Object): - """Tensor object. - - Parameters - ---------- - arr : numpy.ndarray or tvm.nd.NDArray - The source array. - - ctx : TVMContext, optional - The device context to create the array - """ - def __init__(self, arr, ctx=None): - if isinstance(arr, _np.ndarray): - ctx = ctx if ctx else _nd.cpu(0) - self.__init_handle_by_constructor__( - _vmobj.Tensor, _nd.array(arr, ctx=ctx)) - elif isinstance(arr, _nd.NDArray): - self.__init_handle_by_constructor__( - _vmobj.Tensor, arr) - else: - raise RuntimeError("Unsupported type for tensor object.") - - @property - def data(self): - return _vmobj.GetTensorData(self) - - def asnumpy(self): - """Convert data to numpy array - - Returns - ------- - np_arr : numpy.ndarray - The corresponding numpy array. - """ - return self.data.asnumpy() - - @register_object("vm.ADT") class ADT(Object): """Algebatic data type(ADT) object. @@ -75,7 +36,8 @@ class ADT(Object): """ def __init__(self, tag, fields): for f in fields: - assert isinstance(f, Object) + assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " + "tvm NDArray type, but received : {0}".format(type(f)) self.__init_handle_by_constructor__( _vmobj.ADT, tag, *fields) @@ -105,5 +67,6 @@ def tuple_object(fields): The created object. """ for f in fields: - assert isinstance(f, Object) + assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " + "NDArray type, but received : {0}".format(type(f)) return _vmobj.Tuple(*fields) diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index d7b59922b89d..1edb27ae5eb3 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -32,15 +32,16 @@ # import numpy # import tvm # from tvm import relay -# from tvm.relay.backend.interpreter import RefValue, TupleValue, TensorValue, ConstructorValue +# from tvm import nd +# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue PROLOGUE = [ ast.Import([alias('numpy', None)]), ast.Import([alias('tvm', None)]), ast.ImportFrom('tvm', [alias('relay', None)], 0), + ast.ImportFrom('tvm', [alias('nd', None)], 0), ast.ImportFrom('tvm.relay.backend.interpreter', [alias('RefValue', None), alias('TupleValue', None), - alias('TensorValue', None), alias('ConstructorValue', None)], 0) ] @@ -245,7 +246,7 @@ def convert_input(py_input, arg_type): a tensor or tuple (returns list of inputs to the lowered op call)""" # equivalent: input.data if isinstance(arg_type, relay.TensorType): - return [ast.Attribute(py_input, 'data', Load())] + return [py_input] assert isinstance(arg_type, relay.TupleType) # convert each input.fields[i] ret = [] @@ -265,15 +266,13 @@ def convert_output(ret_type): output_var_name = self.generate_var_name('_out') output_var = Name(output_var_name, Load()) shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load()) - # create a new TensorValue of the right shape and dtype + # create a new NDArray of the right shape and dtype assign_output = Assign( [Name(output_var_name, Store())], - self.create_call('TensorValue', [ + self.create_call('nd.array', [ self.create_call('numpy.empty', [shape, Str(ret_type.dtype)]) ])) - # we pass the data field as an argument - extra_arg = ast.Attribute(output_var, 'data', Load()) - return ([assign_output], [extra_arg], output_var) + return ([assign_output], [output_var], output_var) assert isinstance(ret_type, relay.TupleType) assignments = [] extra_args = [] @@ -459,7 +458,7 @@ def visit_if(self, if_block: Expr): true_body, true_defs = self.visit(if_block.true_branch) false_body, false_defs = self.visit(if_block.false_branch) - # need to get the value out of a TensorValue to check the condition + # need to get the value out of a NDArray to check the condition # equvialent to: val.asnumpy() cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], []) ret = ast.IfExp(cond_check, true_body, false_body) @@ -474,7 +473,7 @@ def visit_constant(self, constant: Expr): const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()), [self.parse_numpy_array(value)], [ast.keyword('dtype', Str(constant.checked_type.dtype))]) - return (self.create_call('TensorValue', [const_expr]), []) + return (self.create_call('nd.array', [const_expr]), []) def visit_function(self, func: Expr): diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index c1e4fd59d042..432ad29b13ce 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -43,8 +43,8 @@ inline const PackedFunc& GetPackedFunc(const std::string& name) { return *pf; } -/* Value Implementation */ -Closure ClosureNode::make(tvm::Map env, Function func) { +/* Object Implementation */ +Closure ClosureNode::make(tvm::Map env, Function func) { ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); @@ -62,7 +62,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) // TODO(@jroesch): this doesn't support mutual letrec -/* Value Implementation */ +/* Object Implementation */ RecClosure RecClosureNode::make(Closure clos, Var bind) { ObjectPtr n = make_object(); n->clos = std::move(clos); @@ -79,7 +79,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "RecClosureNode(" << node->clos << ")"; }); -TupleValue TupleValueNode::make(tvm::Array value) { +TupleValue TupleValueNode::make(tvm::Array value) { ObjectPtr n = make_object(); n->fields = value; return TupleValue(n); @@ -94,24 +94,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "TupleValueNode(" << node->fields << ")"; }); -TensorValue TensorValueNode::make(runtime::NDArray data) { - ObjectPtr n = make_object(); - n->data = std::move(data); - return TensorValue(n); -} - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - auto to_str = GetPackedFunc("relay._tensor_value_repr"); - std::string data_str = to_str(GetRef(node)); - p->stream << "TensorValueNode(" << data_str << ")"; - }); - -TVM_REGISTER_GLOBAL("relay._make.TensorValue") -.set_body_typed(TensorValueNode::make); -RefValue RefValueNode::make(Value value) { +RefValue RefValueNode::make(ObjectRef value) { ObjectPtr n = make_object(); n->value = value; return RefValue(n); @@ -129,7 +113,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); ConstructorValue ConstructorValueNode::make(int32_t tag, - tvm::Array fields, + tvm::Array fields, Constructor constructor) { ObjectPtr n = make_object(); n->tag = tag; @@ -153,13 +137,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) /*! * \brief A stack frame in the Relay interpreter. * - * Contains a mapping from relay::Var to relay::Value. + * Contains a mapping from relay::Var to relay::ObjectRef. */ struct Frame { /*! \brief The set of local variables and arguments for the frame. */ - tvm::Map locals; + tvm::Map locals; - explicit Frame(tvm::Map locals) : locals(locals) {} + explicit Frame(tvm::Map locals) : locals(locals) {} }; /*! @@ -175,7 +159,7 @@ struct Stack { Frame& current_frame() { return frames.back(); } - Value Lookup(const Var& local) { + ObjectRef Lookup(const Var& local) { for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) { auto elem = frame->locals.find(local); if (elem != frame->locals.end()) { @@ -185,7 +169,7 @@ struct Stack { LOG(FATAL) << "could not find variable binding for " << local << "address= " << local.operator->(); - return Value(); + return ObjectRef(); } /*! * A wrapper around Frame to add RAII semantics to pushing and popping @@ -206,7 +190,7 @@ class InterpreterState; /*! \brief A container capturing the state of the interpreter. */ class InterpreterStateNode : public Object { public: - using Frame = tvm::Map; + using Frame = tvm::Map; using Stack = tvm::Array; /*! \brief The current expression under evaluation. */ @@ -246,8 +230,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { // // Conversion to ANF is recommended before running the interpretation. class Interpreter : - public ExprFunctor, - PatternFunctor { + public ExprFunctor, + PatternFunctor { public: Interpreter(Module mod, DLContext context, Target target) : mod_(mod), @@ -264,56 +248,56 @@ class Interpreter : return f(); } - void extend(const Var& id, Value v) { + void extend(const Var& id, ObjectRef v) { stack_.current_frame().locals.Set(id, v); } - Value Lookup(const Var& local) { + ObjectRef Lookup(const Var& local) { return stack_.Lookup(local); } - Value Eval(const Expr& expr) { + ObjectRef Eval(const Expr& expr) { return VisitExpr(expr); } - Value VisitExpr(const Expr& expr) final { - auto ret = ExprFunctor::VisitExpr(expr); + ObjectRef VisitExpr(const Expr& expr) final { + auto ret = ExprFunctor::VisitExpr(expr); return ret; } - Value VisitExpr_(const VarNode* var_node) final { + ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef(var_node)); } - Value VisitExpr_(const GlobalVarNode* op) final { + ObjectRef VisitExpr_(const GlobalVarNode* op) final { return Eval(mod_->Lookup(GetRef(op))); } - Value VisitExpr_(const OpNode* id) override { + ObjectRef VisitExpr_(const OpNode* id) override { // TODO(@jroesch): Eta-expand and return in this case. LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node " << "in " << "this case, eta expand"; - return Value(); + return ObjectRef(); } - Value VisitExpr_(const ConstantNode* op) final { - return TensorValueNode::make(op->data.CopyTo(context_)); + ObjectRef VisitExpr_(const ConstantNode* op) final { + return op->data.CopyTo(context_); } - Value VisitExpr_(const TupleNode* op) final { - std::vector values; + ObjectRef VisitExpr_(const TupleNode* op) final { + std::vector values; for (const auto& field : op->fields) { - Value field_value = Eval(field); + ObjectRef field_value = Eval(field); values.push_back(field_value); } return TupleValueNode::make(values); } - Value MakeClosure(const Function& func, Var letrec_name = Var()) { - tvm::Map captured_mod; + ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) { + tvm::Map captured_mod; Array free_vars = FreeVars(func); for (const auto& var : free_vars) { @@ -334,13 +318,13 @@ class Interpreter : return std::move(closure); } - Value VisitExpr_(const FunctionNode* func_node) final { + ObjectRef VisitExpr_(const FunctionNode* func_node) final { auto func = GetRef(func_node); return MakeClosure(func); } Array ComputeDynamicShape(const Function& func, - const Array& args) { + const Array& args) { auto key = CCacheKeyNode::make(func, Target::Create("llvm")); auto cfunc = engine_->LowerShapeFunc(key); size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); @@ -355,11 +339,10 @@ class Interpreter : cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; - auto fset_input = [&](size_t i, Value val, bool need_shape) { - const TensorValueNode* tv = val.as(); - CHECK(tv != nullptr) << "expect Tensor argument"; + auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) { + auto nd_array = Downcast(val); if (need_shape) { - int64_t ndim = tv->data.Shape().size(); + int64_t ndim = nd_array.Shape().size(); NDArray shape_arr; if (ndim == 0) { shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx); @@ -367,13 +350,13 @@ class Interpreter : shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); int64_t* data = reinterpret_cast(shape_arr->data); for (auto j = 0; j < ndim; ++j) { - data[j] = tv->data.Shape()[j]; + data[j] = nd_array.Shape()[j]; } } inputs[i] = shape_arr; setter(i, shape_arr); } else { - auto arr = tv->data.CopyTo(cpu_ctx); + auto arr = nd_array.CopyTo(cpu_ctx); inputs[i] = arr; setter(i, arr); } @@ -384,7 +367,7 @@ class Interpreter : auto arg = args[i]; auto param = func->params[i]; int state = cfunc->shape_func_param_states[i]->value; - if (arg.as()) { + if (arg->IsInstance()) { if (state & kNeedInputData) { fset_input(arg_counter++, arg, false); } @@ -457,8 +440,8 @@ class Interpreter : return out_shapes; } - Value InvokePrimitiveOp(const Function& func, - const Array& args) { + ObjectRef InvokePrimitiveOp(const Function& func, + const Array& args) { const auto* call_node = func->body.as(); if (call_node && call_node->op == debug_op_) { @@ -478,7 +461,7 @@ class Interpreter : // Handle tuple input/output by flattening them. size_t arg_len = 0; for (size_t i = 0; i < args.size(); ++i) { - if (args[i].as()) { + if (args[i]->IsInstance()) { ++arg_len; } else { const auto* tvalue = args[i].as(); @@ -497,11 +480,10 @@ class Interpreter : std::vector codes(arg_len); TVMArgsSetter setter(values.data(), codes.data()); - auto fset_input = [&](size_t i, Value val) { - const TensorValueNode* tv = val.as(); - CHECK(tv != nullptr) << "expect Tensor argument"; - setter(i, tv->data); - DLContext arg_ctx = tv->data->ctx; + auto fset_input = [&](size_t i, ObjectRef val) { + const auto nd_array = Downcast(val); + setter(i, nd_array); + DLContext arg_ctx = nd_array->ctx; CHECK(arg_ctx.device_type == context_.device_type && arg_ctx.device_id == context_.device_id) << "Interpreter expect context to be " @@ -509,8 +491,8 @@ class Interpreter : }; int arg_counter = 0; - for (Value arg : args) { - if (arg.as()) { + for (ObjectRef arg : args) { + if (arg->IsInstance()) { fset_input(arg_counter++, arg); } else { const TupleValueNode* tuple = arg.as(); @@ -536,10 +518,9 @@ class Interpreter : shape.push_back(ivalue[0]); } DLDataType dtype = rtype->dtype; - auto out_tensor = TensorValueNode::make( - NDArray::Empty(shape, dtype, context_)); - setter(num_inputs + i, out_tensor->data); - return out_tensor; + NDArray nd_array = NDArray::Empty(shape, dtype, context_); + setter(num_inputs + i, nd_array); + return nd_array; }; Array out_shapes; @@ -560,7 +541,7 @@ class Interpreter : TVMRetValue rv; if (const TupleTypeNode* rtype = func->body->checked_type().as()) { CHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); - Array fields; + Array fields; for (size_t i = 0; i < rtype->fields.size(); ++i) { if (is_dyn) { auto sh = out_shapes[i]; @@ -573,7 +554,7 @@ class Interpreter : packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); return TupleValueNode::make(fields); } else { - Value out_tensor; + ObjectRef out_tensor; if (is_dyn) { CHECK_EQ(out_shapes.size(), 1); auto sh = out_shapes[0]; @@ -588,14 +569,16 @@ class Interpreter : } // Invoke the closure - Value Invoke(const Closure& closure, const tvm::Array& args, const Var& bind = Var()) { + ObjectRef Invoke(const Closure& closure, + const tvm::Array& args, + const Var& bind = Var()) { // Get a reference to the function inside the closure. if (closure->func->IsPrimitive()) { return InvokePrimitiveOp(closure->func, args); } auto func = closure->func; // Allocate a frame with the parameters and free variables. - tvm::Map locals; + tvm::Map locals; CHECK_EQ(func->params.size(), args.size()); @@ -614,11 +597,11 @@ class Interpreter : locals.Set(bind, RecClosureNode::make(closure, bind)); } - return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); + return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); } - Value VisitExpr_(const CallNode* call) final { - tvm::Array args; + ObjectRef VisitExpr_(const CallNode* call) final { + tvm::Array args; for (auto arg : call->args) { args.push_back(Eval(arg)); } @@ -636,7 +619,7 @@ class Interpreter : return ConstructorValueNode::make(con->tag, args, GetRef(con)); } // Now we just evaluate and expect to find a closure. - Value fn_val = Eval(call->op); + ObjectRef fn_val = Eval(call->op); if (const ClosureNode* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); return this->Invoke(closure, args); @@ -645,11 +628,11 @@ class Interpreter : } else { LOG(FATAL) << "internal error: type error, expected function value in the call " << "position"; - return Value(); + return ObjectRef(); } } - Value VisitExpr_(const LetNode* let) final { + ObjectRef VisitExpr_(const LetNode* let) final { if (auto func = let->value.as()) { auto clo = MakeClosure(GetRef(func), let->var); this->extend(let->var, clo); @@ -661,8 +644,8 @@ class Interpreter : return Eval(let->body); } - Value VisitExpr_(const TupleGetItemNode* op) final { - Value val = Eval(op->tuple); + ObjectRef VisitExpr_(const TupleGetItemNode* op) final { + ObjectRef val = Eval(op->tuple); auto product_node = val.as(); CHECK(product_node) << "interal error: when evaluating TupleGetItem expected a tuple value"; @@ -671,13 +654,14 @@ class Interpreter : return product_node->fields[op->index]; } - Value VisitExpr_(const IfNode* op) final { - Value v = Eval(op->cond); - if (const TensorValueNode* bv = v.as()) { + ObjectRef VisitExpr_(const IfNode* op) final { + ObjectRef v = Eval(op->cond); + if (v->IsInstance()) { + auto nd_array = Downcast(v); DLContext cpu_ctx; cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; - NDArray cpu_array = bv->data.CopyTo(cpu_ctx); + NDArray cpu_array = nd_array.CopyTo(cpu_ctx); CHECK_EQ(DataType(cpu_array->dtype), DataType::Bool()); // TODO(@jroesch, @MK): Refactor code into helper from DCE. if (reinterpret_cast(cpu_array->data)[0]) { @@ -687,47 +671,47 @@ class Interpreter : } } else { LOG(FATAL) << "type error, type system should have caught this"; - return Value(); + return ObjectRef(); } } - Value VisitExpr_(const RefWriteNode* op) final { - Value r = Eval(op->ref); + ObjectRef VisitExpr_(const RefWriteNode* op) final { + ObjectRef r = Eval(op->ref); if (const RefValueNode* rv = r.as()) { rv->value = Eval(op->value); return TupleValueNode::make({}); } else { LOG(FATAL) << "type error, type system should have caught this"; - return Value(); + return ObjectRef(); } } - Value VisitExpr_(const RefCreateNode* op) final { + ObjectRef VisitExpr_(const RefCreateNode* op) final { return RefValueNode::make(Eval(op->value)); } - Value VisitExpr_(const RefReadNode* op) final { - Value r = Eval(op->ref); + ObjectRef VisitExpr_(const RefReadNode* op) final { + ObjectRef r = Eval(op->ref); if (const RefValueNode* rv = r.as()) { return rv->value; } else { LOG(FATAL) << "type error, type system should have caught this"; - return Value(); + return ObjectRef(); } } - Value VisitExpr_(const MatchNode* op) final { - Value v = Eval(op->data); + ObjectRef VisitExpr_(const MatchNode* op) final { + ObjectRef v = Eval(op->data); for (const Clause& c : op->clauses) { if (VisitPattern(c->lhs, v)) { return VisitExpr(c->rhs); } } LOG(FATAL) << "did not find any match"; - return Value(); + return ObjectRef(); } - bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final { + bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final { const ConstructorValueNode* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; CHECK_NE(op->constructor->tag, -1); @@ -744,7 +728,7 @@ class Interpreter : return false; } - bool VisitPattern_(const PatternTupleNode* op, const Value& v) final { + bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final { const TupleValueNode* tvn = v.as(); CHECK(tvn) << "need to be a tuple for match"; CHECK_EQ(op->patterns.size(), tvn->fields.size()); @@ -756,11 +740,11 @@ class Interpreter : return true; } - bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final { + bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { return true; } - bool VisitPattern_(const PatternVarNode* op, const Value& v) final { + bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final { extend(op->var, v); return true; } @@ -783,7 +767,7 @@ class Interpreter : DLContext context_; // Target parameter being used by the interpreter. Target target_; - // Value stack. + // Object stack. Stack stack_; // Backend compile engine. CompileEngine engine_; @@ -793,7 +777,7 @@ class Interpreter : }; -TypedPackedFunc +TypedPackedFunc CreateInterpreter( Module mod, DLContext context, @@ -814,7 +798,7 @@ CreateInterpreter( CHECK(f.is_subset_of(FeatureSet::All() - fGraph)); return intrp->Eval(expr); }; - return TypedPackedFunc(packed); + return TypedPackedFunc(packed); } TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") @@ -822,7 +806,6 @@ TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") TVM_REGISTER_NODE_TYPE(ClosureNode); TVM_REGISTER_NODE_TYPE(TupleValueNode); -TVM_REGISTER_NODE_TYPE(TensorValueNode); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 5d262a09a84d..bb47685c7ece 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -854,7 +854,7 @@ void VMCompiler::Lower(Module mod, // populate constants for (auto data : context_.constants) { - exec_->constants.push_back(vm::Tensor(data)); + exec_->constants.push_back(data); } // update global function map diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index d36733df341b..7f00c718e83e 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -27,12 +27,14 @@ #include #include #include -#include "./pattern_util.h" +#include +#include +#include "pattern_util.h" namespace tvm { namespace relay { -using FInterpreter = runtime::TypedPackedFunc; +using FInterpreter = runtime::TypedPackedFunc; class ConstantChecker : private ExprVisitor { public: @@ -177,17 +179,18 @@ class ConstantFolder : public ExprMutator { const Op& cast_op_; // Convert value to expression. - Expr ValueToExpr(Value value) { - if (const auto* val = value.as()) { - for (auto dim : val->data.Shape()) { + Expr ObjectToExpr(const ObjectRef& value) { + if (value->IsInstance()) { + auto nd_array = Downcast(value); + for (auto dim : nd_array.Shape()) { CHECK_GT(dim, 0) << "invalid dimension after constant eval"; } - return ConstantNode::make(val->data); + return ConstantNode::make(nd_array); } else if (const auto* val = value.as()) { Array fields; - for (Value field : val->fields) { - fields.push_back(ValueToExpr(field)); + for (ObjectRef field : val->fields) { + fields.push_back(ObjectToExpr(field)); } return TupleNode::make(fields); } else { @@ -216,7 +219,7 @@ class ConstantFolder : public ExprMutator { mod = seq(mod); auto entry_func = mod->Lookup("main"); expr = expr.as() == nullptr ? entry_func->body : entry_func; - return ValueToExpr(executor_(expr)); + return ObjectToExpr(executor_(expr)); } // Evaluate a call to the shape_of operator for tensors with constant @@ -258,7 +261,7 @@ class ConstantFolder : public ExprMutator { } } - Constant shape = Downcast(ValueToExpr(TensorValueNode::make(value))); + Constant shape = Downcast(ObjectToExpr(value)); if (shape->data.Shape().size() == 0 && GetScalarFromConstant(shape) == 0) { auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx); @@ -283,8 +286,7 @@ Expr FoldConstant(const Expr& expr, const Module& mod) { // in case we are already in a build context. With fresh_build_ctx(BuildConfig::Create()); - return ConstantFolder(CreateInterpreter( - mod, ctx, target), mod).Mutate(expr); + return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr); } namespace transform { diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index a6b867124448..b06680307cfd 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -403,7 +403,7 @@ Fuel MkFTop() { /*! * \brief A stack frame in the Relay interpreter. * - * Contains a mapping from relay::Var to relay::Value. + * Contains a mapping from relay::Var to relay::Object. */ struct Frame { /*! \brief The set of local variables and arguments for the frame. */ @@ -554,7 +554,7 @@ bool StatefulOp(const Expr& e) { return sov.stateful; } -using FInterpreter = runtime::TypedPackedFunc; +using FInterpreter = runtime::TypedPackedFunc; DLContext CPUContext() { DLContext ctx; @@ -925,13 +925,14 @@ class PartialEvaluator : public ExprFunctor } } - PStatic Reify(const Value& v, LetList* ll) const { - if (const TensorValueNode* op = v.as()) { - return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data))); + PStatic Reify(const ObjectRef& v, LetList* ll) const { + if (v->IsInstance()) { + auto nd_array = Downcast(v); + return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array))); } else if (const TupleValueNode* op = v.as()) { std::vector fields; tvm::Array fields_dyn; - for (const Value& field : op->fields) { + for (const ObjectRef& field : op->fields) { PStatic ps = Reify(field, ll); fields.push_back(ps); fields_dyn.push_back(ps->dynamic); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 3714425a3323..bd650665e196 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -150,10 +150,8 @@ std::string Executable::Stats() const { // Get the number of constants and the shape of each of them. oss << " Constant shapes (# " << constants.size() << "): ["; for (const auto& it : constants) { - const auto* cell = it.as(); - CHECK(cell); - runtime::NDArray data = cell->data; - const auto& shape = data.Shape(); + const auto constant = Downcast(it); + const auto& shape = constant.Shape(); // Scalar if (shape.empty()) { @@ -250,10 +248,8 @@ void Executable::SaveGlobalSection(dmlc::Stream* strm) { void Executable::SaveConstantSection(dmlc::Stream* strm) { std::vector arrays; for (const auto& obj : this->constants) { - const auto* cell = obj.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - arrays.push_back(const_cast(data.operator->())); + const auto cell = Downcast(obj); + arrays.push_back(const_cast(cell.operator->())); } strm->Write(static_cast(this->constants.size())); for (const auto& it : arrays) { @@ -513,8 +509,7 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { for (size_t i = 0; i < size; i++) { runtime::NDArray constant; STREAM_CHECK(constant.Load(strm), "constant"); - runtime::ObjectRef obj = runtime::vm::Tensor(constant); - this->constants.push_back(obj); + this->constants.push_back(constant); } } diff --git a/src/runtime/vm/object.cc b/src/runtime/vm/object.cc index 988ba5d47e7d..d7760d53e4df 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/vm/object.cc @@ -34,12 +34,6 @@ namespace tvm { namespace runtime { namespace vm { -Tensor::Tensor(NDArray data) { - auto ptr = make_object(); - ptr->data = std::move(data); - data_ = std::move(ptr); -} - Closure::Closure(size_t func_index, std::vector free_vars) { auto ptr = make_object(); ptr->func_index = func_index; @@ -48,14 +42,6 @@ Closure::Closure(size_t func_index, std::vector free_vars) { } -TVM_REGISTER_GLOBAL("_vmobj.GetTensorData") -.set_body([](TVMArgs args, TVMRetValue* rv) { - ObjectRef obj = args[0]; - const auto* cell = obj.as(); - CHECK(cell != nullptr); - *rv = cell->data; -}); - TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; @@ -80,11 +66,6 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("_vmobj.Tensor") -.set_body([](TVMArgs args, TVMRetValue* rv) { -*rv = Tensor(args[0].operator NDArray()); -}); - TVM_REGISTER_GLOBAL("_vmobj.Tuple") .set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; @@ -105,7 +86,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT") *rv = ADT(tag, fields); }); -TVM_REGISTER_OBJECT_TYPE(TensorObj); TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); } // namespace vm diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 10b27d1a0e46..49aba7f65eb1 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -613,18 +613,14 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { return os; } -ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { - if (const TensorObj* obj = src.as()) { - auto tensor = obj->data; - if (tensor->ctx.device_type != ctx.device_type) { - auto copy = tensor.CopyTo(ctx); - return Tensor(copy); - } else { - return src; +inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { + if (src->IsInstance()) { + auto nd_array = Downcast(src); + if (nd_array->ctx.device_type != ctx.device_type) { + return nd_array.CopyTo(ctx); } - } else { - return src; } + return src; } PackedFunc VirtualMachine::GetFunction(const std::string& name, @@ -770,16 +766,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, if (const auto* dt_cell = args[i].as()) { for (size_t fi = 0; fi < dt_cell->size; ++fi) { auto obj = (*dt_cell)[fi]; - const auto* tensor = obj.as(); - CHECK(tensor != nullptr) << "Expect tensor object, but received: " - << obj->GetTypeKey(); - setter(idx++, tensor->data); + auto nd_array = Downcast(obj); + setter(idx++, nd_array); } } else { - const auto* tensor = args[i].as(); - CHECK(tensor != nullptr) << "Expect tensor object, but received: " - << args[i]->GetTypeKey(); - setter(idx++, tensor->data); + auto nd_array = Downcast(args[i]); + setter(idx++, nd_array); } } @@ -824,10 +816,8 @@ inline ObjectRef VirtualMachine::ReadRegister(Index r) const { inline int32_t VirtualMachine::LoadScalarInt(Index r) const { int32_t result; const auto& obj = ReadRegister(r); - const auto* tensor = obj.as(); - CHECK(tensor != nullptr) << "Expect tensor object, but received: " - << obj->GetTypeKey(); - NDArray array = tensor->data.CopyTo({kDLCPU, 0}); + auto nd_array = Downcast(obj); + NDArray array = nd_array.CopyTo({kDLCPU, 0}); if (array->dtype.bits <= 8) { result = reinterpret_cast(array->data)[0]; @@ -883,7 +873,7 @@ void VirtualMachine::RunLoop() { case Opcode::LoadConsti: { auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); reinterpret_cast(tensor->data)[0] = instr.load_consti.val; - WriteRegister(instr.dst, Tensor(tensor)); + WriteRegister(instr.dst, tensor); pc_++; goto main_loop; } @@ -943,7 +933,7 @@ void VirtualMachine::RunLoop() { auto tag = adt.tag(); auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); reinterpret_cast(tag_tensor->data)[0] = tag; - WriteRegister(instr.dst, Tensor(tag_tensor)); + WriteRegister(instr.dst, tag_tensor); pc_++; goto main_loop; } @@ -974,9 +964,8 @@ void VirtualMachine::RunLoop() { auto storage_obj = ReadRegister(instr.alloc_tensor.storage); auto storage = Downcast(storage_obj); - auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype); + auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype); - auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc_++; goto main_loop; @@ -986,10 +975,8 @@ void VirtualMachine::RunLoop() { cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); - const auto* tensor = shape_tensor_obj.as(); - CHECK(tensor != nullptr) << "Expect tensor object, but received: " - << shape_tensor_obj->GetTypeKey(); - NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx); + const auto shape_arr = Downcast(shape_tensor_obj); + NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx); const DLTensor* dl_tensor = shape_tensor.operator->(); CHECK_EQ(dl_tensor->dtype.code, 0u); CHECK_LE(dl_tensor->dtype.bits, 64); @@ -1000,9 +987,8 @@ void VirtualMachine::RunLoop() { auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); - auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype); + auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype); - auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc_++; goto main_loop; diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index 612347db1fbd..e39c41e7078f 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -18,6 +18,7 @@ import pytest import tensorflow as tf import numpy as np +from tvm import nd from tvm import relay from tvm.relay.frontend.tensorflow import from_tensorflow @@ -26,7 +27,7 @@ def check_equal(graph, tf_out): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) ex = relay.create_executor('vm', mod=mod) relay_out = ex.evaluate()(**params) - if isinstance(relay_out, relay.vmobj.Tensor): + if isinstance(relay_out, nd.NDArray): np.testing.assert_allclose(tf_out, relay_out.asnumpy()) else: if not isinstance(tf_out, list): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 97557d3c2a04..b3940817be89 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -60,7 +60,7 @@ def convert_to_list(x): } def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vmobj.Tensor): + if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] elif isinstance(o, tvm.relay.backend.vmobj.ADT): result = [] @@ -87,8 +87,6 @@ def vmobj_to_list(o): else: raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) - elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): - return [o.data.asnumpy()] else: raise RuntimeError("Unknown object type: %s" % type(o)) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index c0185e438a8d..8e304bd856e9 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -115,10 +115,8 @@ def tree_to_dict(t): def vmobj_to_list(o, dtype="float32"): - if isinstance(o, tvm.relay.backend.vmobj.Tensor): + if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): - return [o.asnumpy()] elif isinstance(o, tvm.relay.backend.vmobj.ADT): if len(o) == 0: tensor_nil = p.get_var("tensor_nil", dtype=dtype) diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index c1a19c4d9bb1..85bba4402ea2 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -17,8 +17,9 @@ import numpy as np import tvm import tvm.testing +from tvm import nd from tvm import relay -from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue +from tvm.relay.backend.interpreter import TupleValue from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.scope_builder import ScopeBuilder from tvm.relay import testing, create_executor @@ -37,18 +38,11 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): result.asnumpy(), expected_result, rtol=rtol) -def test_from_scalar(): - np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1) - np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0) - np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True) - - def test_tuple_value(): - tv = TupleValue(Value.from_scalar( - 1), Value.from_scalar(2), Value.from_scalar(3)) - np.testing.assert_allclose(tv[0].asnumpy(), 1) - np.testing.assert_allclose(tv[1].asnumpy(), 2) - np.testing.assert_allclose(tv[2].asnumpy(), 3) + tv = TupleValue(relay.const(1), relay.const(2), relay.const(3)) + np.testing.assert_allclose(tv[0].data.asnumpy(), 1) + np.testing.assert_allclose(tv[1].data.asnumpy(), 2) + np.testing.assert_allclose(tv[2].data.asnumpy(), 3) def test_tuple_getitem(): @@ -158,12 +152,6 @@ def test_binds(): tvm.testing.assert_allclose(xx + xx, res) -def test_tensor_value(): - x = relay.var("x", shape=(1, 10)) - xx = np.ones((1, 10)).astype("float32") - check_eval(relay.Function([x], x), [TensorValue(xx)], xx) - - def test_kwargs_params(): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) @@ -174,7 +162,7 @@ def test_kwargs_params(): z_data = np.random.rand(1, 10).astype('float32') params = { 'y': y_data, 'z': z_data } intrp = create_executor("debug") - res = intrp.evaluate(f)(x_data, **params).data + res = intrp.evaluate(f)(x_data, **params) tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data) @@ -185,13 +173,13 @@ def test_function_taking_adt_ref_tuple(): nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil) cons_value = ConstructorValue(prelude.cons.tag, [ - TensorValue(np.random.rand(1, 10).astype('float32')), + nd.array(np.random.rand(1, 10).astype('float32')), nil_value ], prelude.cons) - ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32'))) + ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32'))) tuple_value = TupleValue(*[ - TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10) + nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10) ]) id_func = intrp.evaluate(prelude.id) @@ -236,9 +224,7 @@ def test_tuple_passing(): out = f((10, 8)) tvm.testing.assert_allclose(out.asnumpy(), np.array(10)) # Second use a tuple value. - value_tuple = TupleValue( - TensorValue(np.array(11)), - TensorValue(np.array(12))) + value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12))) out = f(value_tuple) tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) @@ -252,7 +238,6 @@ def test_tuple_passing(): test_binds() test_kwargs_params() test_ref() - test_tensor_value() test_tuple_value() test_tuple_getitem() test_function_taking_adt_ref_tuple() diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index 2a07e9514f0c..f87f90a85a0b 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -19,7 +19,7 @@ from tvm import relay from tvm.relay.testing import to_python, run_as_python from tvm.relay.prelude import Prelude -from tvm.relay.backend.interpreter import TensorValue, TupleValue, RefValue, ConstructorValue +from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue # helper: uses a dummy let binding to sequence a list # of expressions: expr1; expr2; expr3, etc. @@ -39,9 +39,9 @@ def init_box_adt(mod): return (box, box_ctor) -# assert that the candidate is a TensorValue with value val +# assert that the candidate is a NDArray with value val def assert_tensor_value(candidate, val): - assert isinstance(candidate, TensorValue) + assert isinstance(candidate, tvm.nd.NDArray) assert np.array_equal(candidate.asnumpy(), np.array(val)) @@ -68,6 +68,7 @@ def test_create_empty_tuple(): def test_create_scalar(): scalar = relay.const(1) tensor_val = run_as_python(scalar) + print(type(tensor_val)) assert_tensor_value(tensor_val, 1) @@ -544,7 +545,7 @@ def reference(x, gamma, beta, moving_mean, moving_var): # there will be a change in accuracy so we need to check # approximate equality - assert isinstance(call_val, TensorValue) + assert isinstance(call_val, tvm.nd.NDArray) tvm.testing.assert_allclose(call_val.asnumpy(), ref_res, atol=eps, rtol=eps) verify_batch_norm([(10, 20), (20,), (20,), (20,), (20,)]) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 8a160b11ee65..d53360d76656 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -56,7 +56,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): return vm.invoke("main", *args) def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vm.Tensor): + if isinstance(o, tvm.nd.NDArray): return [o.asnumpy().tolist()] elif isinstance(o, tvm.relay.backend.vm.ADT): result = [] diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py index 12d263d1125b..82a2b116d82c 100644 --- a/tests/python/relay/test_vm_object.py +++ b/tests/python/relay/test_vm_object.py @@ -19,28 +19,16 @@ import tvm from tvm.relay import vm -def test_tensor(): - arr = tvm.nd.array([1,2,3]) - x = vm.Tensor(arr) - assert isinstance(x, vm.Tensor) - assert x.asnumpy()[0] == 1 - assert x.asnumpy()[-1] == 3 - assert isinstance(x.data, tvm.nd.NDArray) - - def test_adt(): arr = tvm.nd.array([1,2,3]) - x = vm.Tensor(arr) - y = vm.ADT(0, [x, x]) + y = vm.ADT(0, [arr, arr]) assert len(y) == 2 assert isinstance(y, vm.ADT) - y[0:1][-1].data == x.data + y[0:1][-1] == arr assert y.tag == 0 - assert isinstance(x.data, tvm.nd.NDArray) - + assert isinstance(arr, tvm.nd.NDArray) if __name__ == "__main__": - test_tensor() test_adt() From 92d6b832a6b79fbf206ae5696dc8f892701fa7d0 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 8 Jan 2020 21:38:00 +0000 Subject: [PATCH 2/3] NodeBase to Object in Python --- docs/api/python/dev.rst | 13 ++- docs/dev/codebase_walkthrough.rst | 8 +- python/tvm/__init__.py | 4 +- python/tvm/_ffi/_ctypes/function.py | 13 +-- python/tvm/_ffi/_ctypes/object.py | 10 +-- python/tvm/_ffi/_cython/function.pxi | 11 +-- python/tvm/_ffi/_cython/object.pxi | 8 +- python/tvm/_ffi/function.py | 2 +- python/tvm/_ffi/node.py | 89 ------------------- python/tvm/_ffi/object.py | 66 ++++++++++++-- .../{node_generic.py => object_generic.py} | 30 +++---- python/tvm/api.py | 19 ++-- python/tvm/arith.py | 16 ++-- python/tvm/attrs.py | 6 +- python/tvm/build_module.py | 18 ++-- python/tvm/container.py | 28 +++--- python/tvm/expr.py | 84 ++++++++--------- python/tvm/ir_builder.py | 6 +- python/tvm/{node.py => object.py} | 4 +- python/tvm/relay/_module.pyi | 4 +- python/tvm/relay/adt.py | 4 +- python/tvm/relay/backend/compile_engine.py | 10 +-- python/tvm/relay/backend/interpreter.py | 14 +-- python/tvm/relay/base.py | 9 +- python/tvm/relay/expr.pyi | 4 +- python/tvm/relay/quantize/quantize.py | 4 +- python/tvm/relay/transform.pyi | 8 +- python/tvm/relay/ty.pyi | 4 +- python/tvm/schedule.py | 40 ++++----- python/tvm/stmt.py | 32 +++---- python/tvm/target.py | 12 +-- python/tvm/tensor.py | 43 ++++----- python/tvm/tensor_intrin.py | 6 +- .../test_pass_inject_double_buffer.py | 4 +- .../unittest/test_pass_inject_vthread.py | 8 +- .../unittest/test_pass_storage_flatten.py | 4 +- 36 files changed, 298 insertions(+), 347 deletions(-) delete mode 100644 python/tvm/_ffi/node.py rename python/tvm/_ffi/{node_generic.py => object_generic.py} (82%) rename python/tvm/{node.py => object.py} (93%) diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index 7bb938ca7517..8a0a70588bc3 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -20,17 +20,14 @@ Developer API This page contains modules that are used by developers of TVM. Many of these APIs are PackedFunc registered in C++ backend. -tvm.node -~~~~~~~~ -.. automodule:: tvm.node - -.. autoclass:: tvm.node.NodeBase - :members: +tvm.object +~~~~~~~~~~ +.. automodule:: tvm.object -.. autoclass:: tvm.node.Node +.. autoclass:: tvm.object.Object :members: -.. autofunction:: tvm.register_node +.. autofunction:: tvm.register_object tvm.expr ~~~~~~~~ diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 19f185edca98..0732c26f0c58 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") -Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``. +Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``. :: - @register_node - class Tensor(NodeBase, _expr.ExprOp): + @register_object + class Tensor(Object, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): ... -The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document `_, and details are in ``python/tvm/_ffi/`` if you are interested. +The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document `_, and details are in ``python/tvm/_ffi/`` if you are interested. ``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``: diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 9e3eb0faefb8..b2a4ca3ccf13 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -34,7 +34,7 @@ from . import container from . import schedule from . import module -from . import node +from . import object from . import attrs from . import ir_builder from . import target @@ -55,7 +55,7 @@ from .api import * from .intrin import * from .tensor_intrin import decl_tensor_intrin -from .node import register_node +from .object import register_object from .ndarray import register_extension from .schedule import create_schedule from .build_module import build, lower, build_config diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 2f0b5babda4d..45048c5768a9 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -25,14 +25,14 @@ from ..base import _LIB, get_last_ffi_error, py2cerror from ..base import c_str, string_types -from ..node_generic import convert_to_node, NodeGeneric +from ..object_generic import convert_to_object, ObjectGeneric from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext from . import ndarray as _nd from .ndarray import NDArrayBase, _make_array from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .object import ObjectBase, _set_class_node +from .object import ObjectBase, _set_class_object from . import object as _object FunctionHandle = ctypes.c_void_p @@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, string_types): values[i].v_str = c_str(arg) type_codes[i] = TypeCode.STR - elif isinstance(arg, (list, tuple, dict, NodeGeneric)): - arg = convert_to_node(arg) + elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): + arg = convert_to_object(arg) values[i].v_handle = arg.handle type_codes[i] = TypeCode.OBJECT_HANDLE temp_args.append(arg) @@ -256,7 +256,6 @@ def _handle_return_func(x): _CLASS_MODULE = None _CLASS_FUNCTION = None -_CLASS_OBJECT = None def _set_class_module(module_class): """Initialize the module.""" @@ -266,7 +265,3 @@ def _set_class_module(module_class): def _set_class_function(func_class): global _CLASS_FUNCTION _CLASS_FUNCTION = func_class - -def _set_class_object(obj_class): - global _CLASS_OBJECT - _CLASS_OBJECT = obj_class diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index b8b8aefea131..8a2fb1b5363e 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -30,11 +30,11 @@ """Maps object type to its constructor""" OBJECT_TYPE = {} -_CLASS_NODE = None +_CLASS_OBJECT = None -def _set_class_node(node_class): - global _CLASS_NODE - _CLASS_NODE = node_class +def _set_class_object(object_class): + global _CLASS_OBJECT + _CLASS_OBJECT = object_class def _register_object(index, cls): @@ -51,7 +51,7 @@ def _return_object(x): handle = ObjectHandle(handle) tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) - cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE) + cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index a2360427b6c7..7789769a3901 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -20,7 +20,7 @@ import traceback from cpython cimport Py_INCREF, Py_DECREF from numbers import Number, Integral from ..base import string_types, py2cerror -from ..node_generic import convert_to_node, NodeGeneric +from ..object_generic import convert_to_object, ObjectGeneric from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray @@ -149,8 +149,8 @@ cdef inline int make_arg(object arg, value[0].v_str = tstr tcode[0] = kStr temp_args.append(tstr) - elif isinstance(arg, (list, tuple, dict, NodeGeneric)): - arg = convert_to_node(arg) + elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): + arg = convert_to_object(arg) value[0].v_handle = (arg).chandle tcode[0] = kObjectHandle temp_args.append(arg) @@ -308,7 +308,6 @@ cdef class FunctionBase: _CLASS_FUNCTION = None _CLASS_MODULE = None _CLASS_OBJECT = None -_CLASS_NODE = None def _set_class_module(module_class): """Initialize the module.""" @@ -322,7 +321,3 @@ def _set_class_function(func_class): def _set_class_object(obj_class): global _CLASS_OBJECT _CLASS_OBJECT = obj_class - -def _set_class_node(node_class): - global _CLASS_NODE - _CLASS_NODE = node_class diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 6d20723fd188..1392f9944835 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -32,7 +32,7 @@ def _register_object(int index, object cls): cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE - global _CLASS_NODE + global _CLASS_OBJECT cdef unsigned tindex cdef object cls cdef object handle @@ -44,11 +44,9 @@ cdef inline object make_ret_object(void* chandle): if cls is not None: obj = cls.__new__(cls) else: - # default use node base class - # TODO(tqchen) change to object after Node unifies with Object - obj = _CLASS_NODE.__new__(_CLASS_NODE) + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) else: - obj = _CLASS_NODE.__new__(_CLASS_NODE) + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = chandle return obj diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index 23d95ebbf66b..22e03563976b 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -22,7 +22,7 @@ import sys import ctypes from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE -from .node_generic import _set_class_objects +from .object_generic import _set_class_objects IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py deleted file mode 100644 index c6c151af9053..000000000000 --- a/python/tvm/_ffi/node.py +++ /dev/null @@ -1,89 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Node namespace""" -# pylint: disable=unused-import -from __future__ import absolute_import - -import ctypes -import sys -from .. import _api_internal -from .object import Object, register_object, _set_class_node -from .node_generic import NodeGeneric, convert_to_node, const - - -def _new_object(cls): - """Helper function for pickle""" - return cls.__new__(cls) - - -class NodeBase(Object): - """NodeBase is the base class of all TVM language AST object.""" - def __repr__(self): - return _api_internal._format_str(self) - - def __dir__(self): - fnames = _api_internal._NodeListAttrNames(self) - size = fnames(-1) - return [fnames(i) for i in range(size)] - - def __getattr__(self, name): - try: - return _api_internal._NodeGetAttr(self, name) - except AttributeError: - raise AttributeError( - "%s has no attribute %s" % (str(type(self)), name)) - - def __hash__(self): - return _api_internal._raw_ptr(self) - - def __eq__(self, other): - return self.same_as(other) - - def __ne__(self, other): - return not self.__eq__(other) - - def __reduce__(self): - cls = type(self) - return (_new_object, (cls, ), self.__getstate__()) - - def __getstate__(self): - handle = self.handle - if handle is not None: - return {'handle': _api_internal._save_json(self)} - return {'handle': None} - - def __setstate__(self, state): - # pylint: disable=assigning-non-slot - handle = state['handle'] - if handle is not None: - json_str = handle - other = _api_internal._load_json(json_str) - self.handle = other.handle - other.handle = None - else: - self.handle = None - - def same_as(self, other): - """check object identity equality""" - if not isinstance(other, NodeBase): - return False - return self.__hash__() == other.__hash__() - - -# pylint: disable=invalid-name -register_node = register_object -_set_class_node(NodeBase) diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py index 002fd27af0fd..83d4129a7140 100644 --- a/python/tvm/_ffi/object.py +++ b/python/tvm/_ffi/object.py @@ -14,13 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, unused-import """Runtime Object API""" from __future__ import absolute_import import sys import ctypes +from .. import _api_internal from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str +from .object_generic import ObjectGeneric, convert_to_object, const IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError @@ -29,23 +31,77 @@ if _FFI_MODE == "ctypes": raise ImportError() if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object, _set_class_node + from ._cy3.core import _set_class_object from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import _register_object else: - from ._cy2.core import _set_class_object, _set_class_node + from ._cy2.core import _set_class_object from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import _register_object except IMPORT_EXCEPT: # pylint: disable=wrong-import-position,unused-import - from ._ctypes.function import _set_class_object, _set_class_node + from ._ctypes.function import _set_class_object from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import _register_object +def _new_object(cls): + """Helper function for pickle""" + return cls.__new__(cls) + + class Object(_ObjectBase): """Base class for all tvm's runtime objects.""" - pass + def __repr__(self): + return _api_internal._format_str(self) + + def __dir__(self): + fnames = _api_internal._NodeListAttrNames(self) + size = fnames(-1) + return [fnames(i) for i in range(size)] + + def __getattr__(self, name): + try: + return _api_internal._NodeGetAttr(self, name) + except AttributeError: + raise AttributeError( + "%s has no attribute %s" % (str(type(self)), name)) + + def __hash__(self): + return _api_internal._raw_ptr(self) + + def __eq__(self, other): + return self.same_as(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __reduce__(self): + cls = type(self) + return (_new_object, (cls, ), self.__getstate__()) + + def __getstate__(self): + handle = self.handle + if handle is not None: + return {'handle': _api_internal._save_json(self)} + return {'handle': None} + + def __setstate__(self, state): + # pylint: disable=assigning-non-slot + handle = state['handle'] + if handle is not None: + json_str = handle + other = _api_internal._load_json(json_str) + self.handle = other.handle + other.handle = None + else: + self.handle = None + + def same_as(self, other): + """check object identity equality""" + if not isinstance(other, Object): + return False + return self.__hash__() == other.__hash__() def register_object(type_key=None): diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/object_generic.py similarity index 82% rename from python/tvm/_ffi/node_generic.py rename to python/tvm/_ffi/object_generic.py index 8ee7fc5f2b5b..92e73ad79e88 100644 --- a/python/tvm/_ffi/node_generic.py +++ b/python/tvm/_ffi/object_generic.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Common implementation of Node generic related logic""" +"""Common implementation of object generic related logic""" # pylint: disable=unused-import from __future__ import absolute_import @@ -22,7 +22,7 @@ from .. import _api_internal from .base import string_types -# Node base class +# Object base class _CLASS_OBJECTS = None def _set_class_objects(cls): @@ -47,15 +47,15 @@ def _scalar_type_inference(value): return dtype -class NodeGeneric(object): - """Base class for all classes that can be converted to node.""" - def asnode(self): - """Convert value to node""" +class ObjectGeneric(object): + """Base class for all classes that can be converted to object.""" + def asobject(self): + """Convert value to object""" raise NotImplementedError() -def convert_to_node(value): - """Convert a python value to corresponding node type. +def convert_to_object(value): + """Convert a python value to corresponding object type. Parameters ---------- @@ -64,8 +64,8 @@ def convert_to_node(value): Returns ------- - node : Node - The corresponding node value. + obj : Object + The corresponding object value. """ if isinstance(value, _CLASS_OBJECTS): return value @@ -76,7 +76,7 @@ def convert_to_node(value): if isinstance(value, string_types): return _api_internal._str(value) if isinstance(value, (list, tuple)): - value = [convert_to_node(x) for x in value] + value = [convert_to_object(x) for x in value] return _api_internal._Array(*value) if isinstance(value, dict): vlist = [] @@ -85,14 +85,14 @@ def convert_to_node(value): not isinstance(item[0], string_types)): raise ValueError("key of map must already been a container type") vlist.append(item[0]) - vlist.append(convert_to_node(item[1])) + vlist.append(convert_to_object(item[1])) return _api_internal._Map(*vlist) - if isinstance(value, NodeGeneric): - return value.asnode() + if isinstance(value, ObjectGeneric): + return value.asobject() if value is None: return None - raise ValueError("don't know how to convert type %s to node" % type(value)) + raise ValueError("don't know how to convert type %s to object" % type(value)) def const(value, dtype=None): diff --git a/python/tvm/api.py b/python/tvm/api.py index 4d0e3472683c..7395d3524709 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -22,9 +22,8 @@ from ._ffi.base import string_types from ._ffi.object import register_object, Object -from ._ffi.node import register_node, NodeBase -from ._ffi.node import convert_to_node as _convert_to_node -from ._ffi.node_generic import _scalar_type_inference +from ._ffi.object import convert_to_object as _convert_to_object +from ._ffi.object_generic import _scalar_type_inference from ._ffi.function import Function from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs from ._ffi.function import convert_to_tvm_func as _convert_tvm_func @@ -111,7 +110,7 @@ def get_env_func(name): Note ---- - EnvFunc is a Node wrapper around + EnvFunc is a Object wrapper around global function that can be serialized via its name. This can be used to serialize function field in the language. """ @@ -127,16 +126,16 @@ def convert(value): Returns ------- - tvm_val : Node or Function + tvm_val : Object or Function Converted value in TVM """ - if isinstance(value, (Function, NodeBase)): + if isinstance(value, (Function, Object)): return value if callable(value): return _convert_tvm_func(value) - return _convert_to_node(value) + return _convert_to_object(value) def load_json(json_str): @@ -149,7 +148,7 @@ def load_json(json_str): Returns ------- - node : Node + node : Object The loaded tvm node. """ return _api_internal._load_json(json_str) @@ -160,8 +159,8 @@ def save_json(node): Parameters ---------- - node : Node - A TVM Node object to be saved. + node : Object + A TVM object to be saved. Returns ------- diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 4c3c05f75796..81f478c66b92 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -17,11 +17,11 @@ """Arithmetic data structure and utility""" from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from ._ffi.function import _init_api from . import _api_internal -class IntSet(NodeBase): +class IntSet(Object): """Represent a set of integer in one dimension.""" def is_nothing(self): """Whether the set represent nothing""" @@ -32,7 +32,7 @@ def is_everything(self): return _api_internal._IntSetIsEverything(self) -@register_node("arith.IntervalSet") +@register_object("arith.IntervalSet") class IntervalSet(IntSet): """Represent set of continuous interval [min_value, max_value] @@ -49,16 +49,16 @@ def __init__(self, min_value, max_value): _make_IntervalSet, min_value, max_value) -@register_node("arith.ModularSet") -class ModularSet(NodeBase): +@register_object("arith.ModularSet") +class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z """ def __init__(self, coeff, base): self.__init_handle_by_constructor__( _make_ModularSet, coeff, base) -@register_node("arith.ConstIntBound") -class ConstIntBound(NodeBase): +@register_object("arith.ConstIntBound") +class ConstIntBound(Object): """Represent constant integer bound Parameters @@ -245,7 +245,7 @@ def update(self, var, info, override=False): var : tvm.Var The variable. - info : tvm.NodeBase + info : tvm.Object Related information. override : bool diff --git a/python/tvm/attrs.py b/python/tvm/attrs.py index e2a27328fdcc..2963a0e21734 100644 --- a/python/tvm/attrs.py +++ b/python/tvm/attrs.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. """ TVM Attribute module, which is mainly used for defining attributes of operators""" -from ._ffi.node import NodeBase, register_node as _register_tvm_node +from ._ffi.object import Object, register_object from ._ffi.function import _init_api from . import _api_internal -@_register_tvm_node -class Attrs(NodeBase): +@register_object +class Attrs(Object): """Attribute node, which is mainly use for defining attributes of relay operators. Used by function registered in python side, such as compute, schedule and alter_layout. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index f96e28323595..85d2b8514779 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -23,7 +23,7 @@ import warnings from ._ffi.function import Function -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from . import api from . import _api_internal from . import tensor @@ -115,22 +115,22 @@ def exit(self): DumpIR.scope_level -= 1 -@register_node -class BuildConfig(NodeBase): +@register_object +class BuildConfig(Object): """Configuration scope to set a build config option. Note ---- - This object is backed by node system in C++, with arguments that can be + This object is backed by object protocol in C++, with arguments that can be exchanged between python and C++. Do not construct directly, use build_config instead. - The fields that are backed by the C++ node are immutable once an instance - is constructed. See _node_defaults for the fields. + The fields that are backed by the C++ object are immutable once an instance + is constructed. See _object_defaults for the fields. """ - _node_defaults = { + _object_defaults = { "auto_unroll_max_step": 0, "auto_unroll_max_depth": 8, "auto_unroll_max_extent": 0, @@ -191,7 +191,7 @@ def __exit__(self, ptype, value, trace): _api_internal._ExitBuildConfigScope(self) def __setattr__(self, name, value): - if name in BuildConfig._node_defaults: + if name in BuildConfig._object_defaults: raise AttributeError( "'%s' object cannot set attribute '%s'" % (str(type(self)), name)) return super(BuildConfig, self).__setattr__(name, value) @@ -257,7 +257,7 @@ def build_config(**kwargs): The build configuration """ node_args = {k: v if k not in kwargs else kwargs[k] - for k, v in BuildConfig._node_defaults.items()} + for k, v in BuildConfig._object_defaults.items()} config = make.node("BuildConfig", **node_args) if "add_lower_pass" in kwargs: diff --git a/python/tvm/container.py b/python/tvm/container.py index aedbe95b01b2..274fc1f4027c 100644 --- a/python/tvm/container.py +++ b/python/tvm/container.py @@ -16,11 +16,11 @@ # under the License. """Container data structures used in TVM DSL.""" from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from . import _api_internal -@register_node -class Array(NodeBase): +@register_object +class Array(Object): """Array container of TVM. You do not need to create Array explicitly. @@ -50,8 +50,8 @@ def __len__(self): return _api_internal._ArraySize(self) -@register_node -class EnvFunc(NodeBase): +@register_object +class EnvFunc(Object): """Environment function. This is a global function object that can be serialized by its name. @@ -64,13 +64,13 @@ def func(self): return _api_internal._EnvFuncGetPackedFunc(self) -@register_node -class Map(NodeBase): +@register_object +class Map(Object): """Map container of TVM. You do not need to create Map explicitly. Normally python dict will be converted automaticall to Map during tvm function call. - You can use convert to create a dict[NodeBase-> NodeBase] into a Map + You can use convert to create a dict[Object-> Object] into a Map """ def __getitem__(self, k): return _api_internal._MapGetItem(self, k) @@ -87,11 +87,11 @@ def __len__(self): return _api_internal._MapSize(self) -@register_node +@register_object class StrMap(Map): """A special map container that has str as key. - You can use convert to create a dict[str->NodeBase] into a Map. + You can use convert to create a dict[str->Object] into a Map. """ def items(self): """Get the items from the map""" @@ -99,8 +99,8 @@ def items(self): return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] -@register_node -class Range(NodeBase): +@register_object +class Range(Object): """Represent a range in TVM. You do not need to create a Range explicitly. @@ -108,8 +108,8 @@ class Range(NodeBase): """ -@register_node -class LoweredFunc(NodeBase): +@register_object +class LoweredFunc(Object): """Represent a LoweredFunc in TVM.""" MixedFunc = 0 HostFunc = 1 diff --git a/python/tvm/expr.py b/python/tvm/expr.py index c6b3d9b866e2..d147dd622fd3 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -32,7 +32,7 @@ """ # pylint: disable=missing-docstring from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, NodeGeneric, register_node +from ._ffi.object import Object, register_object, ObjectGeneric from ._ffi.runtime_ctypes import TVMType, TypeCode from . import make as _make from . import generic as _generic @@ -178,11 +178,11 @@ def astype(self, dtype): return _generic.cast(self, dtype) -class EqualOp(NodeGeneric, ExprOp): +class EqualOp(ObjectGeneric, ExprOp): """Deferred equal operator. This is used to support sugar that a == b can either - mean NodeBase.same_as or NodeBase.equal. + mean Object.same_as or Object.equal. Parameters ---------- @@ -205,16 +205,16 @@ def __nonzero__(self): def __bool__(self): return self.__nonzero__() - def asnode(self): - """Convert node.""" + def asobject(self): + """Convert object.""" return _make._OpEQ(self.a, self.b) -class NotEqualOp(NodeGeneric, ExprOp): +class NotEqualOp(ObjectGeneric, ExprOp): """Deferred NE operator. This is used to support sugar that a != b can either - mean not NodeBase.same_as or make.NE. + mean not Object.same_as or make.NE. Parameters ---------- @@ -237,8 +237,8 @@ def __nonzero__(self): def __bool__(self): return self.__nonzero__() - def asnode(self): - """Convert node.""" + def asobject(self): + """Convert object.""" return _make._OpNE(self.a, self.b) @@ -246,7 +246,7 @@ class PrimExpr(ExprOp, NodeBase): """Base class of all tvm Expressions""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ - __hash__ = NodeBase.__hash__ + __hash__ = Object.__hash__ class ConstExpr(PrimExpr): @@ -261,7 +261,7 @@ class CmpExpr(PrimExpr): class LogicalExpr(PrimExpr): pass -@register_node("Variable") +@register_object("Variable") class Var(PrimExpr): """Symbolic variable. @@ -278,7 +278,7 @@ def __init__(self, name, dtype): _api_internal._Var, name, dtype) -@register_node +@register_object class Reduce(PrimExpr): """Reduce node. @@ -305,7 +305,7 @@ def __init__(self, combiner, src, rdom, condition, value_index): condition, value_index) -@register_node +@register_object class FloatImm(ConstExpr): """Float constant. @@ -321,7 +321,7 @@ def __init__(self, dtype, value): self.__init_handle_by_constructor__( _make.FloatImm, dtype, value) -@register_node +@register_object class IntImm(ConstExpr): """Int constant. @@ -341,7 +341,7 @@ def __int__(self): return self.value -@register_node +@register_object class UIntImm(ConstExpr): """UInt constant. @@ -358,7 +358,7 @@ def __init__(self, dtype, value): _make.UIntImm, dtype, value) -@register_node +@register_object class StringImm(ConstExpr): """String constant. @@ -382,7 +382,7 @@ def __ne__(self, other): return self.value != other -@register_node +@register_object class Cast(PrimExpr): """Cast expression. @@ -399,7 +399,7 @@ def __init__(self, dtype, value): _make.Cast, dtype, value) -@register_node +@register_object class Add(BinaryOpExpr): """Add node. @@ -416,7 +416,7 @@ def __init__(self, a, b): _make.Add, a, b) -@register_node +@register_object class Sub(BinaryOpExpr): """Sub node. @@ -433,7 +433,7 @@ def __init__(self, a, b): _make.Sub, a, b) -@register_node +@register_object class Mul(BinaryOpExpr): """Mul node. @@ -450,7 +450,7 @@ def __init__(self, a, b): _make.Mul, a, b) -@register_node +@register_object class Div(BinaryOpExpr): """Div node. @@ -467,7 +467,7 @@ def __init__(self, a, b): _make.Div, a, b) -@register_node +@register_object class Mod(BinaryOpExpr): """Mod node. @@ -484,7 +484,7 @@ def __init__(self, a, b): _make.Mod, a, b) -@register_node +@register_object class FloorDiv(BinaryOpExpr): """FloorDiv node. @@ -501,7 +501,7 @@ def __init__(self, a, b): _make.FloorDiv, a, b) -@register_node +@register_object class FloorMod(BinaryOpExpr): """FloorMod node. @@ -518,7 +518,7 @@ def __init__(self, a, b): _make.FloorMod, a, b) -@register_node +@register_object class Min(BinaryOpExpr): """Min node. @@ -535,7 +535,7 @@ def __init__(self, a, b): _make.Min, a, b) -@register_node +@register_object class Max(BinaryOpExpr): """Max node. @@ -552,7 +552,7 @@ def __init__(self, a, b): _make.Max, a, b) -@register_node +@register_object class EQ(CmpExpr): """EQ node. @@ -569,7 +569,7 @@ def __init__(self, a, b): _make.EQ, a, b) -@register_node +@register_object class NE(CmpExpr): """NE node. @@ -586,7 +586,7 @@ def __init__(self, a, b): _make.NE, a, b) -@register_node +@register_object class LT(CmpExpr): """LT node. @@ -603,7 +603,7 @@ def __init__(self, a, b): _make.LT, a, b) -@register_node +@register_object class LE(CmpExpr): """LE node. @@ -620,7 +620,7 @@ def __init__(self, a, b): _make.LE, a, b) -@register_node +@register_object class GT(CmpExpr): """GT node. @@ -637,7 +637,7 @@ def __init__(self, a, b): _make.GT, a, b) -@register_node +@register_object class GE(CmpExpr): """GE node. @@ -654,7 +654,7 @@ def __init__(self, a, b): _make.GE, a, b) -@register_node +@register_object class And(LogicalExpr): """And node. @@ -671,7 +671,7 @@ def __init__(self, a, b): _make.And, a, b) -@register_node +@register_object class Or(LogicalExpr): """Or node. @@ -688,7 +688,7 @@ def __init__(self, a, b): _make.Or, a, b) -@register_node +@register_object class Not(LogicalExpr): """Not node. @@ -702,7 +702,7 @@ def __init__(self, a): _make.Not, a) -@register_node +@register_object class Select(PrimExpr): """Select node. @@ -730,7 +730,7 @@ def __init__(self, condition, true_value, false_value): _make.Select, condition, true_value, false_value) -@register_node +@register_object class Load(PrimExpr): """Load node. @@ -753,7 +753,7 @@ def __init__(self, dtype, buffer_var, index, predicate): _make.Load, dtype, buffer_var, index, predicate) -@register_node +@register_object class Ramp(PrimExpr): """Ramp node. @@ -773,7 +773,7 @@ def __init__(self, base, stride, lanes): _make.Ramp, base, stride, lanes) -@register_node +@register_object class Broadcast(PrimExpr): """Broadcast node. @@ -790,7 +790,7 @@ def __init__(self, value, lanes): _make.Broadcast, value, lanes) -@register_node +@register_object class Shuffle(PrimExpr): """Shuffle node. @@ -807,7 +807,7 @@ def __init__(self, vectors, indices): _make.Shuffle, vectors, indices) -@register_node +@register_object class Call(PrimExpr): """Call node. @@ -842,7 +842,7 @@ def __init__(self, dtype, name, args, call_type, func, value_index): _make.Call, dtype, name, args, call_type, func, value_index) -@register_node +@register_object class Let(PrimExpr): """Let node. diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index bf41c98a7bdd..ede17a154285 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -24,7 +24,7 @@ from . import ir_pass as _pass from . import container as _container from ._ffi.base import string_types -from ._ffi.node import NodeGeneric +from ._ffi.object import ObjectGeneric from ._ffi.runtime_ctypes import TVMType from .expr import Call as _Call @@ -41,7 +41,7 @@ def __exit__(self, ptype, value, trace): self._exit_cb() -class BufferVar(NodeGeneric): +class BufferVar(ObjectGeneric): """Buffer variable with content type, makes load store easily. Do not create it directly, create use IRBuilder. @@ -70,7 +70,7 @@ def __init__(self, builder, buffer_var, content_type): self._buffer_var = buffer_var self._content_type = content_type - def asnode(self): + def asobject(self): return self._buffer_var @property diff --git a/python/tvm/node.py b/python/tvm/object.py similarity index 93% rename from python/tvm/node.py rename to python/tvm/object.py index 1d5b506fabe7..9659d3c89067 100644 --- a/python/tvm/node.py +++ b/python/tvm/object.py @@ -20,6 +20,4 @@ """ # pylint: disable=unused-import from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node - -Node = NodeBase +from ._ffi.object import Object, register_object diff --git a/python/tvm/relay/_module.pyi b/python/tvm/relay/_module.pyi index ae2d199de257..66c994e4400e 100644 --- a/python/tvm/relay/_module.pyi +++ b/python/tvm/relay/_module.pyi @@ -16,7 +16,7 @@ # under the License. from typing import Union, Tuple, Dict, List -from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId +from relay.ir import GlobalId, OperatorId, Item, Object, Span, FileId from relay.ir import ShapeExtension, Operator, Defn -class Module(NodeBase): ... +class Module(Object): ... diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index 30db22cf8314..7f7496b1a407 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """Algebraic data types in Relay.""" -from .base import RelayNode, register_relay_node, NodeBase +from .base import RelayNode, register_relay_node, Object from . import _make from .ty import Type from .expr import Expr, Call @@ -184,7 +184,7 @@ def __init__(self, header, type_vars, constructors): @register_relay_node -class Clause(NodeBase): +class Clause(Object): """Clause for pattern matching in Relay.""" def __init__(self, lhs, rhs): diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 6c690a9b71de..956ad55404bf 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -17,19 +17,19 @@ """Backend code generation engine.""" from __future__ import absolute_import -from ..base import register_relay_node, NodeBase +from ..base import register_relay_node, Object from ... import target as _target from .. import expr as _expr from . import _backend @register_relay_node -class CachedFunc(NodeBase): +class CachedFunc(Object): """Low-level tensor function to back a relay primitive function. """ @register_relay_node -class CCacheKey(NodeBase): +class CCacheKey(Object): """Key in the CompileEngine. Parameters @@ -46,7 +46,7 @@ def __init__(self, source_func, target): @register_relay_node -class CCacheValue(NodeBase): +class CCacheValue(Object): """Value in the CompileEngine, including usage statistics. """ @@ -64,7 +64,7 @@ def _get_cache_key(source_func, target): @register_relay_node -class CompileEngine(NodeBase): +class CompileEngine(Object): """CompileEngine to get lowered code. """ def __init__(self): diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 128edfca0fe1..59d9a8fae43c 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -24,12 +24,12 @@ from .. import _make, analysis, transform from .. import module from ... import nd -from ..base import NodeBase, register_relay_node +from ..base import Object, register_relay_node from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder @register_relay_node -class TupleValue(NodeBase): +class TupleValue(Object): """A tuple value produced by the interpreter.""" def __init__(self, *fields): self.__init_handle_by_constructor__( @@ -54,24 +54,24 @@ def __iter__(self): @register_relay_node -class Closure(NodeBase): +class Closure(Object): """A closure produced by the interpreter.""" @register_relay_node -class RecClosure(NodeBase): +class RecClosure(Object): """A recursive closure produced by the interpreter.""" @register_relay_node -class ConstructorValue(NodeBase): +class ConstructorValue(Object): def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( _make.ConstructorValue, tag, fields, constructor) @register_relay_node -class RefValue(NodeBase): +class RefValue(Object): def __init__(self, value): self.__init_handle_by_constructor__( _make.RefValue, value) @@ -189,7 +189,7 @@ def evaluate(self, expr=None, binds=None): Returns ------- - val : Union[function, NodeBase] + val : Union[function, Object] The evaluation result. """ if binds: diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index eb604a405410..d389803bfeea 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -17,12 +17,13 @@ # pylint: disable=no-else-return, unidiomatic-typecheck """The base node types for the Relay language.""" from __future__ import absolute_import as _abs -from .._ffi.node import NodeBase, register_node as _register_tvm_node +from .._ffi.object import register_object as _register_tvm_node +from .._ffi.object import Object from . import _make from . import _expr from . import _base -NodeBase = NodeBase +Object = Object def register_relay_node(type_key=None): """Register a Relay node type. @@ -52,7 +53,7 @@ def register_relay_attr_node(type_key=None): return _register_tvm_node(type_key) -class RelayNode(NodeBase): +class RelayNode(Object): """Base class of all Relay nodes.""" def astext(self, show_meta_data=True, annotate=None): """Get the text format of the expression. @@ -102,7 +103,7 @@ def __init__(self, name): self.__init_handle_by_constructor__(_make.SourceName, name) @register_relay_node -class Id(NodeBase): +class Id(Object): """Unique identifier(name) used in Var. Guaranteed to be stable across all passes. """ diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi index d264e99e0577..d2d01720f5ff 100644 --- a/python/tvm/relay/expr.pyi +++ b/python/tvm/relay/expr.pyi @@ -17,12 +17,12 @@ from typing import List import tvm -from .base import Span, NodeBase +from .base import Span, Object from .ty import Type, TypeParam from ._analysis import _get_checked_type -class Expr(NodeBase): +class Expr(Object): def checked_type(self): ... diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index ac5387cf2512..a9d877cecd51 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,7 @@ from .. import expr as _expr from .. import transform as _transform from ... import make as _make -from ..base import NodeBase, register_relay_node +from ..base import Object, register_relay_node class QAnnotateKind(object): @@ -53,7 +53,7 @@ def _forward_op(ref_call, args): @register_relay_node("relay.quantize.QConfig") -class QConfig(NodeBase): +class QConfig(Object): """Configure the quantization behavior by setting config variables. Note diff --git a/python/tvm/relay/transform.pyi b/python/tvm/relay/transform.pyi index 343e89976b09..2c466b0576a7 100644 --- a/python/tvm/relay/transform.pyi +++ b/python/tvm/relay/transform.pyi @@ -16,14 +16,14 @@ # under the License. import tvm -from .base import NodeBase +from .base import Object -class PassContext(NodeBase): +class PassContext(Object): def __init__(self): ... -class PassInfo(NodeBase): +class PassInfo(Object): name = ... # type: str opt_level = ... # type: int required = ... # type: list @@ -32,7 +32,7 @@ class PassInfo(NodeBase): # type: (str, int, list) -> None -class Pass(NodeBase): +class Pass(Object): def __init__(self): ... diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi index 5a7ecffb372e..cde851160167 100644 --- a/python/tvm/relay/ty.pyi +++ b/python/tvm/relay/ty.pyi @@ -18,11 +18,11 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The type nodes of the Relay language.""" from enum import IntEnum -from .base import NodeBase, register_relay_node +from .base import Object, register_relay_node from . import _make -class Type(NodeBase): +class Type(Object): """The base type for all Relay types.""" def __eq__(self, other): diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 6b577c456fac..c8fcd7cbd52d 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -17,8 +17,8 @@ """The computation schedule api of TVM.""" from __future__ import absolute_import as _abs from ._ffi.base import string_types -from ._ffi.node import NodeBase, register_node -from ._ffi.node import convert_to_node as _convert_to_node +from ._ffi.object import Object, register_object +from ._ffi.object import convert_to_object as _convert_to_object from ._ffi.function import _init_api, Function from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from . import _api_internal @@ -27,7 +27,7 @@ from . import container as _container def convert(value): - """Convert value to TVM node or function. + """Convert value to TVM object or function. Parameters ---------- @@ -35,19 +35,19 @@ def convert(value): Returns ------- - tvm_val : Node or Function + tvm_val : Object or Function Converted value in TVM """ - if isinstance(value, (Function, NodeBase)): + if isinstance(value, (Function, Object)): return value if callable(value): return _convert_tvm_func(value) - return _convert_to_node(value) + return _convert_to_object(value) -@register_node -class Buffer(NodeBase): +@register_object +class Buffer(Object): """Symbolic data buffer in TVM. Buffer provide a way to represent data layout @@ -156,23 +156,23 @@ def vstore(self, begin, value): return _api_internal._BufferVStore(self, begin, value) -@register_node -class Split(NodeBase): +@register_object +class Split(Object): """Split operation on axis.""" -@register_node -class Fuse(NodeBase): +@register_object +class Fuse(Object): """Fuse operation on axis.""" -@register_node -class Singleton(NodeBase): +@register_object +class Singleton(Object): """Singleton axis.""" -@register_node -class IterVar(NodeBase, _expr.ExprOp): +@register_object +class IterVar(Object, _expr.ExprOp): """Represent iteration variable. IterVar is normally created by Operation, to represent @@ -214,8 +214,8 @@ def create_schedule(ops): return _api_internal._CreateSchedule(ops) -@register_node -class Schedule(NodeBase): +@register_object +class Schedule(Object): """Schedule for all the stages.""" def __getitem__(self, k): if isinstance(k, _tensor.Tensor): @@ -348,8 +348,8 @@ def rfactor(self, tensor, axis, factor_axis=0): return factored[0] if len(factored) == 1 else factored -@register_node -class Stage(NodeBase): +@register_object +class Stage(Object): """A Stage represents schedule for one operation.""" def split(self, parent, factor=None, nparts=None): """Split the stage either by factor providing outer scope, or both diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index 64628d1d4198..6b87fcb1b885 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -30,14 +30,14 @@ assert(st.buffer_var == a) """ from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from . import make as _make -class Stmt(NodeBase): +class Stmt(Object): pass -@register_node +@register_object class LetStmt(Stmt): """LetStmt node. @@ -57,7 +57,7 @@ def __init__(self, var, value, body): _make.LetStmt, var, value, body) -@register_node +@register_object class AssertStmt(Stmt): """AssertStmt node. @@ -77,7 +77,7 @@ def __init__(self, condition, message, body): _make.AssertStmt, condition, message, body) -@register_node +@register_object class ProducerConsumer(Stmt): """ProducerConsumer node. @@ -97,7 +97,7 @@ def __init__(self, func, is_producer, body): _make.ProducerConsumer, func, is_producer, body) -@register_node +@register_object class For(Stmt): """For node. @@ -137,7 +137,7 @@ def __init__(self, for_type, device_api, body) -@register_node +@register_object class Store(Stmt): """Store node. @@ -160,7 +160,7 @@ def __init__(self, buffer_var, value, index, predicate): _make.Store, buffer_var, value, index, predicate) -@register_node +@register_object class Provide(Stmt): """Provide node. @@ -183,7 +183,7 @@ def __init__(self, func, value_index, value, args): _make.Provide, func, value_index, value, args) -@register_node +@register_object class Allocate(Stmt): """Allocate node. @@ -215,7 +215,7 @@ def __init__(self, extents, condition, body) -@register_node +@register_object class AttrStmt(Stmt): """AttrStmt node. @@ -238,7 +238,7 @@ def __init__(self, node, attr_key, value, body): _make.AttrStmt, node, attr_key, value, body) -@register_node +@register_object class Free(Stmt): """Free node. @@ -252,7 +252,7 @@ def __init__(self, buffer_var): _make.Free, buffer_var) -@register_node +@register_object class Realize(Stmt): """Realize node. @@ -288,7 +288,7 @@ def __init__(self, bounds, condition, body) -@register_node +@register_object class SeqStmt(Stmt): """Sequence of statements. @@ -308,7 +308,7 @@ def __len__(self): return len(self.seq) -@register_node +@register_object class IfThenElse(Stmt): """IfThenElse node. @@ -328,7 +328,7 @@ def __init__(self, condition, then_case, else_case): _make.IfThenElse, condition, then_case, else_case) -@register_node +@register_object class Evaluate(Stmt): """Evaluate node. @@ -342,7 +342,7 @@ def __init__(self, value): _make.Evaluate, value) -@register_node +@register_object class Prefetch(Stmt): """Prefetch node. diff --git a/python/tvm/target.py b/python/tvm/target.py index afddd5f1fd59..c2d37529040b 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -59,7 +59,7 @@ import warnings from ._ffi.base import _LIB_NAME -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from . import _api_internal try: @@ -80,8 +80,8 @@ def _merge_opts(opts, new_opts): return opts -@register_node -class Target(NodeBase): +@register_object +class Target(Object): """Target device information, use through TVM API. Note @@ -97,7 +97,7 @@ class Target(NodeBase): """ def __new__(cls): # Always override new to enable class - obj = NodeBase.__new__(cls) + obj = Object.__new__(cls) obj._keys = None obj._options = None obj._libs = None @@ -146,8 +146,8 @@ def __exit__(self, ptype, value, trace): _api_internal._ExitTargetScope(self) -@register_node -class GenericFunc(NodeBase): +@register_object +class GenericFunc(Object): """GenericFunc node reference. This represents a generic function that may be specialized for different targets. When this object is called, a specialization is chosen based on the current target. diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index e4a2f4f76e7b..e4c36c11120b 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -17,13 +17,14 @@ """Tensor and Operation class for computation declaration.""" # pylint: disable=invalid-name from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node +from ._ffi.object import Object, register_object, ObjectGeneric, \ + convert_to_object from . import _api_internal from . import make as _make from . import expr as _expr -class TensorSlice(NodeGeneric, _expr.ExprOp): +class TensorSlice(ObjectGeneric, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" def __init__(self, tensor, indices): @@ -37,8 +38,8 @@ def __getitem__(self, indices): indices = (indices,) return TensorSlice(self.tensor, self.indices + indices) - def asnode(self): - """Convert slice to node.""" + def asobject(self): + """Convert slice to object.""" return self.tensor(*self.indices) @property @@ -46,23 +47,23 @@ def dtype(self): """Data content of the tensor.""" return self.tensor.dtype -@register_node -class TensorIntrinCall(NodeBase): +@register_object +class TensorIntrinCall(Object): """Intermediate structure for calling a tensor intrinsic.""" itervar_cls = None -@register_node -class Tensor(NodeBase, _expr.ExprOp): +@register_object +class Tensor(Object, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: raise ValueError("Need to provide %d index in tensor slice" % ndim) - indices = convert_to_node(indices) + indices = convert_to_object(indices) args = [] for x in indices: if isinstance(x, _expr.PrimExpr): @@ -127,7 +128,7 @@ def name(self): -class Operation(NodeBase): +class Operation(Object): """Represent an operation that generates a tensor""" def output(self, index): @@ -156,12 +157,12 @@ def input_tensors(self): return _api_internal._OpInputTensors(self) -@register_node +@register_object class PlaceholderOp(Operation): """Placeholder operation.""" -@register_node +@register_object class BaseComputeOp(Operation): """Compute operation.""" @property @@ -175,18 +176,18 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") -@register_node +@register_object class ComputeOp(BaseComputeOp): """Scalar operation.""" pass -@register_node +@register_object class TensorComputeOp(BaseComputeOp): """Tensor operation.""" -@register_node +@register_object class ScanOp(Operation): """Scan operation.""" @property @@ -195,12 +196,12 @@ def scan_axis(self): return self.__getattr__("scan_axis") -@register_node +@register_object class ExternOp(Operation): """External operation.""" -@register_node +@register_object class HybridOp(Operation): """Hybrid operation.""" @property @@ -209,8 +210,8 @@ def axis(self): return self.__getattr__("axis") -@register_node -class Layout(NodeBase): +@register_object +class Layout(Object): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. @@ -269,8 +270,8 @@ def factor_of(self, axis): return _api_internal._LayoutFactorOf(self, axis) -@register_node -class BijectiveLayout(NodeBase): +@register_object +class BijectiveLayout(Object): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 378cfe51a7b7..4665ccfd6204 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -24,7 +24,7 @@ from . import tensor as _tensor from . import schedule as _schedule from .build_module import current_build_config -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object def _get_region(tslice): @@ -41,8 +41,8 @@ def _get_region(tslice): region.append(_make.range_by_min_extent(begin, 1)) return region -@register_node -class TensorIntrin(NodeBase): +@register_object +class TensorIntrin(Object): """Tensor intrinsic functions for certain computation. See Also diff --git a/tests/python/unittest/test_pass_inject_double_buffer.py b/tests/python/unittest/test_pass_inject_double_buffer.py index dc517e2ee28b..aa569cea8665 100644 --- a/tests/python/unittest/test_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_pass_inject_double_buffer.py @@ -28,7 +28,7 @@ def test_double_buffer(): with ib.for_range(0, n) as i: B = ib.allocate("float32", m, name="B", scope="shared") with ib.new_scope(): - ib.scope_attr(B.asnode(), "double_buffer_scope", 1) + ib.scope_attr(B.asobject(), "double_buffer_scope", 1) with ib.for_range(0, m) as j: B[j] = A[i * 4 + j] with ib.for_range(0, m) as j: @@ -39,7 +39,7 @@ def test_double_buffer(): stmt = tvm.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert stmt.body.body.extents[0].value == 2 - f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True) + f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.ir_pass.ThreadSync(f, "shared") count = [0] def count_sync(op): diff --git a/tests/python/unittest/test_pass_inject_vthread.py b/tests/python/unittest/test_pass_inject_vthread.py index a3d059787ab8..08e261b68f6d 100644 --- a/tests/python/unittest/test_pass_inject_vthread.py +++ b/tests/python/unittest/test_pass_inject_vthread.py @@ -32,7 +32,7 @@ def get_vthread(name): ib.scope_attr(ty, "virtual_thread", nthread) B = ib.allocate("float32", m, name="B", scope="shared") B[i] = A[i * nthread + tx] - bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) + bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) ib.emit(tvm.call_extern("int32", "Run", bbuffer.access_ptr("r"), tvm.call_pure_intrin("int32", "tvm_context_id"))) @@ -60,9 +60,9 @@ def get_vthread(name): A = ib.allocate("float32", m, name="A", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared") C = ib.allocate("float32", m, name="C", scope="shared") - cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode()) - abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode()) - bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) + cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asobject()) + abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asobject()) + bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) A[tx] = tx + 1.0 B[ty] = ty + 1.0 ib.emit(tvm.call_extern("int32", "Run", diff --git a/tests/python/unittest/test_pass_storage_flatten.py b/tests/python/unittest/test_pass_storage_flatten.py index 02edfe7d3261..da32f60f69fb 100644 --- a/tests/python/unittest/test_pass_storage_flatten.py +++ b/tests/python/unittest/test_pass_storage_flatten.py @@ -79,7 +79,7 @@ def test_flatten_double_buffer(): with ib.for_range(0, n) as i: B = ib.allocate("float32", m, name="B", scope="shared") with ib.new_scope(): - ib.scope_attr(B.asnode(), "double_buffer_scope", 1) + ib.scope_attr(B.asobject(), "double_buffer_scope", 1) with ib.for_range(0, m) as j: B[j] = A[i * 4 + j] with ib.for_range(0, m) as j: @@ -91,7 +91,7 @@ def test_flatten_double_buffer(): stmt = tvm.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert stmt.body.body.extents[0].value == 2 - f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True) + f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.ir_pass.ThreadSync(f, "shared") count = [0] def count_sync(op): From 5d55be419e72a7c91b57f44ce8b5aceb23c6da95 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 9 Jan 2020 23:51:23 +0000 Subject: [PATCH 3/3] rebase --- python/tvm/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index d147dd622fd3..71c0aecd1f6a 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -242,7 +242,7 @@ def asobject(self): return _make._OpNE(self.a, self.b) -class PrimExpr(ExprOp, NodeBase): +class PrimExpr(ExprOp, Object): """Base class of all tvm Expressions""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__