Skip to content

Commit

Permalink
[runtime][refactor] Unify vm and interpreter objects (apache#4693)
Browse files Browse the repository at this point in the history
* unify vm and interpreter objects

* move closure back vm

* adt/closure back to vm.adt/vm.closure

* closure base
  • Loading branch information
zhiics authored and alexwong committed Feb 28, 2020
1 parent c6848c6 commit 174ddbe
Show file tree
Hide file tree
Showing 26 changed files with 370 additions and 463 deletions.
2 changes: 0 additions & 2 deletions apps/lldb/tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,9 @@ def __lldb_init_module(debugger, _):
"tvm::relay::Span",
"tvm::relay::TempExpr",
"tvm::relay::TensorType",
"tvm::relay::TensorValue",
"tvm::relay::Tuple",
"tvm::relay::TupleGetItem",
"tvm::relay::TupleType",
"tvm::relay::TupleValue",
"tvm::relay::Type",
"tvm::relay::TypeCall",
"tvm::relay::TypeConstraint",
Expand Down
98 changes: 32 additions & 66 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
* Given a Relay module, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
* system to Python for introspection and debugging.
* can be produced by a Relay program and are exposed via TVM's object
* protocol to Python for introspection and debugging.
*
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing.
Expand All @@ -38,6 +38,8 @@
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/vm.h>

namespace tvm {
namespace relay {
Expand All @@ -64,11 +66,8 @@ namespace relay {
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(IRModule mod, DLContext context, Target target);

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public Object {
/*! \brief The container type of Closures used by the interpreter. */
class InterpreterClosureObj : public runtime::vm::ClosureObj {
public:
/*! \brief The set of free variables in the closure.
*
Expand All @@ -82,102 +81,69 @@ class ClosureNode : public Object {
*/
Function func;

ClosureNode() {}
InterpreterClosureObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env);
v->Visit("func", &func);
}

TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);

static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
static constexpr const char* _type_key = "interpreter.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj);
};

class Closure : public ObjectRef {
class InterpreterClosure : public runtime::vm::Closure {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure,
InterpreterClosureObj);
};

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;

/*! \brief The container type of RecClosure. */
class RecClosureNode : public Object {
class RecClosureObj : public Object {
public:
/*! \brief The closure. */
Closure clos;
InterpreterClosure clos;
/*! \brief variable the closure bind to. */
Var bind;

RecClosureNode() {}
RecClosureObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}

TVM_DLL static RecClosure make(Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
static constexpr const char* _type_key = "interpreter.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object);
};

class RecClosure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
};

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : Object {
tvm::Array<ObjectRef> fields;

TupleValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }

TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);

static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};

class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
TVM_DLL RecClosure(InterpreterClosure clos, Var bind);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj);
};

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : Object {
struct RefValueObj : Object {
mutable ObjectRef value;

RefValueNode() {}
RefValueObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(ObjectRef val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
};

class RefValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
TVM_DLL RefValue(ObjectRef val);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj);
};

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueNode : Object {
struct ConstructorValueObj : Object {
int32_t tag;

tvm::Array<ObjectRef> fields;
Expand All @@ -191,17 +157,17 @@ struct ConstructorValueNode : Object {
v->Visit("constructor", &constructor);
}

TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object);
};

class ConstructorValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
TVM_DLL ConstructorValue(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};

} // namespace relay
Expand Down
7 changes: 3 additions & 4 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ namespace runtime {
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kVMTensor = 1,
kVMClosure = 2,
kVMADT = 3,
kRuntimeModule = 4,
kClosure = 1,
kVMADT = 2,
kRuntimeModule = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down
38 changes: 30 additions & 8 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,58 @@
#define TVM_RUNTIME_VM_H_

#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace runtime {
namespace vm {

/*! \brief An object representing a closure. */
/*!
* \brief An object representing a closure. This object is used by both the
* Relay VM and interpreter.
*/
class ClosureObj : public Object {
public:
/*! \brief The index into the VM function table. */
static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};

/*!
* \brief An object representing a vm closure.
*/
class VMClosureObj : public ClosureObj {
public:
/*!
* \brief The index into the function list. The function could be any
* function object that is compatible to the VM runtime.
*/
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;

static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
class VMClosure : public Closure {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars);

TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
VMClosure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj);
};

/*! \brief Magic number for NDArray list file */
Expand Down
57 changes: 56 additions & 1 deletion python/tvm/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
"""Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object
from tvm import ndarray as _nd
from . import _api_internal
from ._ffi.object import Object, register_object, getitem_helper
from ._ffi.function import _init_api

@register_object
class Array(Object):
Expand Down Expand Up @@ -114,3 +116,56 @@ class LoweredFunc(Object):
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2


@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)

@property
def tag(self):
return _GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)

def __len__(self):
return _GetADTSize(self)


def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)


_init_api("tvm.container")
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from . import feature
from .backend import vm
from .backend import profiler_vm
from .backend import vmobj

# Root operators
from .op import Op
Expand Down
20 changes: 0 additions & 20 deletions python/tvm/relay/backend/_vmobj.py

This file was deleted.

Loading

0 comments on commit 174ddbe

Please sign in to comment.