From 1b49f3cca450d33a019de6578401717bebfeceb6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 4 Feb 2019 17:54:26 -0800 Subject: [PATCH 1/7] Point submodules at my fork --- .gitmodules | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 984326434c3fb..b518c37e12117 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/dmlc/dmlc-core [submodule "HalideIR"] path = 3rdparty/HalideIR - url = https://github.com/dmlc/HalideIR + url = https://github.com/jroesch/HalideIR + beanch = vm [submodule "dlpack"] path = 3rdparty/dlpack url = https://github.com/dmlc/dlpack From 182bb37ab8582edef37473ecc057f1ac1ae6db61 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 4 Feb 2019 17:59:03 -0800 Subject: [PATCH 2/7] Bump HalideIR to be my fork containing VMObject changes --- 3rdparty/HalideIR | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index 97efb11fff131..a0b9563f45719 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit 97efb11fff13131480fcaa5adc65a0aef4a4cb5d +Subproject commit a0b9563f45719553adf4d39fe3c14db1af0e1f40 From b220a05607ad18c95e21c144433ed31cddb7a5d2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 4 Feb 2019 18:05:47 -0800 Subject: [PATCH 3/7] Fix .gitmodules --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index b518c37e12117..da9692b40e42b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,7 +4,7 @@ [submodule "HalideIR"] path = 3rdparty/HalideIR url = https://github.com/jroesch/HalideIR - beanch = vm + branch = vm [submodule "dlpack"] path = 3rdparty/dlpack url = https://github.com/dmlc/dlpack From 58d3a10e574f4da6f2b8796a98a325e8b8e1cffd Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 4 Feb 2019 18:08:27 -0800 Subject: [PATCH 4/7] Add bumped HalideIR ver --- 3rdparty/HalideIR | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index a0b9563f45719..23186ce2a65f5 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit a0b9563f45719553adf4d39fe3c14db1af0e1f40 +Subproject commit 23186ce2a65f5290e19a44f8d5b8048cba9e4769 From d095ce9c240bf88dcae7871145d9a5b9cb08964a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 4 Feb 2019 18:09:10 -0800 Subject: [PATCH 5/7] Support for passing VMObjects around --- include/tvm/relay/vm/vm.h | 22 +++++++++++++------- include/tvm/runtime/c_runtime_api.h | 1 + include/tvm/runtime/packed_func.h | 11 ++++++++++ src/api/dsl_api.cc | 10 ++++++++++ src/lang/reflection.cc | 31 +++++++++++++++++++++++++++++ src/relay/ir/text_printer.cc | 4 ++++ src/relay/vm/vm.cc | 23 +++++++++++++++++++-- tests/python/relay/test_vm.py | 21 +++++++++---------- 8 files changed, 104 insertions(+), 19 deletions(-) diff --git a/include/tvm/relay/vm/vm.h b/include/tvm/relay/vm/vm.h index ccfb025a67b31..169a9d1aead3b 100644 --- a/include/tvm/relay/vm/vm.h +++ b/include/tvm/relay/vm/vm.h @@ -42,7 +42,15 @@ struct VMTensorCell : public VMObjectCell { : VMObjectCell(VMObjectTag::kTensor), data(data) {} }; -using VMObject = std::shared_ptr; +struct VMObject { + std::shared_ptr ptr; + VMObject(std::shared_ptr ptr) : ptr(ptr) {} + VMObject() : ptr() {} + VMObject(const VMObject& obj) : ptr(obj.ptr) {} + VMObjectCell* operator->() { + return this->ptr.operator->(); + } +}; struct VMDatatypeCell : public VMObjectCell { size_t tag; @@ -53,24 +61,24 @@ struct VMDatatypeCell : public VMObjectCell { }; -VMObject VMTensor(const tvm::runtime::NDArray& data) { +inline VMObject VMTensor(const tvm::runtime::NDArray& data) { auto ptr = std::make_shared(data); return std::dynamic_pointer_cast(ptr); } -VMObject VMDatatype(size_t tag, const std::vector& fields) { +inline VMObject VMDatatype(size_t tag, const std::vector& fields) { auto ptr = std::make_shared(tag, fields); return std::dynamic_pointer_cast(ptr); } -VMObject VMTuple(const std::vector& fields) { +inline VMObject VMTuple(const std::vector& fields) { return VMDatatype(0, fields); } inline NDArray ToNDArray(const VMObject& obj) { - CHECK(obj.get()); - CHECK(obj->tag == VMObjectTag::kTensor); - std::shared_ptr o = std::dynamic_pointer_cast(obj); + CHECK(obj.ptr.get()); + CHECK(obj.ptr->tag == VMObjectTag::kTensor); + std::shared_ptr o = std::dynamic_pointer_cast(obj.ptr); return o->data; } diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index b493cf6dc8da1..903f731dcb5f0 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -85,6 +85,7 @@ typedef enum { kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, + kVMObject = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 5179696ec42fd..5efda5ba83ddb 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ struct Type; struct Expr; } + // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY #define TVM_RUNTIME_HEADER_ONLY 0 @@ -35,6 +36,12 @@ namespace tvm { // forward declarations class Integer; +namespace relay { +namespace vm { + struct VMObject; +} +} + namespace runtime { // forward declarations class TVMArgs; @@ -569,6 +576,7 @@ class TVMArgValue : public TVMPODValue_ { inline operator tvm::Integer() const; // get internal node ptr, if it is node inline NodePtr& node_sptr(); + operator relay::vm::VMObject() const; }; /*! @@ -702,6 +710,9 @@ class TVMRetValue : public TVMPODValue_ { other.data_ = nullptr; return *this; } + + TVMRetValue& operator=(relay::vm::VMObject other); + TVMRetValue& operator=(PackedFunc f) { this->SwitchToClass(kFuncHandle, f); return *this; diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 1c2c294a5f306..55770d1165962 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -73,6 +74,12 @@ struct APIAttrGetter : public AttrVisitor { found_ref_object = true; } } + void Visit(const char* key, relay::vm::VMObject* value) final { + if (skey == key) { + *ret = value[0]; + found_ref_object = true; + } + } }; struct APIAttrDir : public AttrVisitor { @@ -108,6 +115,9 @@ struct APIAttrDir : public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } + void Visit(const char* key, relay::vm::VMObject* value) final { + names->push_back(key); + } }; class DSLAPIImpl : public DSLAPI { diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index 86a11a7e5b422..aba5fdacdfdc8 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,8 @@ inline Type String2Type(std::string s) { return TVMType2Type(runtime::String2TVMType(s)); } +using relay::vm::VMObject; +using relay::vm::VMObjectCell; // indexer to index all the ndoes class NodeIndexer : public AttrVisitor { @@ -42,6 +45,8 @@ class NodeIndexer : public AttrVisitor { std::vector node_list{nullptr}; std::unordered_map tensor_index; std::vector tensor_list; + std::unordered_map vm_obj_index; + std::vector vm_obj_list; void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} @@ -54,6 +59,7 @@ class NodeIndexer : public AttrVisitor { void Visit(const char* key, NodeRef* value) final { MakeIndex(value->node_.get()); } + void Visit(const char* key, runtime::NDArray* value) final { DLTensor* ptr = const_cast((*value).operator->()); if (tensor_index.count(ptr)) return; @@ -61,6 +67,15 @@ class NodeIndexer : public AttrVisitor { tensor_index[ptr] = tensor_list.size(); tensor_list.push_back(ptr); } + + void Visit(const char* key, VMObject* value) final { + VMObjectCell* ptr = value->ptr.get(); + if (vm_obj_index.count(ptr)) return; + CHECK_EQ(vm_obj_index.size(), vm_obj_list.size()); + vm_obj_index[ptr] = vm_obj_list.size(); + vm_obj_list.push_back(ptr); + } + // make index of all the children of node void MakeIndex(Node* node) { if (node == nullptr) return; @@ -144,6 +159,7 @@ class JSONAttrGetter : public AttrVisitor { public: const std::unordered_map* node_index_; const std::unordered_map* tensor_index_; + const std::unordered_map* vm_obj_index_; JSONNode* node_; void Visit(const char* key, double* value) final { @@ -178,6 +194,10 @@ class JSONAttrGetter : public AttrVisitor { node_->attrs[key] = std::to_string( tensor_index_->at(const_cast((*value).operator->()))); } + void Visit(const char* key, VMObject* value) final { + node_->attrs[key] = std::to_string( + vm_obj_index_->at(value->ptr.get())); + } // Get the node void Get(Node* node) { if (node == nullptr) { @@ -231,6 +251,8 @@ class JSONAttrSetter : public AttrVisitor { public: const std::vector >* node_list_; const std::vector* tensor_list_; + const std::vector* vm_obj_list_; + JSONNode* node_; std::string GetValue(const char* key) const { @@ -285,6 +307,12 @@ class JSONAttrSetter : public AttrVisitor { CHECK_LE(index, tensor_list_->size()); *value = tensor_list_->at(index); } + void Visit(const char* key, VMObject* value) final { + size_t index; + ParseValue(key, &index); + CHECK_LE(index, vm_obj_list_->size()); + *value = vm_obj_list_->at(index); + } // set node to be current JSONNode void Set(Node* node) { if (node == nullptr) return; @@ -462,6 +490,9 @@ class NodeAttrSetter : public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { *value = GetAttr(key).operator runtime::NDArray(); } + void Visit(const char* key, VMObject* value) final { + *value = GetAttr(key).operator VMObject(); + } private: runtime::TVMArgValue GetAttr(const char* key) { diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 617e6c3cf77bb..8f6ca90dc4454 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -844,6 +844,10 @@ class TextPrinter::AttrPrinter: public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { LOG(FATAL) << "do not allow NDarray as argument"; } + void Visit(const char* key, relay::vm::VMObject* value) final { + LOG(FATAL) << "do not allow VMObject as argument"; + } + private: void PrintSep() { diff --git a/src/relay/vm/vm.cc b/src/relay/vm/vm.cc index 9d7f4b0f3b305..e8f240e50ca79 100644 --- a/src/relay/vm/vm.cc +++ b/src/relay/vm/vm.cc @@ -17,9 +17,23 @@ using namespace tvm::runtime; namespace tvm { + +// Packed Function extensions. +TVMRetValue& runtime::TVMRetValue::operator=(relay::vm::VMObject other) { + this->SwitchToClass(kVMObject, other); + return *this; +} + +runtime::TVMArgValue::operator relay::vm::VMObject() const { + if (type_code_ == kNull) return relay::vm::VMObject(nullptr); + TVM_CHECK_TYPE_CODE(type_code_, kVMObject); + return *ptr(); +} + namespace relay { namespace vm { + Instruction::Instruction() {} Instruction::Instruction(const Instruction& instr) { @@ -585,7 +599,7 @@ void VirtualMachine::Run() { case Opcode::GetField: { auto object = stack[bp + instr.object_offset]; CHECK(object->tag == VMObjectTag::kDatatype) << "Object is not data type object"; - auto tuple = std::dynamic_pointer_cast(object); + const std::shared_ptr& tuple = std::dynamic_pointer_cast(object.ptr); auto field = tuple->fields[instr.field_index]; stack.push_back(field); pc++; @@ -690,7 +704,7 @@ Value ConvertVMToValue(VMObject obj) { return TensorValueNode::make(ToNDArray(obj)); } case VMObjectTag::kDatatype: { - LOG(FATAL) << "unsupported return value: data type"; + LOG(FATAL) << "unsupported return value: data type"; return Value(); } default: @@ -708,6 +722,11 @@ VMObject EvaluateModule(const Module& module, const std::vector ctxs return vm.Invoke(module->entry_func, vm_args); } +TVM_REGISTER_API("relay._vm._Tensor") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = VMTensor(args[0]); +}); + TVM_REGISTER_API("relay._vm._evaluate_vm") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef to_compile = args[0]; diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index cbb5de2f9e007..fcbc5dad48768 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -193,13 +193,14 @@ def test_rnn(): if __name__ == "__main__": test_id() - test_op() - test_cond() - test_simple_if() - test_simple_call() - test_count_loop() - test_sum_loop() - test_tuple_fst() - test_let_scalar() - test_let_tensor() - test_rnn() + # test_op() + # test_cond() + # test_simple_if() + # test_simple_call() + # test_count_loop() + # test_sum_loop() + # test_tuple_fst() + # test_tuple_second() + # test_let_scalar() + # test_let_tensor() + # test_rnn() From e8cb96d77049765c588f89b2bfdc08fac9c0ac5e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 4 Feb 2019 20:47:25 -0800 Subject: [PATCH 6/7] Further improve support for VM object passing --- include/tvm/runtime/packed_func.h | 5 +++ python/tvm/_ffi/_ctypes/function.py | 14 ++++++++ python/tvm/_ffi/function.py | 3 +- python/tvm/_ffi/runtime_ctypes.py | 1 + python/tvm/relay/_vm.py | 3 ++ python/tvm/relay/backend/interpreter.py | 4 +++ python/tvm/relay/vm.py | 16 ++++++--- src/relay/vm/vm.cc | 44 +++++++++++++++++++++---- 8 files changed, 78 insertions(+), 12 deletions(-) create mode 100644 python/tvm/relay/_vm.py diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 5efda5ba83ddb..3920845cbc54a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -808,6 +808,9 @@ class TVMRetValue : public TVMPODValue_ { kNodeHandle, *other.template ptr >()); break; } + case kVMObject: { + throw dmlc::Error("here"); + } default: { if (other.type_code() < kExtBegin) { SwitchToPOD(other.type_code()); @@ -855,6 +858,7 @@ class TVMRetValue : public TVMPODValue_ { static_cast(value_.v_handle)->DecRef(); break; } + // case kModuleHandle: delete ptr(); break; } if (type_code_ > kExtBegin) { #if TVM_RUNTIME_HEADER_ONLY @@ -884,6 +888,7 @@ inline const char* TypeCode2Str(int type_code) { case kFuncHandle: return "FunctionHandle"; case kModuleHandle: return "ModuleHandle"; case kNDArrayContainer: return "NDArrayContainer"; + case kVMObject: return "VMObject"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; } diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 3c2a7a5f8c9b2..82304117d7fd5 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -143,6 +143,9 @@ def _make_tvm_args(args, temp_args): values[i].v_handle = arg.handle type_codes[i] = TypeCode.FUNC_HANDLE temp_args.append(arg) + elif isinstance(arg, VMObjectBase): + values[i].v_handle = arg.handle + type_codes[i] = TypeCode.VM_OBJECT else: raise TypeError("Don't know how to handle type %s" % type(arg)) return values, type_codes, num_args @@ -218,12 +221,18 @@ def _handle_return_func(x): handle = FunctionHandle(handle) return _CLASS_FUNCTION(handle, False) +class VMObjectBase(object): + __slots__ = ["handle"] + + def __init__(self, handle): + self.handle = handle # setup return handle for function type _node.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) +RETURN_SWITCH[TypeCode.VM_OBJECT] = lambda x: _CLASS_VM_OBJ(x.v_handle) C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( _handle_return_func, TypeCode.FUNC_HANDLE) C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( @@ -233,6 +242,7 @@ def _handle_return_func(x): _CLASS_MODULE = None _CLASS_FUNCTION = None +_CLASS_VM_OBJ = None def _set_class_module(module_class): """Initialize the module.""" @@ -242,3 +252,7 @@ def _set_class_module(module_class): def _set_class_function(func_class): global _CLASS_FUNCTION _CLASS_FUNCTION = func_class + +def _set_vm_obj_function(vm_obj_class): + global _CLASS_VM_OBJ + _CLASS_VM_OBJ = vm_obj_class diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index ca1812d4109ac..04a9153d270ae 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -22,7 +22,8 @@ from ._cy2.core import convert_to_tvm_func except IMPORT_EXCEPT: # pylint: disable=wrong-import-position - from ._ctypes.function import _set_class_function, _set_class_module + from ._ctypes.function import _set_class_function, _set_class_module, _set_vm_obj_function + from ._ctypes.function import VMObjectBase as _VMObjectBase from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import convert_to_tvm_func diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index ef5316b5e2677..9d7605214c6f2 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -26,6 +26,7 @@ class TypeCode(object): STR = 11 BYTES = 12 NDARRAY_CONTAINER = 13 + VM_OBJECT = 14 EXT_BEGIN = 15 class TVMByteArray(ctypes.Structure): diff --git a/python/tvm/relay/_vm.py b/python/tvm/relay/_vm.py new file mode 100644 index 0000000000000..ff54b89300669 --- /dev/null +++ b/python/tvm/relay/_vm.py @@ -0,0 +1,3 @@ +from tvm._ffi.function import _init_api + +_init_api("relay._vm", __name__) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 0173844f1ef5a..6ac10c35714b6 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -10,6 +10,7 @@ from ..base import NodeBase, register_relay_node from ..expr import Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder +from .. import _vm class Value(NodeBase): """Base class of all values. @@ -20,6 +21,9 @@ def from_scalar(value, dtype=None): """Convert a Python scalar to a Relay scalar.""" return TensorValue(const(value, dtype).data) + def to_vm(self): + return _vm._ValueToVM(self) + @register_relay_node class TupleValue(Value): diff --git a/python/tvm/relay/vm.py b/python/tvm/relay/vm.py index 975b13e481905..6f98c3e5a59bd 100644 --- a/python/tvm/relay/vm.py +++ b/python/tvm/relay/vm.py @@ -1,15 +1,21 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable """The interface of expr function exposed from C++.""" -from tvm._ffi.function import _init_api +import tvm +from tvm._ffi.function import _init_api, _VMObjectBase, _set_vm_obj_function from ..relay import ir_pass from ..relay.backend.interpreter import TensorValue, TupleValue, Executor from ..relay.module import Module from ..relay.expr import GlobalVar, Function, var, Call, Expr from ..relay.ty import FuncType +from . import _vm import numpy as np -_init_api("relay._vm", __name__) +class VMObject(_VMObjectBase): + def to_value(self): + return _vm._VMToValue(self) + +_set_vm_obj_function(VMObject) def optimize(expr, mod=None): # TODO: We need to move this optimization code into the optimizer/pass manager @@ -35,12 +41,12 @@ def eta_expand(expr, mod): def _convert(arg, cargs): if isinstance(arg, np.ndarray): - cargs.append(TensorValue(arg)) + cargs.append(_vm._Tensor(tvm.nd.array(arg))) elif isinstance(arg, tuple): field_args = [] for field in arg: _convert(field, field_args) - cargs.append(TupleValue(*field_args)) + cargs.append(_vm.Tuple(*field_args)) else: raise "unsupported type" @@ -67,7 +73,7 @@ def eval_vm(expr_or_mod, ctx, *args): cargs = convert(list(args)) # import pdb; pdb.set_trace() - return _evaluate_vm(mod, ctx.device_type, ctx.device_id, cargs) + return _vm._evaluate_vm(mod, ctx.device_type, ctx.device_id, *cargs).to_value() class VMExecutor(Executor): """ diff --git a/src/relay/vm/vm.cc b/src/relay/vm/vm.cc index e8f240e50ca79..e686492ceafe8 100644 --- a/src/relay/vm/vm.cc +++ b/src/relay/vm/vm.cc @@ -692,13 +692,14 @@ void ConvertArgsToVM(tvm::Array args, std::vector& out) { /*! \brief Convert from an array of relay.Value into VM compatible objects. */ -std::vector ConvertArgsToVM(tvm::Array args) { +VMObject ValueToVM(Value value) { std::vector out; - ConvertArgsToVM(args, out); - return out; + ConvertArgsToVM({value}, out); + CHECK_LT(out.size(), 2); + return out[0]; } -Value ConvertVMToValue(VMObject obj) { +Value VMToValue(VMObject obj) { switch (obj->tag) { case VMObjectTag::kTensor: { return TensorValueNode::make(ToNDArray(obj)); @@ -722,11 +723,36 @@ VMObject EvaluateModule(const Module& module, const std::vector ctxs return vm.Invoke(module->entry_func, vm_args); } +TVM_REGISTER_API("relay._vm._ValueToVM") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ValueToVM(args[0]); +}); + +TVM_REGISTER_API("relay._vm._VMToValue") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = VMToValue(args[0]); +}); + TVM_REGISTER_API("relay._vm._Tensor") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = VMTensor(args[0]); }); +TVM_REGISTER_API("relay._vm._Tuple") +.set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector fields; + for (size_t i = 0; i < args.size(); i++) { + fields.push_back(args[i]); + } + *ret = VMTuple(fields); +}); + +TVM_REGISTER_API("relay._vm._Datatype") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = VMTensor(args[0]); +}); + + TVM_REGISTER_API("relay._vm._evaluate_vm") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef to_compile = args[0]; @@ -745,9 +771,15 @@ TVM_REGISTER_API("relay._vm._evaluate_vm") LOG(FATAL) << "expected function or module"; } - std::vector vm_args = ConvertArgsToVM(args[3]); + std::cout << "About to get args " << std::endl; + std::vector vm_args; + for (auto i = 3; i < args.size(); i++) { + std::cout << "Arg: " << i << std::endl; + VMObject obj = args[i]; + vm_args.push_back(obj); + } auto result = EvaluateModule(module, {ctx}, vm_args); - *ret = ConvertVMToValue(result); + *ret = result; }); From 7ae24fb8067f70d820ca28f3a74f275523fa1513 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 4 Feb 2019 20:49:35 -0800 Subject: [PATCH 7/7] Restore tests --- python/tvm/relay/vm.py | 2 +- tests/python/relay/test_vm.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/vm.py b/python/tvm/relay/vm.py index 6f98c3e5a59bd..247adf9778940 100644 --- a/python/tvm/relay/vm.py +++ b/python/tvm/relay/vm.py @@ -46,7 +46,7 @@ def _convert(arg, cargs): field_args = [] for field in arg: _convert(field, field_args) - cargs.append(_vm.Tuple(*field_args)) + cargs.append(_vm._Tuple(*field_args)) else: raise "unsupported type" diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index fcbc5dad48768..15a8f4d32a29d 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -193,14 +193,14 @@ def test_rnn(): if __name__ == "__main__": test_id() - # test_op() - # test_cond() - # test_simple_if() - # test_simple_call() - # test_count_loop() - # test_sum_loop() - # test_tuple_fst() - # test_tuple_second() - # test_let_scalar() - # test_let_tensor() - # test_rnn() + test_op() + test_cond() + test_simple_if() + test_simple_call() + test_count_loop() + test_sum_loop() + test_tuple_fst() + test_tuple_second() + test_let_scalar() + test_let_tensor() + test_rnn()