Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime][refactor] Unify vm and interpreter objects #4693

Merged
merged 4 commits into from
Jan 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions apps/lldb/tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,9 @@ def __lldb_init_module(debugger, _):
"tvm::relay::Span",
"tvm::relay::TempExpr",
"tvm::relay::TensorType",
"tvm::relay::TensorValue",
"tvm::relay::Tuple",
"tvm::relay::TupleGetItem",
"tvm::relay::TupleType",
"tvm::relay::TupleValue",
"tvm::relay::Type",
"tvm::relay::TypeCall",
"tvm::relay::TypeConstraint",
Expand Down
98 changes: 32 additions & 66 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
* Given a Relay module, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
* system to Python for introspection and debugging.
* can be produced by a Relay program and are exposed via TVM's object
* protocol to Python for introspection and debugging.
*
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing.
Expand All @@ -38,6 +38,8 @@
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/vm.h>

namespace tvm {
namespace relay {
Expand All @@ -64,11 +66,8 @@ namespace relay {
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(IRModule mod, DLContext context, Target target);

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public Object {
/*! \brief The container type of Closures used by the interpreter. */
class InterpreterClosureObj : public runtime::vm::ClosureObj {
public:
/*! \brief The set of free variables in the closure.
*
Expand All @@ -82,102 +81,69 @@ class ClosureNode : public Object {
*/
Function func;

ClosureNode() {}
InterpreterClosureObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env);
v->Visit("func", &func);
}

TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);

static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
static constexpr const char* _type_key = "interpreter.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj);
};

class Closure : public ObjectRef {
class InterpreterClosure : public runtime::vm::Closure {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure,
InterpreterClosureObj);
};

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;

/*! \brief The container type of RecClosure. */
class RecClosureNode : public Object {
class RecClosureObj : public Object {
public:
/*! \brief The closure. */
Closure clos;
InterpreterClosure clos;
/*! \brief variable the closure bind to. */
Var bind;

RecClosureNode() {}
RecClosureObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}

TVM_DLL static RecClosure make(Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
static constexpr const char* _type_key = "interpreter.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object);
};

class RecClosure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
};

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : Object {
tvm::Array<ObjectRef> fields;

TupleValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }

TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);

static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};

class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
TVM_DLL RecClosure(InterpreterClosure clos, Var bind);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj);
};

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : Object {
struct RefValueObj : Object {
mutable ObjectRef value;

RefValueNode() {}
RefValueObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(ObjectRef val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
};

class RefValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
TVM_DLL RefValue(ObjectRef val);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj);
};

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueNode : Object {
struct ConstructorValueObj : Object {
int32_t tag;

tvm::Array<ObjectRef> fields;
Expand All @@ -191,17 +157,17 @@ struct ConstructorValueNode : Object {
v->Visit("constructor", &constructor);
}

TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object);
};

class ConstructorValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
TVM_DLL ConstructorValue(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};

} // namespace relay
Expand Down
7 changes: 3 additions & 4 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ namespace runtime {
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kVMTensor = 1,
kVMClosure = 2,
kVMADT = 3,
kRuntimeModule = 4,
kClosure = 1,
kVMADT = 2,
kRuntimeModule = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down
38 changes: 30 additions & 8 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,58 @@
#define TVM_RUNTIME_VM_H_

#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace runtime {
namespace vm {

/*! \brief An object representing a closure. */
/*!
* \brief An object representing a closure. This object is used by both the
* Relay VM and interpreter.
*/
class ClosureObj : public Object {
public:
/*! \brief The index into the VM function table. */
static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};

/*!
* \brief An object representing a vm closure.
*/
class VMClosureObj : public ClosureObj {
public:
/*!
* \brief The index into the function list. The function could be any
* function object that is compatible to the VM runtime.
*/
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;

static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
zhiics marked this conversation as resolved.
Show resolved Hide resolved
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
class VMClosure : public Closure {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars);

TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
VMClosure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj);
};

/*! \brief Magic number for NDArray list file */
Expand Down
57 changes: 56 additions & 1 deletion python/tvm/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
"""Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object
from tvm import ndarray as _nd
from . import _api_internal
from ._ffi.object import Object, register_object, getitem_helper
from ._ffi.function import _init_api

@register_object
class Array(Object):
Expand Down Expand Up @@ -114,3 +116,56 @@ class LoweredFunc(Object):
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2


@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
zhiics marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)

@property
def tag(self):
return _GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)

def __len__(self):
return _GetADTSize(self)


def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)


_init_api("tvm.container")
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from . import feature
from .backend import vm
from .backend import profiler_vm
from .backend import vmobj

# Root operators
from .op import Op
Expand Down
20 changes: 0 additions & 20 deletions python/tvm/relay/backend/_vmobj.py

This file was deleted.

Loading