From b4347a63dcb6b64f5e9cdc92495071bf299ffc89 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 17 Jan 2020 19:10:27 +0000 Subject: [PATCH] closure base --- include/tvm/relay/interpreter.h | 58 +++++++++++----- include/tvm/runtime/object.h | 2 +- include/tvm/runtime/vm.h | 34 ++++++---- python/tvm/relay/backend/vm.py | 2 +- src/relay/backend/interpreter.cc | 110 +++++++++++++++---------------- src/runtime/vm/vm.cc | 10 ++- 6 files changed, 124 insertions(+), 92 deletions(-) diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 9309171003c2..e090dc85f238 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -66,14 +66,44 @@ namespace relay { runtime::TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, Target target); -/*! \brief A Relay Recursive Closure. A closure that has a name. */ -class RecClosure; +/*! \brief The container type of Closures used by the interpreter. */ +class InterpreterClosureObj : public runtime::vm::ClosureObj { + public: + /*! \brief The set of free variables in the closure. + * + * These are the captured variables which are required for + * evaluation when we call the closure. + */ + tvm::Map env; + /*! \brief The function which implements the closure. + * + * \note May reference the variables contained in the env. + */ + Function func; + + InterpreterClosureObj() {} + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("env", &env); + v->Visit("func", &func); + } + + static constexpr const char* _type_key = "interpreter.Closure"; + TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj); +}; + +class InterpreterClosure : public runtime::vm::Closure { + public: + TVM_DLL InterpreterClosure(tvm::Map env, Function func); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, + InterpreterClosureObj); +}; /*! \brief The container type of RecClosure. */ class RecClosureObj : public Object { public: /*! \brief The closure. */ - runtime::vm::Closure clos; + InterpreterClosure clos; /*! \brief variable the closure bind to. */ Var bind; @@ -84,20 +114,16 @@ class RecClosureObj : public Object { v->Visit("bind", &bind); } - TVM_DLL static RecClosure make(runtime::vm::Closure clos, Var bind); - - static constexpr const char* _type_key = "relay.RecClosure"; + static constexpr const char* _type_key = "interpreter.RecClosure"; TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object); }; class RecClosure : public ObjectRef { public: + TVM_DLL RecClosure(InterpreterClosure clos, Var bind); TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj); }; -/*! \brief A reference value. */ -class RefValue; - struct RefValueObj : Object { mutable ObjectRef value; @@ -107,20 +133,16 @@ struct RefValueObj : Object { v->Visit("value", &value); } - TVM_DLL static RefValue make(ObjectRef val); - static constexpr const char* _type_key = "relay.RefValue"; TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object); }; class RefValue : public ObjectRef { public: + TVM_DLL RefValue(ObjectRef val); TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj); }; -/*! \brief An ADT constructor value. */ -class ConstructorValue; - struct ConstructorValueObj : Object { int32_t tag; @@ -135,16 +157,16 @@ struct ConstructorValueObj : Object { v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(int32_t tag, - tvm::Array fields, - Constructor construtor = {}); - static constexpr const char* _type_key = "relay.ConstructorValue"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object); }; class ConstructorValue : public ObjectRef { public: + TVM_DLL ConstructorValue(int32_t tag, + tvm::Array fields, + Constructor construtor = {}); + TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); }; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 30b91865a041..5314cab43f68 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -51,7 +51,7 @@ namespace runtime { enum TypeIndex { /*! \brief Root object type. */ kRoot = 0, - kVMClosure = 1, + kClosure = 1, kVMADT = 2, kRuntimeModule = 3, kStaticIndexEnd, diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 17ffd474c427..43c222d0994a 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -43,32 +43,40 @@ namespace vm { * Relay VM and interpreter. */ class ClosureObj : public Object { + public: + static constexpr const uint32_t _type_index = TypeIndex::kClosure; + static constexpr const char* _type_key = "Closure"; + TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object); +}; + +/*! \brief reference to closure. */ +class Closure : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); +}; + +/*! + * \brief An object representing a vm closure. + */ +class VMClosureObj : public ClosureObj { public: /*! * \brief The index into the function list. The function could be any - * function object that is compatible to a certain runtime, i.e. VM or - * interpreter. + * function object that is compatible to the VM runtime. */ size_t func_index; /*! \brief The free variables of the closure. */ std::vector free_vars; - static constexpr const uint32_t _type_index = TypeIndex::kVMClosure; static constexpr const char* _type_key = "vm.Closure"; - TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj); }; /*! \brief reference to closure. */ -class Closure : public ObjectRef { +class VMClosure : public Closure { public: - Closure(size_t func_index, std::vector free_vars) { - auto ptr = make_object(); - ptr->func_index = func_index; - ptr->free_vars = std::move(free_vars); - data_ = std::move(ptr); - } - - TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); + VMClosure(size_t func_index, std::vector free_vars); + TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj); }; /*! \brief Magic number for NDArray list file */ diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 6279b4146ce7..31009008b23c 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -24,10 +24,10 @@ import tvm from tvm import autotvm, container +from tvm.object import Object from tvm.relay import expr as _expr from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi import base as _base -from tvm._ffi.object import Object from . import _vm from .interpreter import Executor diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index fa40e6a8e319..88ced23fa41a 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -37,7 +37,20 @@ namespace tvm { namespace relay { using namespace runtime; -using namespace runtime::vm; + +InterpreterClosure::InterpreterClosure(tvm::Map env, + Function func) { + ObjectPtr n = make_object(); + n->env = std::move(env); + n->func = std::move(func); + data_ = std::move(n); +} + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; +}); inline const PackedFunc& GetPackedFunc(const std::string& name) { const PackedFunc* pf = tvm::runtime::Registry::Get(name); @@ -47,11 +60,11 @@ inline const PackedFunc& GetPackedFunc(const std::string& name) { // TODO(@jroesch): this doesn't support mutual letrec /* Object Implementation */ -RecClosure RecClosureObj::make(Closure clos, Var bind) { +RecClosure::RecClosure(InterpreterClosure clos, Var bind) { ObjectPtr n = make_object(); n->clos = std::move(clos); n->bind = std::move(bind); - return RecClosure(n); + data_ = std::move(n); } TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) @@ -60,14 +73,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "RecClosureObj(" << node->clos << ")"; }); -RefValue RefValueObj::make(ObjectRef value) { +RefValue::RefValue(ObjectRef value) { ObjectPtr n = make_object(); n->value = value; - return RefValue(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("relay._make.RefValue") -.set_body_typed(RefValueObj::make); +.set_body_typed([](ObjectRef value){ + return RefValue(value); +}); TVM_REGISTER_NODE_TYPE(RefValueObj); @@ -77,18 +92,21 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "RefValueObj(" << node->value << ")"; }); -ConstructorValue ConstructorValueObj::make(int32_t tag, - tvm::Array fields, - Constructor constructor) { +ConstructorValue::ConstructorValue(int32_t tag, + tvm::Array fields, + Constructor constructor) { ObjectPtr n = make_object(); n->tag = tag; n->fields = fields; n->constructor = constructor; - return ConstructorValue(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") -.set_body_typed(ConstructorValueObj::make); +.set_body_typed([](int32_t tag, tvm::Array fields, + Constructor constructor) { + return ConstructorValue(tag, fields, constructor); +}); TVM_REGISTER_NODE_TYPE(ConstructorValueObj); @@ -153,7 +171,7 @@ struct Stack { class InterpreterState; /*! \brief A container capturing the state of the interpreter. */ -class InterpreterStateNode : public Object { +class InterpreterStateObj : public Object { public: using Frame = tvm::Map; using Stack = tvm::Array; @@ -172,16 +190,16 @@ class InterpreterStateNode : public Object { static InterpreterState make(Expr current_expr, Stack stack); static constexpr const char* _type_key = "relay.InterpreterState"; - TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object); }; class InterpreterState : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateNode); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj); }; -InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { - ObjectPtr n = make_object(); +InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) { + ObjectPtr n = make_object(); n->current_expr = std::move(current_expr); n->stack = std::move(stack); return InterpreterState(n); @@ -262,13 +280,8 @@ class Interpreter : } ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) { - if (func_index_map_.count(func) == 0) { - func_index_map_[func] = func_index_++; - eval_funcs_.push_back(func); - } - std::vector free_var_values; + tvm::Map captured_mod; Array free_vars = FreeVars(func); - std::vector captured_vars; for (const auto& var : free_vars) { // Evaluate the free var (which could be a function call) if it hasn't @@ -277,16 +290,13 @@ class Interpreter : continue; } - ObjectRef value = Eval(var); - free_var_values.push_back(value); - captured_vars.push_back(var); + captured_mod.Set(var, Eval(var)); } // We must use mutation here to build a self referential closure. - Closure closure(func_index_map_[func], free_var_values); - closure_captured_vars_[closure] = captured_vars; + InterpreterClosure closure(captured_mod, func); if (letrec_name.defined()) { - return RecClosureObj::make(closure, letrec_name); + return RecClosure(closure, letrec_name); } return std::move(closure); } @@ -540,23 +550,18 @@ class Interpreter : } // Invoke the closure - ObjectRef Invoke(const Closure& closure, + ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, const Var& bind = Var()) { - CHECK_GT(eval_funcs_.size(), closure->func_index); - CHECK_GT(func_index_map_.count(eval_funcs_[closure->func_index]), 0U); - auto func = eval_funcs_[closure->func_index]; // Get a reference to the function inside the closure. - if (func->IsPrimitive()) { - return InvokePrimitiveOp(func, args); + if (closure->func->IsPrimitive()) { + return InvokePrimitiveOp(closure->func, args); } + auto func = closure->func; // Allocate a frame with the parameters and free variables. tvm::Map locals; CHECK_EQ(func->params.size(), args.size()); - CHECK_GT(closure_captured_vars_.count(closure), 0U); - const auto& captured_vars = closure_captured_vars_[closure]; - CHECK_EQ(captured_vars.size(), closure->free_vars.size()); for (size_t i = 0; i < func->params.size(); i++) { CHECK_EQ(locals.count(func->params[i]), 0); @@ -564,14 +569,13 @@ class Interpreter : } // Add the var to value mappings from the Closure's environment. - for (size_t i = 0; i < closure->free_vars.size(); i++) { - Var var = captured_vars[i]; - CHECK_EQ(locals.count(var), 0); - locals.Set(var, closure->free_vars[i]); + for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { + CHECK_EQ(locals.count((*it).first), 0); + locals.Set((*it).first, (*it).second); } if (bind.defined()) { - locals.Set(bind, RecClosureObj::make(closure, bind)); + locals.Set(bind, RecClosure(closure, bind)); } return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); @@ -593,12 +597,12 @@ class Interpreter : "fusing and lowering"; } if (auto con = call->op.as()) { - return ConstructorValueObj::make(con->tag, args, GetRef(con)); + return ConstructorValue(con->tag, args, GetRef(con)); } // Now we just evaluate and expect to find a closure. ObjectRef fn_val = Eval(call->op); - if (const ClosureObj* closure_node = fn_val.as()) { - auto closure = GetRef(closure_node); + if (const InterpreterClosureObj* closure_node = fn_val.as()) { + auto closure = GetRef(closure_node); return this->Invoke(closure, args); } else if (const RecClosureObj* closure_node = fn_val.as()) { return this->Invoke(closure_node->clos, args, closure_node->bind); @@ -665,7 +669,7 @@ class Interpreter : } ObjectRef VisitExpr_(const RefCreateNode* op) final { - return RefValueObj::make(Eval(op->value)); + return RefValue(Eval(op->value)); } ObjectRef VisitExpr_(const RefReadNode* op) final { @@ -727,12 +731,12 @@ class Interpreter : } InterpreterState get_state(Expr e = Expr()) const { - InterpreterStateNode::Stack stack; + InterpreterStateObj::Stack stack; for (auto fr : this->stack_.frames) { - InterpreterStateNode::Frame frame = fr.locals; + InterpreterStateObj::Frame frame = fr.locals; stack.push_back(frame); } - auto state = InterpreterStateNode::make(e, stack); + auto state = InterpreterStateObj::make(e, stack); return state; } @@ -751,14 +755,6 @@ class Interpreter : // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; const Op& shape_of_op_; - // The free vars captured by the last closure. - std::unordered_map, ObjectHash, ObjectEqual> closure_captured_vars_; - // The index of the Relay function being evaluated. - int func_index_{0}; - // The Relay function to index map. - std::unordered_map func_index_map_; - // The saved functions. - std::vector eval_funcs_; }; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c5ab1fdb4b62..84a3e26fb7f9 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -45,6 +45,12 @@ namespace tvm { namespace runtime { namespace vm { +VMClosure::VMClosure(size_t func_index, std::vector free_vars) { + auto ptr = make_object(); + ptr->func_index = func_index; + ptr->free_vars = std::move(free_vars); + data_ = std::move(ptr); +} inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) { // We could put cache in here, from ctx to storage allocator. @@ -906,7 +912,7 @@ void VirtualMachine::RunLoop() { } case Opcode::InvokeClosure: { auto object = ReadRegister(instr.closure); - const auto* closure = object.as(); + const auto* closure = object.as(); std::vector args; for (auto free_var : closure->free_vars) { @@ -1008,7 +1014,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_freevar; i++) { free_vars.push_back(ReadRegister(instr.free_vars[i])); } - WriteRegister(instr.dst, Closure(instr.func_index, free_vars)); + WriteRegister(instr.dst, VMClosure(instr.func_index, free_vars)); pc_++; goto main_loop; }