Skip to content

Commit

Permalink
unify vm and interpreter objects
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jan 13, 2020
1 parent bd17baa commit 88c6be4
Show file tree
Hide file tree
Showing 27 changed files with 389 additions and 479 deletions.
2 changes: 0 additions & 2 deletions apps/lldb/tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,9 @@ def __lldb_init_module(debugger, _):
"tvm::relay::Span",
"tvm::relay::TempExpr",
"tvm::relay::TensorType",
"tvm::relay::TensorValue",
"tvm::relay::Tuple",
"tvm::relay::TupleGetItem",
"tvm::relay::TupleType",
"tvm::relay::TupleValue",
"tvm::relay::Type",
"tvm::relay::TypeCall",
"tvm::relay::TypeConstraint",
Expand Down
90 changes: 17 additions & 73 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
* Given a Relay module, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
* system to Python for introspection and debugging.
* can be produced by a Relay program and are exposed via tvm's object
* protocol to Python for introspection and debugging.
*
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing.
Expand All @@ -38,6 +38,8 @@
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/common_object.h>

namespace tvm {
namespace relay {
Expand All @@ -64,100 +66,42 @@ namespace relay {
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public Object {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, ObjectRef> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
*/
Function func;

ClosureNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env);
v->Visit("func", &func);
}

TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);

static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
};

class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
};

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;

/*! \brief The container type of RecClosure. */
class RecClosureNode : public Object {
class RecClosureObj : public Object {
public:
/*! \brief The closure. */
Closure clos;
runtime::Closure clos;
/*! \brief variable the closure bind to. */
Var bind;

RecClosureNode() {}
RecClosureObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}

TVM_DLL static RecClosure make(Closure clos, Var bind);
TVM_DLL static RecClosure make(runtime::Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object);
};

class RecClosure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
};

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : Object {
tvm::Array<ObjectRef> fields;

TupleValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }

TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);

static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};

class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj);
};

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : Object {
struct RefValueObj : Object {
mutable ObjectRef value;

RefValueNode() {}
RefValueObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
Expand All @@ -166,18 +110,18 @@ struct RefValueNode : Object {
TVM_DLL static RefValue make(ObjectRef val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
};

class RefValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj);
};

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueNode : Object {
struct ConstructorValueObj : Object {
int32_t tag;

tvm::Array<ObjectRef> fields;
Expand All @@ -196,12 +140,12 @@ struct ConstructorValueNode : Object {
Constructor construtor = {});

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object);
};

class ConstructorValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};

} // namespace relay
Expand Down
71 changes: 71 additions & 0 deletions include/tvm/runtime/common_object.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/common_object.h
* \brief The objects that are commonly used by different runtime, i.e. Relay VM
* and interpreter.
*/
#ifndef TVM_RUNTIME_COMMON_OBJECT_H_
#define TVM_RUNTIME_COMMON_OBJECT_H_

#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <utility>
#include <vector>

namespace tvm {
namespace runtime {

/*!
* \brief An object representing a closure. This object is used by both the
* Relay VM and interpreter.
*/
class ClosureObj : public Object {
public:
/*!
* \brief The index into the function list. The function could be any
* function object that is compatible to a certain runtime, i.e. VM or
* interpreter.
*/
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;

static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<ClosureObj>();
ptr->func_index = func_index;
ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
}

TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_COMMON_OBJECT_H_
4 changes: 2 additions & 2 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
uint32_t size;
// The fields of the structure follows directly in memory.

static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
static constexpr const char* _type_key = "vm.ADT";
static constexpr const uint32_t _type_index = TypeIndex::kADT;
static constexpr const char* _type_key = "ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);

private:
Expand Down
7 changes: 3 additions & 4 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ namespace runtime {
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kVMTensor = 1,
kVMClosure = 2,
kVMADT = 3,
kRuntimeModule = 4,
kClosure = 1,
kADT = 2,
kRuntimeModule = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down
22 changes: 1 addition & 21 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_RUNTIME_VM_H_

#include <tvm/runtime/object.h>
#include <tvm/runtime/common_object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory>
Expand All @@ -36,27 +37,6 @@ namespace tvm {
namespace runtime {
namespace vm {

/*! \brief An object representing a closure. */
class ClosureObj : public Object {
public:
/*! \brief The index into the VM function table. */
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;

static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars);

TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};

/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;

Expand Down
57 changes: 56 additions & 1 deletion python/tvm/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
"""Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object
from tvm import ndarray as _nd
from . import _api_internal
from ._ffi.object import Object, register_object, getitem_helper
from ._ffi.function import _init_api

@register_object
class Array(Object):
Expand Down Expand Up @@ -114,3 +116,56 @@ class LoweredFunc(Object):
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2


@register_object
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)

@property
def tag(self):
return _GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)

def __len__(self):
return _GetADTSize(self)


def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)


_init_api("tvm.container")
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from . import feature
from .backend import vm
from .backend import profiler_vm
from .backend import vmobj

# Root operators
from .op import Op
Expand Down
Loading

0 comments on commit 88c6be4

Please sign in to comment.