diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index f77a927eeabf..c71cbd2b0c2d 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -23,7 +23,6 @@ from tvm._ffi.function import get_global_func from tvm.contrib import graph_runtime from tvm.ndarray import array -from tvm.rpc import base as rpc_base from . import debug_result _DUMP_ROOT_PREFIX = "tvmdbg_" @@ -60,25 +59,17 @@ def create(graph_json_str, libmod, ctx, dump_root=None): except AttributeError: raise ValueError("Type %s is not supported" % type(graph_json_str)) try: - fcreate = get_global_func("tvm.graph_runtime_debug.create") + ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) + if num_rpc_ctx == len(ctx): + fcreate = ctx[0]._rpc_sess.get_function( + "tvm.graph_runtime_debug.create") + else: + fcreate = get_global_func("tvm.graph_runtime_debug.create") except ValueError: raise ValueError( "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " "config.cmake and rebuild TVM to enable debug mode" ) - - ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) - if num_rpc_ctx == len(ctx): - libmod = rpc_base._ModuleHandle(libmod) - try: - fcreate = ctx[0]._rpc_sess.get_function( - "tvm.graph_runtime_debug.remote_create" - ) - except ValueError: - raise ValueError( - "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " - "config.cmake and rebuild TVM to enable debug mode" - ) func_obj = fcreate(graph_json_str, libmod, *device_type_id) return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root) diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 3e182e26fd22..f4ee2f7db28d 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -51,11 +51,10 @@ def create(graph_json_str, libmod, ctx): ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx) if num_rpc_ctx == len(ctx): - hmod = rpc_base._ModuleHandle(libmod) - fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create") - return GraphModule(fcreate(graph_json_str, hmod, *device_type_id)) + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create") + else: + fcreate = get_global_func("tvm.graph_runtime.create") - fcreate = get_global_func("tvm.graph_runtime.create") return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) def get_device_ctx(libmod, ctx): diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index ab28cb662f2a..314ddabe2c7b 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -27,7 +27,6 @@ #include #include #include "../graph_runtime.h" -#include "../../object_internal.h" namespace tvm { namespace runtime { @@ -220,19 +219,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create") << args.num_args; *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); }); - -TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) << "The expected number of arguments for " - "graph_runtime.remote_create is " - "at least 4, but it has " - << args.num_args; - void* mhandle = args[1]; - ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle); - const auto& contexts = GetAllContext(args); - *rv = GraphRuntimeDebugCreate( - args[0], GetRef(mnode), contexts); - }); - } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index b286a5328d82..cc2478f27d19 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -36,7 +36,6 @@ #include #include "graph_runtime.h" -#include "../object_internal.h" namespace tvm { namespace runtime { @@ -551,19 +550,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") const auto& contexts = GetAllContext(args); *rv = GraphRuntimeCreate(args[0], args[1], contexts); }); - -TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) << "The expected number of arguments for " - "graph_runtime.remote_create is " - "at least 4, but it has " - << args.num_args; - void* mhandle = args[1]; - ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle); - - const auto& contexts = GetAllContext(args); - *rv = GraphRuntimeCreate( - args[0], GetRef(mnode), contexts); - }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 27cbc37d1a06..1042a4f68e5e 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -41,7 +41,7 @@ class RPCWrappedFunc { } void operator()(TVMArgs args, TVMRetValue *rv) const { - sess_->CallFunc(handle_, args, rv, &fwrap_); + sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_); } ~RPCWrappedFunc() { try { @@ -55,6 +55,9 @@ class RPCWrappedFunc { TVMArgs args, TVMRetValue* rv); + static void* UnwrapRemote(int rpc_sess_table_index, + const TVMArgValue& arg); + // deleter of RPC remote array static void RemoteNDArrayDeleter(NDArray::Container* ptr) { RemoteSpace* space = static_cast(ptr->dl_tensor.data); @@ -181,6 +184,25 @@ class RPCModuleNode final : public ModuleNode { PackedFunc fwrap_; }; +void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, + const TVMArgValue& arg) { + if (arg.type_code() == kModuleHandle) { + Module mod = arg; + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") + << "ValueError: Cannot pass a non-RPC module to remote"; + auto* rmod = static_cast(mod.operator->()); + CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index) + << "ValueError: Cannot pass in module into a different remote session"; + return rmod->module_handle(); + } else { + LOG(FATAL) << "ValueError: Cannot pass type " + << runtime::TypeCode2Str(arg.type_code()) + << " as an argument to the remote"; + return nullptr; + } +} + void RPCWrappedFunc::WrapRemote(std::shared_ptr sess, TVMArgs args, TVMRetValue *rv) { diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index c7e524d2f295..16b0e7f69529 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -202,23 +202,33 @@ class RPCSession::EventHandler : public dmlc::Stream { return ctx; } // Send Packed sequence to writer. + // + // client_mode: whether we are in client mode. + // + // funwrap: auxiliary function to unwrap remote Object + // when it is provided, we need to unwrap objects. + // // return_ndarray is a special flag to handle returning of ndarray // In this case, we return the shape, context and data of the array, // as well as a customized PackedFunc that handles deletion of // the array in the remote. void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, - int n, + int num_args, + bool client_mode, + FUnwrapRemoteObject funwrap = nullptr, bool return_ndarray = false) { - this->Write(n); - for (int i = 0; i < n; ++i) { + std::swap(client_mode_, client_mode); + + this->Write(num_args); + for (int i = 0; i < num_args; ++i) { int tcode = type_codes[i]; if (tcode == kNDArrayContainer) tcode = kArrayHandle; this->Write(tcode); } // Argument packing. - for (int i = 0; i < n; ++i) { + for (int i = 0; i < num_args; ++i) { int tcode = type_codes[i]; TVMValue value = arg_values[i]; switch (tcode) { @@ -241,7 +251,23 @@ class RPCSession::EventHandler : public dmlc::Stream { break; } case kFuncHandle: - case kModuleHandle: + case kModuleHandle: { + // always send handle in 64 bit. + uint64_t handle; + // allow pass module as argument to remote. + if (funwrap != nullptr) { + void* remote_handle = (*funwrap)( + rpc_sess_table_index_, + runtime::TVMArgValue(value, tcode)); + handle = reinterpret_cast(remote_handle); + } else { + CHECK(!client_mode_) + << "Cannot directly pass remote object as argument"; + handle = reinterpret_cast(value.v_handle); + } + this->Write(handle); + break; + } case kHandle: { // always send handle in 64 bit. uint64_t handle = reinterpret_cast(value.v_handle); @@ -300,6 +326,7 @@ class RPCSession::EventHandler : public dmlc::Stream { } } } + std::swap(client_mode_, client_mode); } // Endian aware IO handling @@ -430,11 +457,11 @@ class RPCSession::EventHandler : public dmlc::Stream { case kHandle: case kStr: case kBytes: + case kModuleHandle: case kTVMContext: { this->RequestBytes(sizeof(TVMValue)); break; } - case kFuncHandle: - case kModuleHandle: { + case kFuncHandle: { CHECK(client_mode_) << "Only client can receive remote functions"; this->RequestBytes(sizeof(TVMValue)); break; @@ -656,7 +683,7 @@ class RPCSession::EventHandler : public dmlc::Stream { TVMValue ret_value; ret_value.v_str = e.what(); int ret_tcode = kStr; - SendPackedSeq(&ret_value, &ret_tcode, 1); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); } } this->SwitchToState(kRecvCode); @@ -711,7 +738,7 @@ class RPCSession::EventHandler : public dmlc::Stream { } } this->Write(code); - SendPackedSeq(&ret_value, &ret_tcode, 1); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); arg_recv_stage_ = 0; this->SwitchToState(kRecvCode); } @@ -734,7 +761,7 @@ class RPCSession::EventHandler : public dmlc::Stream { if (rv.type_code() == kStr) { ret_value.v_str = rv.ptr()->c_str(); ret_tcode = kStr; - SendPackedSeq(&ret_value, &ret_tcode, 1); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); } else if (rv.type_code() == kBytes) { std::string* bytes = rv.ptr(); TVMByteArray arr; @@ -742,14 +769,14 @@ class RPCSession::EventHandler : public dmlc::Stream { arr.size = bytes->length(); ret_value.v_handle = &arr; ret_tcode = kBytes; - SendPackedSeq(&ret_value, &ret_tcode, 1); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); } else if (rv.type_code() == kFuncHandle || rv.type_code() == kModuleHandle) { // always send handle in 64 bit. CHECK(!client_mode_) << "Only server can send function and module handle back."; rv.MoveToCHost(&ret_value, &ret_tcode); - SendPackedSeq(&ret_value, &ret_tcode, 1); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); } else if (rv.type_code() == kNDArrayContainer) { // always send handle in 64 bit. CHECK(!client_mode_) @@ -764,18 +791,18 @@ class RPCSession::EventHandler : public dmlc::Stream { NDArray::Container* nd = static_cast(ret_value_pack[0].v_handle); ret_value_pack[1].v_handle = nd; ret_tcode_pack[1] = kHandle; - SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true); + SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); } else { ret_value = rv.value(); ret_tcode = rv.type_code(); - SendPackedSeq(&ret_value, &ret_tcode, 1); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); } } catch (const std::runtime_error& e) { RPCCode code = RPCCode::kException; this->Write(code); ret_value.v_str = e.what(); ret_tcode = kStr; - SendPackedSeq(&ret_value, &ret_tcode, 1); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); } } @@ -873,7 +900,7 @@ void RPCSession::Init() { &reader_, &writer_, table_index_, name_, &remote_key_); // Quick function to call remote. call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - handler_->SendPackedSeq(args.values, args.type_codes, args.num_args); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); RPCCode code = HandleUntilReturnEvent(rv, true, nullptr); CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); }); @@ -954,13 +981,16 @@ int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) { void RPCSession::CallFunc(void* h, TVMArgs args, TVMRetValue* rv, + FUnwrapRemoteObject funwrap, const PackedFunc* fwrap) { std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCallFunc; handler_->Write(code); uint64_t handle = reinterpret_cast(h); handler_->Write(handle); - handler_->SendPackedSeq(args.values, args.type_codes, args.num_args); + handler_->SendPackedSeq( + args.values, args.type_codes, args.num_args, true, funwrap); code = HandleUntilReturnEvent(rv, true, fwrap); CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); } diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index ab5f16dadc46..cc04af80232f 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -90,6 +90,16 @@ enum class RPCCode : int { kNDArrayFree }; +/*! + * \brief Function that unwraps a remote object to its handle. + * \param rpc_sess_table_index RPC session table index for validation. + * \param obj Handle to the object argument. + * \return The corresponding handle. + */ +typedef void* (*FUnwrapRemoteObject)( + int rpc_sess_table_index, + const TVMArgValue& obj); + /*! * \brief Abstract channel interface used to create RPCSession. */ @@ -144,11 +154,13 @@ class RPCSession { * \param handle The function handle * \param args The arguments * \param rv The return value. + * \param funpwrap Function that takes a remote object and returns the raw handle. * \param fwrap Wrapper function to turn Function/Module handle into real return. */ void CallFunc(RPCFuncHandle handle, TVMArgs args, TVMRetValue* rv, + FUnwrapRemoteObject funwrap, const PackedFunc* fwrap); /*! * \brief Copy bytes into remote array content.