Skip to content

Commit

Permalink
closure base
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jan 17, 2020
1 parent 040a532 commit b4347a6
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 92 deletions.
58 changes: 40 additions & 18 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,44 @@ namespace relay {
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(IRModule mod, DLContext context, Target target);

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;
/*! \brief The container type of Closures used by the interpreter. */
class InterpreterClosureObj : public runtime::vm::ClosureObj {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, ObjectRef> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
*/
Function func;

InterpreterClosureObj() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env);
v->Visit("func", &func);
}

static constexpr const char* _type_key = "interpreter.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj);
};

class InterpreterClosure : public runtime::vm::Closure {
public:
TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure,
InterpreterClosureObj);
};

/*! \brief The container type of RecClosure. */
class RecClosureObj : public Object {
public:
/*! \brief The closure. */
runtime::vm::Closure clos;
InterpreterClosure clos;
/*! \brief variable the closure bind to. */
Var bind;

Expand All @@ -84,20 +114,16 @@ class RecClosureObj : public Object {
v->Visit("bind", &bind);
}

TVM_DLL static RecClosure make(runtime::vm::Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
static constexpr const char* _type_key = "interpreter.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object);
};

class RecClosure : public ObjectRef {
public:
TVM_DLL RecClosure(InterpreterClosure clos, Var bind);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj);
};

/*! \brief A reference value. */
class RefValue;

struct RefValueObj : Object {
mutable ObjectRef value;

Expand All @@ -107,20 +133,16 @@ struct RefValueObj : Object {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(ObjectRef val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
};

class RefValue : public ObjectRef {
public:
TVM_DLL RefValue(ObjectRef val);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj);
};

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueObj : Object {
int32_t tag;

Expand All @@ -135,16 +157,16 @@ struct ConstructorValueObj : Object {
v->Visit("constructor", &constructor);
}

TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object);
};

class ConstructorValue : public ObjectRef {
public:
TVM_DLL ConstructorValue(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace runtime {
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kVMClosure = 1,
kClosure = 1,
kVMADT = 2,
kRuntimeModule = 3,
kStaticIndexEnd,
Expand Down
34 changes: 21 additions & 13 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,40 @@ namespace vm {
* Relay VM and interpreter.
*/
class ClosureObj : public Object {
public:
static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};

/*!
* \brief An object representing a vm closure.
*/
class VMClosureObj : public ClosureObj {
public:
/*!
* \brief The index into the function list. The function could be any
* function object that is compatible to a certain runtime, i.e. VM or
* interpreter.
* function object that is compatible to the VM runtime.
*/
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;

static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj);
};

/*! \brief reference to closure. */
class Closure : public ObjectRef {
class VMClosure : public Closure {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<ClosureObj>();
ptr->func_index = func_index;
ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
}

TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
VMClosure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj);
};

/*! \brief Magic number for NDArray list file */
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

import tvm
from tvm import autotvm, container
from tvm.object import Object
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from tvm._ffi.object import Object
from . import _vm
from .interpreter import Executor

Expand Down
Loading

0 comments on commit b4347a6

Please sign in to comment.