Skip to content

Commit

Permalink
replace TensorObj and TensorValue with NDArray
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jan 7, 2020
1 parent dc30880 commit 1e61991
Show file tree
Hide file tree
Showing 20 changed files with 215 additions and 439 deletions.
105 changes: 30 additions & 75 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,11 @@
#include <tvm/build_module.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>

namespace tvm {
namespace relay {

/*!
* \brief A Relay value.
*/
class Value;

/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
Expand All @@ -65,39 +61,21 @@ class Value;
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
runtime::TypedPackedFunc<Value(Expr)>
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);

/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode);
};

class Value : public ObjectRef {
public:
Value() {}
explicit Value(ObjectPtr<Object> n) : ObjectRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(get());
}

using ContainerType = ValueNode;
};

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
class ClosureNode : public Object {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, Value> env;
tvm::Map<Var, ObjectRef> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
Expand All @@ -111,22 +89,22 @@ class ClosureNode : public ValueNode {
v->Visit("func", &func);
}

TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);

static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
};

class Closure : public Value {
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
};

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;

/*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode {
class RecClosureNode : public Object {
public:
/*! \brief The closure. */
Closure clos;
Expand All @@ -143,89 +121,66 @@ class RecClosureNode : public ValueNode {
TVM_DLL static RecClosure make(Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
};

class RecClosure : public Value {
class RecClosure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
};

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
tvm::Array<Value> fields;
struct TupleValueNode : Object {
tvm::Array<ObjectRef> fields;

TupleValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }

TVM_DLL static TupleValue make(tvm::Array<Value> value);
TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);

static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode);
};

class TupleValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode);
};

/*! \brief A tensor value. */
class TensorValue;

/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;

TensorValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }

/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);

static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};

class TensorValue : public Value {
class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
};

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : ValueNode {
mutable Value value;
struct RefValueNode : Object {
mutable ObjectRef value;

RefValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(Value val);
TVM_DLL static RefValue make(ObjectRef val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
};

class RefValue : public Value {
class RefValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
};

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueNode : ValueNode {
struct ConstructorValueNode : Object {
int32_t tag;

tvm::Array<Value> fields;
tvm::Array<ObjectRef> fields;

/*! \brief Optional field tracking ADT constructor. */
Constructor constructor;
Expand All @@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode {
}

TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<Value> fields,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
};

class ConstructorValue : public Value {
class ConstructorValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
};

} // namespace relay
Expand Down
19 changes: 0 additions & 19 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,6 @@ namespace tvm {
namespace runtime {
namespace vm {

/*! \brief An object containing an NDArray. */
class TensorObj : public Object {
public:
/*! \brief The NDArray. */
NDArray data;

static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
static constexpr const char* _type_key = "vm.Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
};

/*! \brief reference to tensor. */
class Tensor : public ObjectRef {
public:
explicit Tensor(NDArray data);

TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
};

/*! \brief An object representing a closure. */
class ClosureObj : public Object {
public:
Expand Down
60 changes: 9 additions & 51 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,13 @@
from . import _backend
from .. import _make, analysis, transform
from .. import module
from ... import register_func, nd
from ... import nd
from ..base import NodeBase, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
from . import _vm

class Value(NodeBase):
"""Base class of all values.
"""
@staticmethod
@register_func("relay.from_scalar")
def from_scalar(value, dtype=None):
"""Convert a Python scalar to a Relay scalar."""
return TensorValue(const(value, dtype).data)

def to_vm(self):
return _vm._ValueToVM(self)


@register_relay_node
class TupleValue(Value):
class TupleValue(NodeBase):
"""A tuple value produced by the interpreter."""
def __init__(self, *fields):
self.__init_handle_by_constructor__(
Expand All @@ -68,60 +54,32 @@ def __iter__(self):


@register_relay_node
class Closure(Value):
class Closure(NodeBase):
"""A closure produced by the interpreter."""


@register_relay_node
class RecClosure(Value):
class RecClosure(NodeBase):
"""A recursive closure produced by the interpreter."""


@register_relay_node
class ConstructorValue(Value):
class ConstructorValue(NodeBase):
def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__(
_make.ConstructorValue, tag, fields, constructor)


@register_relay_node
class TensorValue(Value):
"""A Tensor value produced by the interpreter."""

def __init__(self, data):
"""Allocate a new TensorValue and copy the data from `array` into
the new array.
"""
if isinstance(data, np.ndarray):
data = nd.array(data)

self.__init_handle_by_constructor__(
_make.TensorValue, data)

def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy()

def __eq__(self, other):
return self.data == other.data

def __repr__(self):
return repr(self.data)

def __str__(self):
return str(self.data)


@register_relay_node
class RefValue(Value):
class RefValue(NodeBase):
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.RefValue, value)


def _arg_to_ast(mod, arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0)))
if isinstance(arg, nd.NDArray):
return Constant(arg.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, tuple):
Expand Down Expand Up @@ -231,7 +189,7 @@ def evaluate(self, expr=None, binds=None):
Returns
-------
val : Union[function, Value]
val : Union[function, NodeBase]
The evaluation result.
"""
if binds:
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,26 @@
from . import vmobj as _obj
from .interpreter import Executor

Tensor = _obj.Tensor
ADT = _obj.ADT

def _convert(arg, cargs):
if isinstance(arg, _expr.Constant):
cargs.append(_obj.Tensor(arg.data))
cargs.append(arg.data)
elif isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg))
elif isinstance(arg, np.ndarray):
nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0))
cargs.append(nd_arr)
elif isinstance(arg, tvm.nd.NDArray):
cargs.append(arg)
elif isinstance(arg, (tuple, list)):
field_args = []
for field in arg:
_convert(field, field_args)
cargs.append(_obj.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = _obj.Tensor(np.array(arg, dtype=dtype))
value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0))
cargs.append(value)
else:
raise TypeError("Unsupported type: %s" % (type(arg)))
Expand Down
Loading

0 comments on commit 1e61991

Please sign in to comment.