Skip to content

Commit

Permalink
Merge pull request #5 from jroesch/vm_obj_pf
Browse files Browse the repository at this point in the history
Add the ability to directly pass VMObjects back and forth from PackedFuncs.
  • Loading branch information
jroesch authored Feb 5, 2019
2 parents 5c958ec + 7ae24fb commit 8661269
Show file tree
Hide file tree
Showing 16 changed files with 175 additions and 23 deletions.
3 changes: 2 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
url = https://github.com/dmlc/dmlc-core
[submodule "HalideIR"]
path = 3rdparty/HalideIR
url = https://github.com/dmlc/HalideIR
url = https://github.com/jroesch/HalideIR
branch = vm
[submodule "dlpack"]
path = 3rdparty/dlpack
url = https://github.com/dmlc/dlpack
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
22 changes: 15 additions & 7 deletions include/tvm/relay/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,15 @@ struct VMTensorCell : public VMObjectCell {
: VMObjectCell(VMObjectTag::kTensor), data(data) {}
};

using VMObject = std::shared_ptr<VMObjectCell>;
struct VMObject {
std::shared_ptr<VMObjectCell> ptr;
VMObject(std::shared_ptr<VMObjectCell> ptr) : ptr(ptr) {}
VMObject() : ptr() {}
VMObject(const VMObject& obj) : ptr(obj.ptr) {}
VMObjectCell* operator->() {
return this->ptr.operator->();
}
};

struct VMDatatypeCell : public VMObjectCell {
size_t tag;
Expand All @@ -53,24 +61,24 @@ struct VMDatatypeCell : public VMObjectCell {
};


VMObject VMTensor(const tvm::runtime::NDArray& data) {
inline VMObject VMTensor(const tvm::runtime::NDArray& data) {
auto ptr = std::make_shared<VMTensorCell>(data);
return std::dynamic_pointer_cast<VMObjectCell>(ptr);
}

VMObject VMDatatype(size_t tag, const std::vector<VMObject>& fields) {
inline VMObject VMDatatype(size_t tag, const std::vector<VMObject>& fields) {
auto ptr = std::make_shared<VMDatatypeCell>(tag, fields);
return std::dynamic_pointer_cast<VMObjectCell>(ptr);
}

VMObject VMTuple(const std::vector<VMObject>& fields) {
inline VMObject VMTuple(const std::vector<VMObject>& fields) {
return VMDatatype(0, fields);
}

inline NDArray ToNDArray(const VMObject& obj) {
CHECK(obj.get());
CHECK(obj->tag == VMObjectTag::kTensor);
std::shared_ptr<VMTensorCell> o = std::dynamic_pointer_cast<VMTensorCell>(obj);
CHECK(obj.ptr.get());
CHECK(obj.ptr->tag == VMObjectTag::kTensor);
std::shared_ptr<VMTensorCell> o = std::dynamic_pointer_cast<VMTensorCell>(obj.ptr);
return o->data;
}

Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ typedef enum {
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kVMObject = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct Type;
struct Expr;
}


// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
Expand All @@ -35,6 +36,12 @@ namespace tvm {
// forward declarations
class Integer;

namespace relay {
namespace vm {
struct VMObject;
}
}

namespace runtime {
// forward declarations
class TVMArgs;
Expand Down Expand Up @@ -569,6 +576,7 @@ class TVMArgValue : public TVMPODValue_ {
inline operator tvm::Integer() const;
// get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr();
operator relay::vm::VMObject() const;
};

/*!
Expand Down Expand Up @@ -702,6 +710,9 @@ class TVMRetValue : public TVMPODValue_ {
other.data_ = nullptr;
return *this;
}

TVMRetValue& operator=(relay::vm::VMObject other);

TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
Expand Down Expand Up @@ -797,6 +808,9 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kVMObject: {
throw dmlc::Error("here");
}
default: {
if (other.type_code() < kExtBegin) {
SwitchToPOD(other.type_code());
Expand Down Expand Up @@ -844,6 +858,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
// case kModuleHandle: delete ptr<relay::vm::VMObject>(); break;
}
if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY
Expand Down Expand Up @@ -873,6 +888,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kVMObject: return "VMObject";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
elif isinstance(arg, VMObjectBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.VM_OBJECT
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
Expand Down Expand Up @@ -218,12 +221,18 @@ def _handle_return_func(x):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)

class VMObjectBase(object):
__slots__ = ["handle"]

def __init__(self, handle):
self.handle = handle

# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
RETURN_SWITCH[TypeCode.VM_OBJECT] = lambda x: _CLASS_VM_OBJ(x.v_handle)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
Expand All @@ -233,6 +242,7 @@ def _handle_return_func(x):

_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_VM_OBJ = None

def _set_class_module(module_class):
"""Initialize the module."""
Expand All @@ -242,3 +252,7 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class

def _set_vm_obj_function(vm_obj_class):
global _CLASS_VM_OBJ
_CLASS_VM_OBJ = vm_obj_class
3 changes: 2 additions & 1 deletion python/tvm/_ffi/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import _set_class_function, _set_class_module, _set_vm_obj_function
from ._ctypes.function import VMObjectBase as _VMObjectBase
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func

Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TypeCode(object):
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
VM_OBJECT = 14
EXT_BEGIN = 15

class TVMByteArray(ctypes.Structure):
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/_vm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from tvm._ffi.function import _init_api

_init_api("relay._vm", __name__)
4 changes: 4 additions & 0 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..base import NodeBase, register_relay_node
from ..expr import Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
from .. import _vm

class Value(NodeBase):
"""Base class of all values.
Expand All @@ -20,6 +21,9 @@ def from_scalar(value, dtype=None):
"""Convert a Python scalar to a Relay scalar."""
return TensorValue(const(value, dtype).data)

def to_vm(self):
return _vm._ValueToVM(self)


@register_relay_node
class TupleValue(Value):
Expand Down
16 changes: 11 additions & 5 deletions python/tvm/relay/vm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api
import tvm
from tvm._ffi.function import _init_api, _VMObjectBase, _set_vm_obj_function
from ..relay import ir_pass
from ..relay.backend.interpreter import TensorValue, TupleValue, Executor
from ..relay.module import Module
from ..relay.expr import GlobalVar, Function, var, Call, Expr
from ..relay.ty import FuncType
from . import _vm

import numpy as np

_init_api("relay._vm", __name__)
class VMObject(_VMObjectBase):
def to_value(self):
return _vm._VMToValue(self)

_set_vm_obj_function(VMObject)

def optimize(expr, mod=None):
# TODO: We need to move this optimization code into the optimizer/pass manager
Expand All @@ -35,12 +41,12 @@ def eta_expand(expr, mod):

def _convert(arg, cargs):
if isinstance(arg, np.ndarray):
cargs.append(TensorValue(arg))
cargs.append(_vm._Tensor(tvm.nd.array(arg)))
elif isinstance(arg, tuple):
field_args = []
for field in arg:
_convert(field, field_args)
cargs.append(TupleValue(*field_args))
cargs.append(_vm._Tuple(*field_args))
else:
raise "unsupported type"

Expand All @@ -67,7 +73,7 @@ def eval_vm(expr_or_mod, ctx, *args):

cargs = convert(list(args))
# import pdb; pdb.set_trace()
return _evaluate_vm(mod, ctx.device_type, ctx.device_id, cargs)
return _vm._evaluate_vm(mod, ctx.device_type, ctx.device_id, *cargs).to_value()

class VMExecutor(Executor):
"""
Expand Down
10 changes: 10 additions & 0 deletions src/api/dsl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <dmlc/thread_local.h>
#include <tvm/api_registry.h>
#include <tvm/attrs.h>
#include <tvm/relay/vm/vm.h>
#include <vector>
#include <string>
#include <exception>
Expand Down Expand Up @@ -73,6 +74,12 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object = true;
}
}
void Visit(const char* key, relay::vm::VMObject* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
};

struct APIAttrDir : public AttrVisitor {
Expand Down Expand Up @@ -108,6 +115,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
void Visit(const char* key, relay::vm::VMObject* value) final {
names->push_back(key);
}
};

class DSLAPIImpl : public DSLAPI {
Expand Down
31 changes: 31 additions & 0 deletions src/lang/reflection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/relay/vm/vm.h>
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <string>
Expand All @@ -34,6 +35,8 @@ inline Type String2Type(std::string s) {
return TVMType2Type(runtime::String2TVMType(s));
}

using relay::vm::VMObject;
using relay::vm::VMObjectCell;

// indexer to index all the ndoes
class NodeIndexer : public AttrVisitor {
Expand All @@ -42,6 +45,8 @@ class NodeIndexer : public AttrVisitor {
std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list;
std::unordered_map<VMObjectCell*, size_t> vm_obj_index;
std::vector<VMObjectCell*> vm_obj_list;

void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {}
Expand All @@ -54,13 +59,23 @@ class NodeIndexer : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get());
}

void Visit(const char* key, runtime::NDArray* value) final {
DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
if (tensor_index.count(ptr)) return;
CHECK_EQ(tensor_index.size(), tensor_list.size());
tensor_index[ptr] = tensor_list.size();
tensor_list.push_back(ptr);
}

void Visit(const char* key, VMObject* value) final {
VMObjectCell* ptr = value->ptr.get();
if (vm_obj_index.count(ptr)) return;
CHECK_EQ(vm_obj_index.size(), vm_obj_list.size());
vm_obj_index[ptr] = vm_obj_list.size();
vm_obj_list.push_back(ptr);
}

// make index of all the children of node
void MakeIndex(Node* node) {
if (node == nullptr) return;
Expand Down Expand Up @@ -144,6 +159,7 @@ class JSONAttrGetter : public AttrVisitor {
public:
const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
const std::unordered_map<VMObjectCell*, size_t>* vm_obj_index_;
JSONNode* node_;

void Visit(const char* key, double* value) final {
Expand Down Expand Up @@ -178,6 +194,10 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
void Visit(const char* key, VMObject* value) final {
node_->attrs[key] = std::to_string(
vm_obj_index_->at(value->ptr.get()));
}
// Get the node
void Get(Node* node) {
if (node == nullptr) {
Expand Down Expand Up @@ -231,6 +251,8 @@ class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
const std::vector<VMObject>* vm_obj_list_;

JSONNode* node_;

std::string GetValue(const char* key) const {
Expand Down Expand Up @@ -285,6 +307,12 @@ class JSONAttrSetter : public AttrVisitor {
CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index);
}
void Visit(const char* key, VMObject* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, vm_obj_list_->size());
*value = vm_obj_list_->at(index);
}
// set node to be current JSONNode
void Set(Node* node) {
if (node == nullptr) return;
Expand Down Expand Up @@ -462,6 +490,9 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, VMObject* value) final {
*value = GetAttr(key).operator VMObject();
}

private:
runtime::TVMArgValue GetAttr(const char* key) {
Expand Down
Loading

0 comments on commit 8661269

Please sign in to comment.