From 50921a4ccc0f53d5777ac87cd1f028c0b68742c2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 28 Apr 2020 11:48:56 -0700 Subject: [PATCH] [REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra. This PR refactors the RPC protocol to make it more modularized. - RPCSession: represent a set of features that need to be implemented - RPCEndPont: End point that forwards the RPCSession requests over a communication channel. - RPCModule: Exposes an RPCSession as an rpc device in the TVM Runtime API. In the new design, the local machine is presented as a special case of RPCSession. The remote is just another client session that calls into RPCEndPoint. The RPC communication path is as follows. ``` client -> ClientSession -> EndPoint[client@n0] -> networking[between n0 <=> n1] -> EndPoint[server@n1] -> LocalSession[@n1] ``` Because of the new modular design, we can now chain more sessions together. For example, we can now run the following proxy setup (testcase in test_runtime_rpc.test_session_constructor). ``` client -> ClientSession -> Endpoint[client@n0] -> networking[between n0 <=> n1] -> Endpoint[server@n1] -> ClientSession -> Endpoint[client@n1] -> networking[between n1 <=> n2] -> Endpoint[server@n2] -> LocalSession[@n2] ``` We can also implement other types of Sessions. For example, we could make uTVM session a special case of the RPCSession and use the same mechanism for session management. We also add more comments about the internal of the RPC. The communication protocol is simplfied using a similar convention as PackedFunc. This allows us to further reduce the amount of special remote syscalls. Due to the major improvement and simplification, we are making a non-compatible update to the RPC protocol. It means that the client and server needs to be upgraded to together in order for it to function correctly. This PR also introduces a versioning mechanism to the current RPC procotol, so that future upgrade will be produce more user friendly with error messages. --- apps/cpp_rpc/rpc_server.cc | 20 +- apps/cpp_rpc/rpc_tracker_client.h | 2 +- include/tvm/runtime/device_api.h | 10 +- .../org/apache/tvm/contrib/GraphRuntime.java | 9 +- .../main/java/org/apache/tvm/rpc/Client.java | 2 +- .../org/apache/tvm/rpc/NativeServerLoop.java | 2 +- .../java/org/apache/tvm/rpc/RPCSession.java | 4 +- python/tvm/contrib/graph_runtime.py | 9 +- python/tvm/error.py | 5 + python/tvm/rpc/_ffi_api.py | 21 + python/tvm/rpc/base.py | 7 - python/tvm/rpc/client.py | 65 +- python/tvm/rpc/proxy.py | 4 +- python/tvm/rpc/server.py | 7 +- src/runtime/module.cc | 2 +- src/runtime/registry.cc | 22 +- src/runtime/rpc/rpc_channel.cc | 52 + src/runtime/rpc/rpc_channel.h | 90 ++ src/runtime/rpc/rpc_device_api.cc | 53 +- src/runtime/rpc/rpc_endpoint.cc | 1334 +++++++++++++++++ src/runtime/rpc/rpc_endpoint.h | 253 ++++ src/runtime/rpc/rpc_event_impl.cc | 18 +- src/runtime/rpc/rpc_local_session.cc | 146 ++ src/runtime/rpc/rpc_local_session.h | 82 + src/runtime/rpc/rpc_module.cc | 401 +++-- src/runtime/rpc/rpc_server_env.cc | 7 +- src/runtime/rpc/rpc_session.cc | 1263 +--------------- src/runtime/rpc/rpc_session.h | 407 ++--- src/runtime/rpc/rpc_socket_impl.cc | 59 +- tests/python/unittest/test_runtime_rpc.py | 65 +- web/tvm_runtime.js | 2 +- 31 files changed, 2642 insertions(+), 1781 deletions(-) create mode 100644 python/tvm/rpc/_ffi_api.py create mode 100644 src/runtime/rpc/rpc_channel.cc create mode 100644 src/runtime/rpc/rpc_channel.h create mode 100644 src/runtime/rpc/rpc_endpoint.cc create mode 100644 src/runtime/rpc/rpc_endpoint.h create mode 100644 src/runtime/rpc/rpc_local_session.cc create mode 100644 src/runtime/rpc/rpc_local_session.h diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index ea4ab00c113b9..57a68f452d3d6 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -33,7 +33,7 @@ #include #include "../../src/support/socket.h" -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" #include "rpc_env.h" #include "rpc_server.h" @@ -86,7 +86,7 @@ class RPCServer { tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), custom_addr_(std::move(custom_addr)) { - + } /*! @@ -98,7 +98,7 @@ class RPCServer { tracker_sock_.Close(); listen_sock_.Close(); } catch(...) { - + } } @@ -144,7 +144,7 @@ class RPCServer { } int timeout = GetTimeOutFromOpts(opts); -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) // step 3: serving if (timeout != 0) { const pid_t timer_pid = fork(); @@ -197,7 +197,7 @@ class RPCServer { try { SpawnRPCChild(conn.sockfd, seconds(timeout)); } catch (const std::exception&) { - + } auto dur = high_resolution_clock::now() - start_time; @@ -217,10 +217,10 @@ class RPCServer { * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, + void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, - support::SockAddr* addr, - std::string* opts, + support::SockAddr* addr, + std::string* opts, int ping_period = 2) { std::set old_keyset; std::string matchkey; @@ -330,7 +330,7 @@ void ServerLoopFromChild(SOCKET socket) { tvm::support::TCPSocket sock(socket); const auto env = RPCEnv(); RPCServerLoop(int(sock.sockfd)); - + sock.Close(); env.CleanUp(); } @@ -357,7 +357,7 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc._ServerCreate") +TVM_REGISTER_GLOBAL("rpc.ServerCreate") .set_body([](TVMArgs args, TVMRetValue* rv) { RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index dfd576f4c1951..9b9a707dd3763 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -31,7 +31,7 @@ #include #include -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_end_point.h" #include "../../src/support/socket.h" namespace tvm { diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index f2ddc84e9f98b..12069182354b0 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -157,9 +157,9 @@ class TVM_DLL DeviceAPI { * \param event_dst The destination stream to synchronize. */ virtual void SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, - TVMStreamHandle event_dst); - /*! + TVMStreamHandle event_src, + TVMStreamHandle event_dst); + /*! * \brief Allocate temporal workspace for backend execution. * * \note We have the following assumption about backend temporal @@ -176,8 +176,8 @@ class TVM_DLL DeviceAPI { * as OpenGL, as nbytes is sufficient for most backends. */ virtual void* AllocWorkspace(TVMContext ctx, - size_t nbytes, - DLDataType type_hint = {}); + size_t nbytes, + DLDataType type_hint = {}); /*! * \brief Free temporal workspace in backend execution. * diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java index c31c67f283af5..0f3eb0b3c2bdc 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java @@ -51,7 +51,7 @@ public static GraphModule create(String graphJson, Module libmod, TVMContext ctx throw new IllegalArgumentException("libmod.typeKey != rpc"); } final int sessIndex = (int) ((Function) reflectionStaticCall( - RPC.class, "getApi", "_SessTableIndex")) + RPC.class, "getApi", "SessTableIndex")) .pushArg(libmod).invoke().asLong(); if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) { throw new IllegalArgumentException(String.format( @@ -59,13 +59,6 @@ public static GraphModule create(String graphJson, Module libmod, TVMContext ctx sessIndex, reflectionGetField(rpcSession, "tblIndex"))); } - Function rpcModuleHandle = (Function) reflectionStaticCall( - RPC.class, "getApi","_ModuleHandle"); - if (rpcModuleHandle == null) { - throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle." - + "Did you compile tvm_runtime with the correct version?"); - } - Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create"); if (fcreate == null) { throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create." diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java index 5178ac900a367..69321c3b51c80 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java @@ -29,7 +29,7 @@ public class Client { * @return The connected session. */ public static RPCSession connect(String url, int port, String key) { - Function doConnect = RPC.getApi("_Connect"); + Function doConnect = RPC.getApi("Connect"); if (doConnect == null) { throw new RuntimeException("Please compile with USE_RPC=1"); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java index 29a457f39a40c..1f3191fb2e8ca 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java @@ -46,7 +46,7 @@ public NativeServerLoop(final Function fsend, final Function frecv) { try { tempDir = serverEnv(); System.err.println("starting server loop..."); - RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); + RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); System.err.println("done server loop..."); } catch (IOException e) { e.printStackTrace(); diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java index 92b328488b403..b9f621473cf4d 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java @@ -39,7 +39,7 @@ public class RPCSession { RPCSession(Module sess) { session = sess; - tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong(); + tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong(); } /** @@ -237,7 +237,7 @@ public byte[] download(String path) { * @return The remote module containing remote function. */ public Module loadModule(String path) { - return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); + return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); } diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 73235f71c77ba..740d1c3f19f3c 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -18,9 +18,10 @@ import numpy as np import tvm._ffi -from .._ffi.base import string_types -from .._ffi.runtime_ctypes import TVMContext -from ..rpc import base as rpc_base +from tvm.rpc import _ffi_api as _rpc_ffi_api +from tvm.rpc import base as rpc_base +from tvm._ffi.base import string_types +from tvm._ffi.runtime_ctypes import TVMContext def create(graph_json_str, libmod, ctx): @@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx): device_type = cur_ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: assert libmod.type_key == "rpc" - assert rpc_base._SessTableIndex( + assert _rpc_ffi_api.SessTableIndex( libmod) == cur_ctx._rpc_sess._tbl_index num_rpc_ctx += 1 device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK diff --git a/python/tvm/error.py b/python/tvm/error.py index 4c3e6060c25a8..7366fdb3be39e 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -57,6 +57,11 @@ def __init__(self, msg): register_error("KeyError", KeyError) +@register_error +class RPCError(RuntimeError): + """Error thrown by the RPC call.""" + + @register_error class OpError(TVMError): """Base class of all operator errors in frontends.""" diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py new file mode 100644 index 0000000000000..1a7cc739b5c1f --- /dev/null +++ b/python/tvm/rpc/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.rpc""" +import tvm._ffi + + +tvm._ffi._init_api("rpc", __name__) diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index bc81534a12d99..f0e33f8503f28 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -17,8 +17,6 @@ """Base definitions for RPC.""" # pylint: disable=invalid-name -from __future__ import absolute_import - import socket import time import json @@ -26,7 +24,6 @@ import struct import random import logging -import tvm._ffi from .._ffi.base import py_str @@ -176,7 +173,3 @@ def connect_with_retry(addr, timeout=60, retry_period=5): logger.warning("Cannot connect to tracker %s, retry in %g secs...", str(addr), retry_period) time.sleep(retry_period) - - -# Still use tvm.rpc for the foreign functions -tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base") diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index ed57e0d4276d4..72de36106cb1c 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -15,19 +15,17 @@ # specific language governing permissions and limitations # under the License. """RPC client tools""" -from __future__ import absolute_import - import os import socket import struct import time -import tvm._ffi -from tvm.contrib import util + from tvm._ffi.base import TVMError from tvm.runtime import ndarray as nd -from tvm.runtime import load_module as _load_module from . import base +from . import server +from . import _ffi_api class RPCSession(object): @@ -38,7 +36,7 @@ class RPCSession(object): # pylint: disable=invalid-name def __init__(self, sess): self._sess = sess - self._tbl_index = base._SessTableIndex(sess) + self._tbl_index = _ffi_api.SessTableIndex(sess) self._remote_funcs = {} def get_function(self, name): @@ -145,7 +143,7 @@ def load_module(self, path): m : Module The remote module containing remote function. """ - return base._LoadRemoteModule(self._sess, path) + return _ffi_api.LoadRemoteModule(self._sess, path) def cpu(self, dev_id=0): """Construct CPU device.""" @@ -184,27 +182,8 @@ class LocalSession(RPCSession): """ def __init__(self): # pylint: disable=super-init-not-called - self.context = nd.context - self.get_function = tvm._ffi.get_global_func - self._temp = util.tempdir() - - def upload(self, data, target=None): - if isinstance(data, bytearray): - if not target: - raise ValueError("target must present when file is a bytearray") - blob = data - else: - blob = bytearray(open(data, "rb").read()) - if not target: - target = os.path.basename(data) - with open(self._temp.relpath(target), "wb") as f: - f.write(blob) - - def download(self, path): - return bytearray(open(self._temp.relpath(path), "rb").read()) - - def load_module(self, path): - return _load_module(self._temp.relpath(path)) + self._temp = server._server_env([]) + RPCSession.__init__(self, _ffi_api.LocalSession()) class TrackerSession(object): @@ -378,7 +357,7 @@ def request_and_run(self, key, max_retry, str(last_err))) -def connect(url, port, key="", session_timeout=0): +def connect(url, port, key="", session_timeout=0, session_constructor=None): """Connect to RPC Server Parameters @@ -397,15 +376,41 @@ def connect(url, port, key="", session_timeout=0): the connection when duration is longer than this value. When duration is zero, it means the request must always be kept alive. + session_constructor: List + List of additional arguments to passed as the remote session constructor. + Returns ------- sess : RPCSession The connected session. + + Examples + -------- + Normal usage + .. code-block:: python + + client = rpc.connect(server_url, server_port, server_key) + + Session_constructor can be used to customize the session in the remote + The following code connects to a remote internal server via a proxy + by constructing another RPCClientSession on the proxy machine and use that + as the serving session of the proxy endpoint. + + .. code-block:: python + + client_via_proxy = rpc.connect( + proxy_server_url, proxy_server_port, proxy_server_key, + session_constructor=[ + "rpc.Connect", internal_url, internal_port, internal_key]) + """ try: if session_timeout: key += " -timeout=%s" % str(session_timeout) - sess = base._Connect(url, port, key) + session_constructor = session_constructor if session_constructor else [] + if not isinstance(session_constructor, (list, tuple)): + raise TypeError("Expect the session constructor to be a list or tuple") + sess = _ffi_api.Connect(url, port, key, *session_constructor) except NameError: raise RuntimeError("Please compile with USE_RPC=1") return RPCSession(sess) diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index c3a3647948eea..555fa549ad531 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -42,7 +42,7 @@ raise ImportError( "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg) -from . import base +from . import _ffi_api from .base import TrackerCode from .server import _server_env from .._ffi.base import py_str @@ -549,7 +549,7 @@ def _fsend(data): data = bytes(data) conn.write_message(data, binary=True) return len(data) - on_message = base._CreateEventDrivenServer( + on_message = _ffi_api.CreateEventDrivenServer( _fsend, "WebSocketProxyServer", "%toinit") return on_message diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 03749c1c17e4a..15a3c7de789d3 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -43,6 +43,7 @@ from tvm._ffi.libinfo import find_lib_path from tvm.runtime.module import load_module as _load_module from tvm.contrib import util +from . import _ffi_api from . import base from . base import TrackerCode @@ -56,7 +57,7 @@ def _server_env(load_library, work_path=None): temp = util.tempdir() # pylint: disable=unused-variable - @tvm._ffi.register_func("tvm.rpc.server.workpath") + @tvm._ffi.register_func("tvm.rpc.server.workpath", override=True) def get_workpath(path): return temp.relpath(path) @@ -81,7 +82,7 @@ def _serve_loop(sock, addr, load_library, work_path=None): """Server loop""" sockfd = sock.fileno() temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) + _ffi_api.ServerLoop(sockfd) if not work_path: temp.remove() logger.info("Finish serving %s", addr) @@ -330,7 +331,7 @@ def __init__(self, utvm_dev_config_args=None, ): try: - if base._ServerLoop is None: + if _ffi_api.ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") except NameError: raise RuntimeError("Please compile with USE_RPC=1") diff --git a/src/runtime/module.cc b/src/runtime/module.cc index d2ed7ff9e2b7b..813a79d43c061 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -36,7 +36,7 @@ void ModuleNode::Import(Module other) { if (!std::strcmp(this->type_key(), "rpc")) { static const PackedFunc* fimport_ = nullptr; if (fimport_ == nullptr) { - fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule"); + fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule"); CHECK(fimport_ != nullptr); } (*fimport_)(GetRef(this), other); diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 4717d89e33c15..855a342a7e97a 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -37,7 +37,7 @@ struct Registry::Manager { // map storing the functions. // We delibrately used raw pointer // This is because PackedFunc can contain callbacks into the host languge(python) - // and the resource can become invalid because of indeterminstic order of destruction. + // and the resource can become invalid because of indeterminstic order of destruction and forking. // The resources will only be recycled during program exit. std::unordered_map fmap; // mutex @@ -60,20 +60,18 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) return *this; } -Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*) +Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*) Manager* m = Manager::Global(); std::lock_guard lock(m->mutex); - auto it = m->fmap.find(name); - if (it == m->fmap.end()) { - Registry* r = new Registry(); - r->name_ = name; - m->fmap[name] = r; - return *r; - } else { - CHECK(override) - << "Global PackedFunc " << name << " is already registered"; - return *it->second; + if (m->fmap.count(name)) { + CHECK(can_override) + << "Global PackedFunc " << name << " is already registered"; } + + Registry* r = new Registry(); + r->name_ = name; + m->fmap[name] = r; + return *r; } bool Registry::Remove(const std::string& name) { diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc new file mode 100644 index 0000000000000..f8dc6e6363245 --- /dev/null +++ b/src/runtime/rpc/rpc_channel.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.cc + */ +#include +#include "rpc_channel.h" + +namespace tvm { +namespace runtime { + +size_t CallbackChannel::Send(const void* data, size_t size) { + TVMByteArray bytes; + bytes.data = static_cast(data); + bytes.size = size; + int64_t n = fsend_(bytes); + if (n == -1) { + LOG(FATAL) << "CallbackChannel::Send"; + } + return static_cast(n); +} + +size_t CallbackChannel::Recv(void* data, size_t size) { + TVMRetValue ret = frecv_(size); + + if (ret.type_code() != kTVMBytes) { + LOG(FATAL) << "CallbackChannel::Recv"; + } + std::string* bytes = ret.ptr(); + memcpy(static_cast(data), bytes->c_str(), bytes->length()); + return bytes->length(); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h new file mode 100644 index 0000000000000..f8d498231cd42 --- /dev/null +++ b/src/runtime/rpc/rpc_channel.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_H_ +#define TVM_RUNTIME_RPC_RPC_CHANNEL_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Abstract channel interface used to create RPCEndpoint. + */ +class RPCChannel { + public: + /*! \brief virtual destructor */ + virtual ~RPCChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + virtual size_t Send(const void* data, size_t size) = 0; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + virtual size_t Recv(void* data, size_t size) = 0; +}; + +/*! + * \brief RPC channel which callback + * frontend (Python/Java/etc.)'s send & recv function + */ +class CallbackChannel final : public RPCChannel { + public: + explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) + : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} + + ~CallbackChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + size_t Send(const void* data, size_t size) final; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + size_t Recv(void* data, size_t size) final; + + private: + PackedFunc fsend_; + PackedFunc frecv_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_H_ diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 9fd45acd14bf4..ade4d1683fb18 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include "rpc_session.h" namespace tvm { @@ -31,20 +32,24 @@ namespace runtime { class RPCDeviceAPI final : public DeviceAPI { public: void SetDevice(TVMContext ctx) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevSetDevice, ctx); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->SetDevice(remote_ctx); } + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { - *rv = GetSess(ctx)->CallRemote( - RPCCode::kDevGetAttr, ctx, static_cast(kind)); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv); } + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { auto sess = GetSess(ctx); - void *data = sess->CallRemote( - RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + auto remote_ctx = RemoveSessMask(ctx); + void *data = sess->GetDeviceAPI(remote_ctx)->AllocDataSpace( + remote_ctx, nbytes, alignment, type_hint); + RemoteSpace* space = new RemoteSpace(); space->data = data; space->sess = std::move(sess); @@ -52,9 +57,10 @@ class RPCDeviceAPI final : public DeviceAPI { } void FreeDataSpace(TVMContext ctx, void* ptr) final { RemoteSpace* space = static_cast(ptr); + auto remote_ctx = RemoveSessMask(ctx); try { - GetSess(ctx)->CallRemote( - RPCCode::kDevFreeData, ctx, space->data); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace( + remote_ctx, space->data); } catch (const dmlc::Error& e) { // fault tolerance to remote close. } @@ -75,29 +81,35 @@ class RPCDeviceAPI final : public DeviceAPI { to_dev_type > kRPCSessMask) { CHECK(ctx_from.device_type == ctx_to.device_type) << "Cannot copy across two different remote session"; - GetSess(ctx_from)->CallRemote( - RPCCode::kCopyAmongRemote, - static_cast(from)->data, from_offset, - static_cast(to)->data, to_offset, - size, ctx_from, ctx_to, type_hint, stream); + auto remote_ctx_from = RemoveSessMask(ctx_from); + auto remote_ctx_to = RemoveSessMask(ctx_to); + auto remote_ctx = remote_ctx_from; + if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to; + GetSess(ctx_from)->GetDeviceAPI(remote_ctx) + ->CopyDataFromTo(static_cast(from)->data, from_offset, + static_cast(to)->data, to_offset, + size, remote_ctx_from, remote_ctx_to, type_hint, stream); } else if (from_dev_type > kRPCSessMask && to_dev_type == kDLCPU) { + auto remote_ctx_from = RemoveSessMask(ctx_from); GetSess(ctx_from)->CopyFromRemote( static_cast(from)->data, from_offset, - to, to_offset, size, ctx_from, type_hint); + to, to_offset, size, remote_ctx_from, type_hint); } else if (from_dev_type == kDLCPU && to_dev_type > kRPCSessMask) { + auto remote_ctx_to = RemoveSessMask(ctx_to); GetSess(ctx_to)->CopyToRemote( - (void*)from, from_offset, // NOLINT(*) + const_cast(from), from_offset, static_cast(to)->data, to_offset, - size, ctx_to, type_hint); + size, remote_ctx_to, type_hint); } else { LOG(FATAL) << "expect copy from/to remote or between remote"; } } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevStreamSync, ctx, stream); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->StreamSync(ctx, stream); } private: @@ -107,6 +119,11 @@ class RPCDeviceAPI final : public DeviceAPI { int tbl_index = dev_type / kRPCSessMask - 1; return RPCSession::Get(tbl_index); } + + static TVMContext RemoveSessMask(TVMContext ctx) { + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } }; TVM_REGISTER_GLOBAL("device_api.rpc") diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc new file mode 100644 index 0000000000000..a7b81d10043ac --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -0,0 +1,1334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_session.cc + * \brief RPC session for remote function call. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rpc_endpoint.h" +#include "rpc_local_session.h" +#include "../object_internal.h" +#include "../../support/ring_buffer.h" + +namespace tvm { +namespace runtime { + +/*! \brief Temp buffer for data array */ +struct RPCByteArrayBuffer { + TVMByteArray arr; + std::string data; +}; + +/*! \brief Temp buffer for data array */ +struct RPCDataArrayBuffer { + DLTensor tensor; + std::vector shape; +}; + +/*! + * \brief Temporal argument buffer. + */ +struct RPCArgBuffer { + // The argument values + std::vector value; + // The type codes. + std::vector tcode; + // Temporal resources. + std::vector> temp_bytes; + // Temporal array + std::vector> temp_array; + // convert buffer as TVMArgs + TVMArgs AsTVMArgs() const { + return TVMArgs(value.data(), tcode.data(), static_cast(value.size())); + } +}; + +/*! + * Event-driven state-machine based handlers for RPCEndpoint. + * + * It provides two key functions: + * + * - SendPackedSeq: send the arguments over to the peer + * - HandleNextEvent: handle the next request from the peer(RPCCode followed by per code protocol). + */ +class RPCEndpoint::EventHandler : public dmlc::Stream { + public: + EventHandler(support::RingBuffer* reader, + support::RingBuffer* writer, + std::string name, + std::string* remote_key) + : reader_(reader), + writer_(writer), + name_(name), + remote_key_(remote_key) { + this->Clear(); + if (*remote_key == "%toinit") { + state_ = kInitHeader; + remote_key_->resize(0); + pending_request_bytes_ = sizeof(int32_t); + } + } + + /*! + * \brief Bytes needed to fulfill current request + */ + size_t BytesNeeded() const { + if (reader_->bytes_available() < pending_request_bytes_) { + return pending_request_bytes_ - reader_->bytes_available(); + } else { + return 0; + } + } + + /*! + * \brief Request number of bytes from the reader. + * \param nbytes The number of bytes + */ + void RequestBytes(size_t nbytes) { + pending_request_bytes_ += nbytes; + reader_->Reserve(pending_request_bytes_); + } + + /*! \return Whether we are ready to handle next request. */ + bool Ready() const { + return reader_->bytes_available() >= pending_request_bytes_; + } + + /*! \return Whether we can perform a clean shutdown */ + bool CanCleanShutdown() const { + return state_ == kRecvCode; + } + + /*! \brief Finish the copy ack stage. */ + void FinishCopyAck() { + this->SwitchToState(kRecvCode); + } + + /*! + * \brief Enter the io loop until the next event. + * \param client_mode Whether we are in the client. + * \param setreturn The function to set the return value encoding. + * \return The function to set return values when there is a return event. + */ + RPCCode HandleNextEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { + std::swap(client_mode_, client_mode); + + while (this->Ready()) { + switch (state_) { + case kInitHeader: HandleInitHeader(); break; + case kRecvCode: HandleRecvCode(); break; + case kRecvCallHandle: { + CHECK(this->Read(&call_handle_)); + this->SwitchToState(kRecvPackedSeqNumArgs); + break; + } + case kRecvPackedSeqNumArgs: { + CHECK(this->Read(&num_packed_args_)); + arg_buf_.reset(new RPCArgBuffer()); + arg_buf_->value.resize(num_packed_args_); + arg_buf_->tcode.resize(num_packed_args_); + this->SwitchToState(kRecvPackedSeqTypeCode); + break; + } + case kRecvPackedSeqTypeCode: { + if (num_packed_args_ != 0) { + this->ReadArray(arg_buf_->tcode.data(), num_packed_args_); + } + arg_index_ = 0; + arg_recv_stage_ = 0; + this->SwitchToState(kRecvPackedSeqArg); + break; + } + case kRecvPackedSeqArg: { + this->HandleRecvPackedSeqArg(); + break; + } + case kDoCopyFromRemote: { + this->HandleCopyFromRemote(); + break; + } + case kDoCopyToRemote: { + this->HandleCopyToRemote(); + break; + } + case kReturnReceived: { + CHECK(setreturn != nullptr) << "fsetreturn not available"; + setreturn(arg_buf_->AsTVMArgs()); + this->SwitchToState(kRecvCode); + std::swap(client_mode_, client_mode); + return RPCCode::kReturn; + } + case kCopyAckReceived: { + std::swap(client_mode_, client_mode); + return RPCCode::kCopyAck; + } + case kShutdownReceived: { + std::swap(client_mode_, client_mode); + return RPCCode::kShutdown; + } + } + } + std::swap(client_mode_, client_mode); + return RPCCode::kNone; + } + + /*! \brief Clear all the states in the Handler.*/ + void Clear() { + state_ = kRecvCode; + pending_request_bytes_ = sizeof(RPCCode); + arg_recv_stage_ = 0; + arg_buf_.reset(); + } + + /*! + * \brief Validate that the arguments can be sent through RPC. + * \param arg_values The argument values. + * \param type_codes The type codes. + */ + void ValidateArguments(const TVMValue* arg_values, + const int* type_codes, + int num_args) { + TVMArgs args(arg_values, type_codes, num_args); + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { + LOG(FATAL) << "ValueError: Cannot pass argument " << i + << ", type " << args[i].AsObjectRef()->GetTypeKey() + << " is not supported by RPC"; + } else if (tcode == kTVMContext) { + DLContext ctx = args[i]; + CHECK_LT(static_cast(ctx.device_type), kRPCSessMask) + << "InternalError: cannot pass RPC context in the channel"; + } + } + } + + /*! + * \brief Send packed argument sequnce to the other peer. + * + * This function serves as the foundational communication primitive between peers. + * + * TVMValue sequence encoding protocol(according to the type): + * + * - int/float/uint/bytes/str: Serialize all content. + * - DLTensor: send meta-data, send data handle as opaque handle(via uint64_t) + * - OpaqueHandle: send as uint64_t + * - ModuleHandle, PackedFuncHandle: send as uint64_t, + * The support to Module/PackedFuncHandle are reserved for arguments + * in the CallFunc from a client to server only. + * Note that we cannot simply take these argument out(as the handle) + * refers to a value on the remote(instead of local). + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + */ + void SendPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode) { + std::swap(client_mode_, client_mode); + + this->Write(num_args); + this->WriteArray(type_codes, num_args); + + // Argument packing. + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + TVMValue value = arg_values[i]; + switch (tcode) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + this->Write(value.v_int64); + break; + } + case kTVMDataType: { + this->Write(value.v_type); + // padding + int32_t padding = 0; + this->Write(padding); + break; + } + case kTVMContext: { + this->Write(value.v_ctx); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + // always send handle in 64 bit. + uint64_t handle; + CHECK(client_mode_) + << "Cannot directly pass remote object in the return encoding"; + handle = reinterpret_cast(value.v_handle); + this->Write(handle); + break; + } + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + this->Write(handle); + break; + } + case kTVMNDArrayHandle: { + LOG(FATAL) << "Need to convert NDArray to DLTensor before sending"; + break; + } + case kTVMDLTensorHandle: { + DLTensor* arr = static_cast(value.v_handle); + TVMContext ctx; + uint64_t data; + // When we return NDArray, we directly return + // the space and the context + // The client will be further wrapping + ctx = arr->ctx; + data = reinterpret_cast(arr->data); + this->Write(data); + this->Write(ctx); + this->Write(arr->ndim); + this->Write(arr->dtype); + this->WriteArray(arr->shape, arr->ndim); + CHECK(arr->strides == nullptr) + << "Do not support strided remote array"; + CHECK_EQ(arr->byte_offset, 0) + << "Do not support send byte offset"; + break; + } + case kTVMNullptr: break; + case kTVMStr: { + const char* s = value.v_str; + uint64_t len = strlen(s); + this->Write(len); + this->WriteArray(s, len); + break; + } + case kTVMBytes: { + TVMByteArray* bytes = static_cast(arg_values[i].v_handle); + uint64_t len = bytes->size; + this->Write(len); + this->WriteArray(bytes->data, len); + break; + } + default: { + LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); + break; + } + } + } + std::swap(client_mode_, client_mode); + } + + // Endian aware IO handling + using Stream::Read; + using Stream::Write; + using Stream::ReadArray; + using Stream::WriteArray; + + inline bool Read(RPCCode* code) { + int cdata; + if (!this->Read(&cdata)) return false; + *code = static_cast(cdata); + return true; + } + inline void Write(RPCCode code) { + int cdata = static_cast(code); + this->Write(cdata); + } + + protected: + enum State { + kInitHeader, + kRecvCode, + kRecvCallHandle, + kRecvPackedSeqNumArgs, + kRecvPackedSeqTypeCode, + kRecvPackedSeqArg, + kDoCopyFromRemote, + kDoCopyToRemote, + kReturnReceived, + kCopyAckReceived, + kShutdownReceived + }; + // Current state; + State state_; + // The RPCCode to be read. + RPCCode code_; + // Handle for the remote function call. + uint64_t call_handle_; + // Initialize remote header + bool init_header_step_{0}; + // Number of packed arguments. + int num_packed_args_; + // Current argument index. + int arg_index_; + // The stage of each argument receiver. + int arg_recv_stage_; + // Whether current handler is client or server mode. + bool client_mode_{false}; + // Argument buffer + std::unique_ptr arg_buf_; + // Temp byte buffer. + std::unique_ptr temp_bytes_; + // Temp array buffer. + std::unique_ptr temp_array_; + // Internal temporal data space. + std::string temp_data_; + // Temp variables for copy request state. + TVMContext copy_ctx_; + DLDataType copy_dtype_; + uint64_t copy_handle_, copy_offset_, copy_size_; + // State switcher + void SwitchToState(State state) { + // invariant + CHECK_EQ(pending_request_bytes_, 0U) + << "state=" << state; + state_ = state; + switch (state) { + case kInitHeader: { + LOG(FATAL) << "cannot switch to init header"; + break; + } + case kRecvCode: { + this->RequestBytes(sizeof(RPCCode)); + break; + } + case kRecvCallHandle: { + this->RequestBytes(sizeof(call_handle_)); + break; + } + case kRecvPackedSeqNumArgs: { + this->RequestBytes(sizeof(num_packed_args_)); + break; + } + case kRecvPackedSeqTypeCode: { + this->RequestBytes(sizeof(int) * num_packed_args_); + break; + } + case kRecvPackedSeqArg: { + CHECK_LE(arg_index_, num_packed_args_); + if (arg_index_ == num_packed_args_) { + // The function can change state_ again. + HandlePackedCall(); + } else { + RequestRecvPackedSeqArg(); + } + break; + } + case kDoCopyFromRemote: { + this->RequestBytes(sizeof(uint64_t) * 3); + this->RequestBytes(sizeof(TVMContext)); + this->RequestBytes(sizeof(DLDataType)); + break; + } + case kDoCopyToRemote: { + this->RequestBytes(sizeof(uint64_t) * 3); + this->RequestBytes(sizeof(TVMContext)); + this->RequestBytes(sizeof(DLDataType)); + break; + } + case kCopyAckReceived: + case kReturnReceived: + case kShutdownReceived: { + break; + } + } + } + // Requets bytes needed for next computation. + void RequestRecvPackedSeqArg() { + CHECK_EQ(arg_recv_stage_, 0); + int tcode = arg_buf_->tcode[arg_index_]; + static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant"); + switch (tcode) { + case kDLInt: + case kDLUInt: + case kDLFloat: + case kTVMDataType: + case kTVMOpaqueHandle: + case kTVMStr: + case kTVMBytes: + case kTVMModuleHandle: + case kTVMContext: { + this->RequestBytes(sizeof(TVMValue)); break; + } + case kTVMPackedFuncHandle: { + CHECK(client_mode_) + << "Only client can receive remote functions"; + this->RequestBytes(sizeof(TVMValue)); break; + } + case kTVMNullptr: break; + case kTVMDLTensorHandle: { + this->RequestBytes(sizeof(uint64_t)); + this->RequestBytes(sizeof(TVMContext)); + this->RequestBytes(sizeof(int)); + this->RequestBytes(sizeof(DLDataType)); + break; + } + default: { + LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); + break; + } + } + } + // Handler for packed sequence argument receive. + void HandleRecvPackedSeqArg() { + CHECK_LT(arg_index_, num_packed_args_); + int tcode = arg_buf_->tcode[arg_index_]; + TVMValue& value = arg_buf_->value[arg_index_]; + if (arg_recv_stage_ == 0) { + switch (tcode) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + this->Read(&(value.v_int64)); + ++arg_index_; + this->SwitchToState(kRecvPackedSeqArg); + break; + } + case kTVMDataType: { + this->Read(&(value.v_type)); + int32_t padding = 0; + this->Read(&padding); + ++arg_index_; + this->SwitchToState(kRecvPackedSeqArg); + break; + } + case kTVMContext: { + this->Read(&(value.v_ctx)); + ++arg_index_; + this->SwitchToState(kRecvPackedSeqArg); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle; + this->Read(&handle); + value.v_handle = reinterpret_cast(handle); + ++arg_index_; + this->SwitchToState(kRecvPackedSeqArg); + break; + } + case kTVMNullptr: { + value.v_handle = nullptr; + ++arg_index_; + this->SwitchToState(kRecvPackedSeqArg); + break; + } + case kTVMStr: + case kTVMBytes: { + uint64_t len; + this->Read(&len); + temp_bytes_.reset( new RPCByteArrayBuffer()); + temp_bytes_->data.resize(len); + arg_recv_stage_ = 1; + this->RequestBytes(len); + break; + } + case kTVMDLTensorHandle: { + temp_array_.reset(new RPCDataArrayBuffer()); + uint64_t handle; + this->Read(&handle); + DLTensor& tensor = temp_array_->tensor; + tensor.data = reinterpret_cast(handle); + this->Read(&(tensor.ctx)); + this->Read(&(tensor.ndim)); + this->Read(&(tensor.dtype)); + temp_array_->shape.resize(tensor.ndim); + tensor.shape = temp_array_->shape.data(); + arg_recv_stage_ = 1; + tensor.strides = nullptr; + tensor.byte_offset = 0; + this->RequestBytes(sizeof(int64_t) * tensor.ndim); + break; + } + default: { + LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); + break; + } + } + } else { + CHECK_EQ(arg_recv_stage_, 1); + if (tcode == kTVMStr || tcode == kTVMBytes) { + if (temp_bytes_->data.size() != 0) { + this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size()); + } + if (tcode == kTVMStr) { + value.v_str = temp_bytes_->data.c_str(); + } else { + temp_bytes_->arr.size = static_cast(temp_bytes_->data.size()); + temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data); + value.v_handle = &(temp_bytes_->arr); + } + arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_)); + } else { + CHECK_EQ(tcode, kTVMDLTensorHandle); + DLTensor& tensor = temp_array_->tensor; + this->ReadArray(tensor.shape, tensor.ndim); + value.v_handle = &tensor; + arg_buf_->temp_array.emplace_back(std::move(temp_array_)); + } + ++arg_index_; + arg_recv_stage_ = 0; + this->SwitchToState(kRecvPackedSeqArg); + } + } + // handler for initial header read + void HandleInitHeader() { + if (init_header_step_ == 0) { + int32_t len; + this->Read(&len); + remote_key_->resize(len); + init_header_step_ = 1; + this->RequestBytes(len); + return; + } else { + CHECK_EQ(init_header_step_, 1); + this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); + this->SwitchToState(kRecvCode); + } + } + // Handler for read code. + void HandleRecvCode() { + this->Read(&code_); + if (code_ >= RPCCode::kSyscallCodeStart) { + SwitchToState(kRecvPackedSeqNumArgs); + return; + } + // invariant. + CHECK_EQ(arg_recv_stage_, 0); + switch (code_) { + case RPCCode::kCallFunc: { + SwitchToState(kRecvCallHandle); + break; + } + case RPCCode::kException: + case RPCCode::kReturn: { + SwitchToState(kRecvPackedSeqNumArgs); + break; + } + case RPCCode::kCopyFromRemote: { + SwitchToState(kDoCopyFromRemote); + break; + } + case RPCCode::kCopyToRemote: { + SwitchToState(kDoCopyToRemote); + break; + } + case RPCCode::kShutdown: { + SwitchToState(kShutdownReceived); + break; + } + case RPCCode::kCopyAck: { + SwitchToState(kCopyAckReceived); + break; + } + default: LOG(FATAL) << "Unknown event " << static_cast(code_); + } + } + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + + if (ctx.device_type == kDLCPU) { + RPCCode code = RPCCode::kCopyAck; + this->Write(code); + char* dptr = reinterpret_cast(handle) + offset; + if (!DMLC_IO_NO_ENDIAN_SWAP) { + temp_data_.resize(0); + temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes); + dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); + this->WriteArray(temp_data_.data(), num_bytes); + } else { + this->WriteArray(dptr, num_bytes); + } + } else { + temp_data_.resize(num_bytes + 1); + + try { + GetServingSession()->CopyFromRemote( + reinterpret_cast(handle), offset, + dmlc::BeginPtr(temp_data_), 0, + num_bytes, ctx, type_hint); + + RPCCode code = RPCCode::kCopyAck; + this->Write(code); + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); + } + this->WriteArray(&temp_data_[0], num_bytes); + } catch (const std::runtime_error &e) { + RPCCode code = RPCCode::kException; + this->Write(code); + TVMValue ret_value; + ret_value.v_str = e.what(); + int ret_tcode = kTVMStr; + SendPackedSeq(&ret_value, &ret_tcode, 1, false); + } + } + this->SwitchToState(kRecvCode); + } + + void HandleCopyToRemote() { + // use static variable to persist state. + // This only works if next stage is immediately after this. + if (arg_recv_stage_ == 0) { + CHECK(this->Read(©_handle_)); + CHECK(this->Read(©_offset_)); + CHECK(this->Read(©_size_)); + CHECK(this->Read(©_ctx_)); + CHECK(this->Read(©_dtype_)); + arg_recv_stage_ = 1; + CHECK_EQ(pending_request_bytes_, 0U); + this->RequestBytes(copy_size_); + } else { + CHECK_EQ(arg_recv_stage_, 1); + TVMValue ret_value; + ret_value.v_handle = nullptr; + int ret_tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + std::string errmsg; + + size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8; + if (copy_ctx_.device_type == kDLCPU) { + char* dptr = reinterpret_cast(copy_handle_) + copy_offset_; + this->ReadArray(dptr, copy_size_); + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes); + } + } else { + temp_data_.resize(copy_size_ + 1); + this->ReadArray(&temp_data_[0], copy_size_); + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes); + } + try { + GetServingSession()->CopyToRemote( + dmlc::BeginPtr(temp_data_), 0, + reinterpret_cast(copy_handle_), copy_offset_, + copy_size_, copy_ctx_, copy_dtype_); + } catch (const std::runtime_error &e) { + code = RPCCode::kException; + errmsg = e.what(); + ret_value.v_str = errmsg.c_str(); + ret_tcode = kTVMStr; + } + } + this->Write(code); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); + arg_recv_stage_ = 0; + this->SwitchToState(kRecvCode); + } + } + // Handle for packed call. + void HandlePackedCall(); + + void HandleNormalCallFunc() { + try { + // Need to move out, in case f itself need to call RecvPackedSeq + // Which will override argbuf again. + std::unique_ptr arg_buf = std::move(arg_buf_); + + TVMArgs args = arg_buf->AsTVMArgs(); + + GetServingSession()->CallFunc( + reinterpret_cast(call_handle_), + args.values, args.type_codes, args.size(), + [this](TVMArgs ret) { + RPCCode code = RPCCode::kReturn; + this->Write(code); + SendPackedSeq(ret.values, ret.type_codes, ret.size(), false); + }); + } catch (const std::runtime_error& e) { + RPCCode code = RPCCode::kException; + this->Write(code); + TVMValue ret_value; + ret_value.v_str = e.what(); + int ret_tcode = kTVMStr; + SendPackedSeq(&ret_value, &ret_tcode, 1, false); + } + } + + void HandleInitServer() { + auto handler = [this](RPCSession*, TVMArgs args, TVMRetValue* rv) { + CHECK(serving_session_ == nullptr) + << "Server has already been initialized"; + CHECK_GE(args.size(), 1U); + std::string client_protocol_ver = args[0]; + std::string server_protocol_ver = kRPCProtocolVer; + CHECK_EQ(client_protocol_ver, server_protocol_ver) + << "Server[" << name_ << "]: Client protocol version mismatch with the server " + << " server protocol=" << server_protocol_ver + << ", client protocol=" << client_protocol_ver; + + if (args.size() == 1) { + serving_session_ = std::make_shared(); + } else { + std::string constructor_name = args[1]; + auto* fconstructor = Registry::Get(constructor_name); + CHECK(fconstructor != nullptr) + << " Cannot find session constructor " << constructor_name; + TVMRetValue con_ret; + try { + fconstructor->CallPacked( + TVMArgs(args.values + 2, args.type_codes + 2, args.size() - 2), &con_ret); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "Server[" << name_ << "]:" + << " Error caught from session constructor " << constructor_name + << ":\n" << e.what(); + } + + CHECK_EQ(con_ret.type_code(), kTVMModuleHandle) + << "Server[" << name_ << "]:" + << " Constructor " << constructor_name + << " need to return an RPCModule"; + Module mod = con_ret; + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") + << "Constructor " << constructor_name << " to return an RPCModule"; + serving_session_ = RPCModuleGetSession(mod); + } + }; + this->SysCallHandler(handler, false); + } + + // Handler for special syscalls that have a specific RPCCode. + template + void SysCallHandler(F f, bool need_serving_session = true) { + TVMRetValue rv; + TVMValue ret_value; + int ret_tcode; + try { + // Need to move out, in case f itself need to call RecvPackedSeq + // Which will override argbuf again. + std::unique_ptr args = std::move(arg_buf_); + + if (need_serving_session) { + f(GetServingSession(), args->AsTVMArgs(), &rv); + } else { + f(nullptr, args->AsTVMArgs(), &rv); + } + + RPCCode code = RPCCode::kReturn; + this->Write(code); + TVMArgsSetter setter(&ret_value, &ret_tcode); + setter(0, rv); + SendPackedSeq(&ret_value, &ret_tcode, 1, false); + } catch (const std::runtime_error& e) { + RPCCode code = RPCCode::kException; + this->Write(code); + ret_value.v_str = e.what(); + ret_tcode = kTVMStr; + SendPackedSeq(&ret_value, &ret_tcode, 1, false); + } + } + + private: + RPCSession* GetServingSession() const { + CHECK(serving_session_ != nullptr) + << "Need to call InitRemoteSession first before any further actions"; + return serving_session_.get(); + } + // Utility functions + // Internal read function, update pending_request_bytes_ + size_t Read(void* data, size_t size) final { + CHECK_LE(size, pending_request_bytes_); + reader_->Read(data, size); + pending_request_bytes_ -= size; + return size; + } + // wriite the data to the channel. + void Write(const void* data, size_t size) final { + writer_->Write(data, size); + } + // Number of pending bytes requests + size_t pending_request_bytes_{0}; + // The ring buffer to read data from. + support::RingBuffer* reader_; + // The ringr buffer to write reply to. + support::RingBuffer* writer_; + // The session used to serve the RPC requests. + std::shared_ptr serving_session_; + // Name of endpoint. + std::string name_; + // remote key + std::string* remote_key_; +}; + +RPCCode RPCEndpoint::HandleUntilReturnEvent( + bool client_mode, RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kCallFunc; + while (code != RPCCode::kReturn && + code != RPCCode::kShutdown && + code != RPCCode::kCopyAck) { + while (writer_.bytes_available() != 0) { + writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + } + size_t bytes_needed = handler_->BytesNeeded(); + if (bytes_needed != 0) { + size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { + return channel_->Recv(data, size); + }, bytes_needed); + if (n == 0) { + if (handler_->CanCleanShutdown()) { + return RPCCode::kShutdown; + } else { + LOG(FATAL) << "Channel closes before we get neded bytes"; + } + } + } + code = handler_->HandleNextEvent(client_mode, setreturn); + } + return code; +} + +void RPCEndpoint::Init() { + // Event handler + handler_ = std::make_shared( + &reader_, &writer_, name_, &remote_key_); + // Quick function to for syscall remote. + syscall_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + RPCCode code = HandleUntilReturnEvent(true, [rv](TVMArgs args) { + CHECK_EQ(args.size(), 1); + *rv = args[0]; + }); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); + }); +} + +std::shared_ptr RPCEndpoint::Create( + std::unique_ptr channel, + std::string name, + std::string remote_key) { + std::shared_ptr endpt = std::make_shared(); + endpt->channel_ = std::move(channel); + endpt->name_ = std::move(name); + endpt->remote_key_ = std::move(remote_key); + endpt->Init(); + return endpt; +} + +RPCEndpoint::~RPCEndpoint() { + this->Shutdown(); +} + +void RPCEndpoint::Shutdown() { + if (channel_ != nullptr) { + RPCCode code = RPCCode::kShutdown; + handler_->Write(code); + // flush all writing buffer to output channel. + try { + while (writer_.bytes_available() != 0) { + size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + if (n == 0) break; + } + } catch (const dmlc::Error& e) { + } + channel_.reset(nullptr); + } +} + +void RPCEndpoint::ServerLoop() { + std::lock_guard lock(mutex_); + if (const auto* f = Registry::Get("tvm.rpc.server.start")) { + (*f)(); + } + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown); + if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { + (*f)(); + } + channel_.reset(nullptr); +} + +int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kNone; + if (in_bytes.length() != 0) { + reader_.Write(in_bytes.c_str(), in_bytes.length()); + code = handler_->HandleNextEvent(false, [](TVMArgs) {}); + } + if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { + writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + } + CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); + if (code == RPCCode::kShutdown) return 0; + if (writer_.bytes_available() != 0) return 2; + return 1; +} + +void RPCEndpoint::InitRemoteSession(TVMArgs args) { + std::lock_guard lock(mutex_); + + std::vector values(1, TVMValue()); + std::vector tcodes(1); + + values.insert(values.end(), args.values, args.values + args.size()); + tcodes.insert(tcodes.end(), args.type_codes, args.type_codes + args.size()); + + TVMArgs init_args(values.data(), tcodes.data(), values.size()); + TVMArgsSetter setter(values.data(), tcodes.data()); + std::string protocol_ver = kRPCProtocolVer; + + setter(0, protocol_ver); + + RPCCode code = RPCCode::kInitServer; + handler_->Write(code); + TVMRetValue rv; + syscall_remote_.CallPacked(init_args, &rv); +} + +// Get remote function with name +void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + RPCSession::FEncodeReturn encode_return) { + std::lock_guard lock(mutex_); + + handler_->ValidateArguments(arg_values, arg_type_codes, num_args); + + RPCCode code = RPCCode::kCallFunc; + handler_->Write(code); + uint64_t handle = reinterpret_cast(h); + handler_->Write(handle); + handler_->SendPackedSeq( + arg_values, arg_type_codes, num_args, true); + code = HandleUntilReturnEvent(true, encode_return); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +void RPCEndpoint::CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t data_size, + TVMContext ctx_to, + DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyToRemote; + handler_->Write(code); + uint64_t handle = reinterpret_cast(to); + handler_->Write(handle); + uint64_t offset = static_cast(to_offset); + handler_->Write(offset); + uint64_t size = static_cast(data_size); + handler_->Write(size); + handler_->Write(ctx_to); + handler_->Write(type_hint); + handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); + + CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kReturn); +} + +void RPCEndpoint::CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t data_size, + TVMContext ctx_from, + DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyFromRemote; + handler_->Write(code); + uint64_t handle = reinterpret_cast(from); + handler_->Write(handle); + uint64_t offset = static_cast(from_offset); + handler_->Write(offset); + uint64_t size = static_cast(data_size); + handler_->Write(size); + handler_->Write(ctx_from); + handler_->Write(type_hint); + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kCopyAck); + reader_.Reserve(data_size); + handler_->RequestBytes(data_size); + while (!handler_->Ready()) { + size_t bytes_needed = handler_->BytesNeeded(); + reader_.WriteWithCallback([this](void* data, size_t size) { + size_t n = channel_->Recv(data, size); + CHECK_NE(n, 0U) << "Channel closes before we get neded bytes"; + return n; + }, bytes_needed); + } + handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); + handler_->FinishCopyAck(); +} + + +// SysCallEventHandler functions +void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + *rv = handler->GetFunction(name); +} + +void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + void* handle = args[0]; + int type_code = args[1]; + handler->FreeHandle(handle, type_code); +} + +void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + handler->GetDeviceAPI(ctx)->SetDevice(ctx); +} + +void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + DeviceAttrKind kind = static_cast(args[1].operator int()); + if (kind == kExist) { + DeviceAPI* api = handler->GetDeviceAPI(ctx, true); + if (api != nullptr) { + api->GetAttr(ctx, kind, rv); + } else { + *rv = 0; + } + } else { + handler->GetDeviceAPI(ctx)->GetAttr( + ctx, static_cast(kind), rv); + } +} + +void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + uint64_t nbytes = args[1]; + uint64_t alignment = args[2]; + DLDataType type_hint = args[3]; + void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace( + ctx, nbytes, alignment, type_hint); + *rv = data; +} + +void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + void* ptr = args[1]; + handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr); +} + +void RPCDevStreamSync(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + TVMStreamHandle handle = args[1]; + handler->GetDeviceAPI(ctx)->StreamSync(ctx, handle); +} + +void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + void* from = args[0]; + uint64_t from_offset = args[1]; + void* to = args[2]; + uint64_t to_offset = args[3]; + uint64_t size = args[4]; + TVMContext ctx_from = args[5]; + TVMContext ctx_to = args[6]; + DLDataType type_hint = args[7]; + TVMStreamHandle stream = args[8]; + TVMContext ctx = ctx_from; + + if (ctx.device_type == kDLCPU) { + ctx = ctx_to; + } else { + CHECK(ctx_to.device_type == kDLCPU || + ctx_to.device_type == ctx_from.device_type) + << "Can not copy across different ctx types directly"; + } + handler->GetDeviceAPI(ctx)->CopyDataFromTo( + from, from_offset, + to, to_offset, + size, ctx_from, ctx_to, type_hint, stream); +} + +void RPCEndpoint::EventHandler::HandlePackedCall() { + CHECK_EQ(pending_request_bytes_, 0U); + if (code_ == RPCCode::kReturn) { + state_ = kReturnReceived; return; + } + // reset state to clean init state + state_ = kRecvCode; + this->RequestBytes(sizeof(RPCCode)); + // Event handler sit at clean state at this point. + switch (code_) { + case RPCCode::kCallFunc: { + this->HandleNormalCallFunc(); + break; + } + case RPCCode::kException: { + CHECK_EQ(arg_buf_->value.size(), 1U); + CHECK_EQ(arg_buf_->tcode[0], kTVMStr); + + std::ostringstream os; + os << "RPCError: Error caught from RPC call:\n" << arg_buf_->value[0].v_str; + arg_buf_.reset(); + throw dmlc::Error(os.str()); + break; + } + // system functions + case RPCCode::kInitServer: this->HandleInitServer(); break; + case RPCCode::kFreeHandle: SysCallHandler(RPCFreeHandle); break; + case RPCCode::kGetGlobalFunc: SysCallHandler(RPCGetGlobalFunc); break; + case RPCCode::kDevSetDevice: SysCallHandler(RPCDevSetDevice); break; + case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break; + case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break; + case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; + case RPCCode::kDevStreamSync: SysCallHandler(RPCDevStreamSync); break; + case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; + + default: LOG(FATAL) << "Unknown event " << static_cast(code_); + } + CHECK_EQ(state_, kRecvCode); +} + +/*! + * \brief RPC client session that proxies all calls to an endpoint. + */ +class RPCClientSession : public RPCSession, + public DeviceAPI { + public: + /*! + * \brief param endpoint The client endpoint of the session. + */ + explicit RPCClientSession(std::shared_ptr endpoint) + : endpoint_(endpoint) {} + + // function overrides + PackedFuncHandle GetFunction(const std::string& name) final { + return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name); + } + + void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) final { + endpoint_->CallFunc( + func, arg_values, arg_type_codes, num_args, fencode_return); + } + + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) final { + endpoint_->CopyToRemote( + from, from_offset, to, to_offset, nbytes, ctx_to, type_hint); + } + + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) final { + endpoint_->CopyFromRemote( + from, from_offset, to, to_offset, nbytes, ctx_from, type_hint); + } + + void FreeHandle(void* handle, int type_code) final { + endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code); + } + + + void SetDevice(TVMContext ctx) final { + endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx); + } + + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { + *rv = endpoint_->SysCallRemote(RPCCode::kDevGetAttr, ctx, static_cast(kind)); + } + + void* AllocDataSpace(TVMContext ctx, + size_t nbytes, + size_t alignment, + DLDataType type_hint) final { + return endpoint_->SysCallRemote( + RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + } + + void FreeDataSpace(TVMContext ctx, void* ptr) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeData, ptr); + } + + void CopyDataFromTo(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t size, + TVMContext ctx_from, + TVMContext ctx_to, + DLDataType type_hint, + TVMStreamHandle stream) final { + endpoint_->SysCallRemote( + RPCCode::kCopyAmongRemote, + const_cast(from), from_offset, + to, to_offset, + size, + ctx_from, ctx_to, + type_hint, stream); + } + + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevStreamSync, ctx, stream); + } + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final { + return this; + } + + private: + std::shared_ptr endpoint_; +}; + +std::shared_ptr +CreateClientSession(std::shared_ptr endpoint) { + return std::make_shared(endpoint); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h new file mode 100644 index 0000000000000..6bf6556d6164e --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.h @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_endpoint.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ +#define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ + +#include +#include +#include +#include +#include +#include "rpc_session.h" +#include "rpc_channel.h" +#include "../../support/ring_buffer.h" + +namespace tvm { +namespace runtime { + +// Magic header for RPC data plane +const int kRPCMagic = 0xff271; +// magic header for RPC tracker(control plane) +const int kRPCTrackerMagic = 0x2f271; +// sucess response +const int kRPCSuccess = kRPCMagic + 0; +// cannot found matched key in server +const int kRPCMismatch = kRPCMagic + 2; +// RPC procotol version, used for better error reporting. +constexpr const char* kRPCProtocolVer = "0.7.0"; + +/*! \brief Enumeration code for the RPC tracker */ +enum class TrackerCode : int { + kFail = -1, + kSuccess = 0, + kPing = 1, + kStop = 2, + kPut = 3, + kRequest = 4, + kUpdateInfo = 5, + kSummary = 6, + kGetPendingMatchKeys = 7 +}; + +/*! \brief The RPC code */ +enum class RPCCode : int { + kNone, + kCallFunc, + kReturn, + kException, + kShutdown, + kCopyFromRemote, + kCopyToRemote, + kCopyAck, + kInitServer, + // The following are syscall code that can send over CallRemote + kSyscallCodeStart = kInitServer, + kGetGlobalFunc, + kFreeHandle, + kDevSetDevice, + kDevGetAttr, + kDevAllocData, + kDevFreeData, + kDevStreamSync, + kCopyAmongRemote, +}; + +/*! + * \brief Communication endpoints to connect local and remote RPC sessions. + * An endpoint can either be a client or a server. + */ +class RPCEndpoint { + public: + /*! \brief virtual destructor */ + ~RPCEndpoint(); + /*! + * \brief The server loop that server runs to handle RPC calls. + */ + void ServerLoop(); + /*! + * \brief Message handling function for an async IO event driven server. + * + * Called when the server receives a message or an IO event update. + * Event driven handler will never call recv on the channel + * and always relies on the ServerIOEventHandler to receive the data. + * + * \param in_bytes The incoming bytes. + * \param event_flag 1: read_available, 2: write_avaiable. + * \return State flag. + * 1: continue running, no need to write, + * 2: need to write + * 0: shutdown + */ + int ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag); + + /*! + * \brief Initalize the session on the remote that will be used to back all the RPC requests. + * + * If no session constructor arguments is passed, LocalSession will be used in the remote. + * Otherwise the remote serving session will be constructed using the arguments + * specified in the session_constructor. + * + * The construction rule can be summarized as follows: + * + * \code + * + * auto args = session_constructor; + * int n = args.size(); + * if (n != 0) { + * std::string constructor = args[0]; + * server.serving_session_ = GetGlobalFunc(constructor)( + * args[1], args[2] ... args[n - 1]) + * } else { + * server.serving_session_ = LocalSession(); + * } + * \endcode + * + * \param session_constructor Optional sequence of the remote sesssion constructor. + */ + void InitRemoteSession(TVMArgs session_constructor); + + /*! + * \brief Call into remote function + * \param handle The function handle + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to receive return value encodings. + */ + void CallFunc(RPCSession::PackedFuncHandle handle, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + RPCSession::FEncodeReturn encode_return); + /*! + * \brief Copy bytes into remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_to The target context. + * \param type_hint Hint of content data type. + */ + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint); + /*! + * \brief Copy bytes from remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_from The source context. + * \param type_hint Hint of content data type. + */ + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint); + + /*! + * \brief Call a remote defined system function with arguments. + * \param fcode The function code. + * \param args The arguments + * \return The returned remote value. + */ + template + inline TVMRetValue SysCallRemote(RPCCode fcode, Args&& ...args); + /*! + * \brief Create a RPC session with given channel. + * \param channel The communication channel. + * \param name The local name of the session, used for debug + * \param remote_key The remote key of the session + * if remote_key equals "%toinit", we need to re-intialize + * it by event handler. + * \params server_handler The request handler on the server side, + * can be nullptr on the client. + */ + static std::shared_ptr Create( + std::unique_ptr channel, + std::string name, + std::string remote_key); + + private: + class EventHandler; + // Handle events until receives a return + // Also flushes channels so that the function advances. + RPCCode HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn); + // Initalization + void Init(); + // Shutdown + void Shutdown(); + // Internal channel. + std::unique_ptr channel_; + // Internal mutex + std::recursive_mutex mutex_; + // Internal ring buffer. + support::RingBuffer reader_, writer_; + // Event handler. + std::shared_ptr handler_; + // syscall remote with specified function code. + PackedFunc syscall_remote_; + // The name of the session. + std::string name_; + // The remote key + std::string remote_key_; +}; + +/*! + * \brief Create an RPC client session from an RPC client endpoint. + * \param endpoint The endpoint. + * \return The created session. + */ +std::shared_ptr +CreateClientSession(std::shared_ptr endpoint); + +// implementation of inline functions +template +inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&& ...args) { + std::lock_guard lock(mutex_); + writer_.Write(&code, sizeof(code)); + return syscall_remote_(std::forward(args)...); +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 29adb0fed108d..284dca5cce6b0 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,11 +19,12 @@ /*! * \file rpc_event_impl.cc - * \brief Event based RPC server implementation. + * \brief Event driven RPC server implementation. */ #include #include -#include "rpc_session.h" +#include "rpc_endpoint.h" +#include "rpc_local_session.h" namespace tvm { namespace runtime { @@ -35,16 +36,17 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, LOG(FATAL) << "Do not allow explicit receive"; return 0; }); + std::unique_ptr ch(new CallbackChannel(fsend, frecv)); - std::shared_ptr sess = - RPCSession::Create(std::move(ch), name, remote_key); + std::shared_ptr sess = + RPCEndpoint::Create(std::move(ch), name, remote_key); return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - int ret = sess->ServerEventHandler(args[0], args[1]); + int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]); *rv = ret; }); } -TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer") +TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer") .set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc new file mode 100644 index 0000000000000..0a2809bbaf4c9 --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file local_session.cc + * \brief Local session that directs requests to local API. + */ +#include +#include +#include +#include "rpc_local_session.h" + +namespace tvm { +namespace runtime { + +RPCSession::PackedFuncHandle +LocalSession::GetFunction(const std::string& name) { + PackedFunc pf = this->GetFunctionInternal(name); + // return raw handl because the remote need to explicitly manage it. + if (pf != nullptr) return new PackedFunc(pf); + return nullptr; +} + +void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& encode_return) { + auto* pf = static_cast(func); + TVMRetValue rv; + + pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); + int rv_tcode = rv.type_code(); + + // return value encoding. + TVMValue ret_value_pack[3]; + int ret_tcode_pack[3]; + TVMArgsSetter set_arg(ret_value_pack, ret_tcode_pack); + // first location always encode type code. + set_arg(0, rv_tcode); + + if (rv_tcode == kTVMNDArrayHandle) { + // We follow a special protocol to return NDArray to client side + // The first pack value is the NDArray handle as DLTensor + // The second pack value is a customized deleter that deletes the NDArray. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMDLTensorHandle; + ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; + ret_tcode_pack[2] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); + } else if (rv_tcode == kTVMPackedFuncHandle || + rv_tcode == kTVMModuleHandle) { + // MoveToCHost means rv no longer manages the object. + // return handle instead. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else if (rv_tcode == kTVMBytes) { + TVMByteArray byte_arr; + auto* sptr = rv.ptr(); + byte_arr.data = sptr->data(); + byte_arr.size = sptr->length(); + set_arg(1, byte_arr); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else { + set_arg(1, rv); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } +} + +void LocalSession::CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + this->GetDeviceAPI(ctx_to)->CopyDataFromTo( + from, from_offset, + to, to_offset, + nbytes, cpu_ctx, ctx_to, type_hint, nullptr); +} + +void LocalSession::CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + + this->GetDeviceAPI(ctx_from)->CopyDataFromTo( + from, from_offset, + to, to_offset, + nbytes, ctx_from, cpu_ctx, type_hint, nullptr); +} + +void LocalSession::FreeHandle(void* handle, int type_code) { + TVMValue value; + value.v_handle = handle; + // will trigger deleter once the rv goes out of the scope. + TVMRetValue rv = TVMRetValue::MoveFromCHost(value, type_code); +} + +DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool allow_missing) { + return DeviceAPI::Get(ctx, allow_missing); +} + +PackedFunc LocalSession::GetFunctionInternal(const std::string& name) { + auto* fp = tvm::runtime::Registry::Get(name); + if (fp != nullptr) { + return *fp; + } else { + return nullptr; + } +} + +TVM_REGISTER_GLOBAL("rpc.LocalSession") +.set_body_typed([]() { + return CreateRPCSessionModule(std::make_shared()); +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h new file mode 100644 index 0000000000000..ebb3ea11c50e7 --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_local_session.h + * \brief Local session that directs all request to the local runtime API. + */ +#ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ +#define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ + +#include +#include +#include +#include +#include "rpc_session.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief A local session that directly use the handle repr of the + * local tvm runtime objects on the same process. + */ +class LocalSession : public RPCSession { + public: + // function overrides + PackedFuncHandle GetFunction(const std::string& name) final; + + void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) final; + + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) final; + + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) final; + + void FreeHandle(void* handle, int type_code) final; + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) final; + + protected: + /*! + * \brief Internal implementation of GetFunction. + * \param name The name of the function. + * \return The corresponding PackedFunc. + */ + virtual PackedFunc GetFunctionInternal(const std::string& name); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 0e48e6fb27089..3d8560e6b558e 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -18,63 +18,117 @@ */ /*! - * \file rpc_device_api.cc - * \brief RPC module. + * \file rpc_module.cc + * \brief RPC runtime module. */ #include +#include #include #include +#include "rpc_endpoint.h" #include "rpc_session.h" namespace tvm { namespace runtime { -// Wrapped remote function to packed func. -class RPCWrappedFunc { +/*! + * \brief A wrapped remote function as a PackedFunc. + */ +class RPCWrappedFunc : public Object { public: RPCWrappedFunc(void* handle, std::shared_ptr sess) : handle_(handle), sess_(sess) { - fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - WrapRemote(sess, args, rv); - }); } - void operator()(TVMArgs args, TVMRetValue *rv) const { - sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_); + void operator()(TVMArgs args, TVMRetValue* rv) const { + std::vector values(args.values, args.values + args.size()); + std::vector type_codes(args.type_codes, args.type_codes + args.size()); + std::vector> temp_dltensors; + + // scan and check whether we need rewrite these arguments + // to their remote variant. + for (int i = 0; i < args.size(); ++i) { + int tcode = type_codes[i]; + + switch (tcode) { + case kTVMDLTensorHandle: + case kTVMNDArrayHandle: { + // Pass NDArray as DLTensor, NDArray and DLTensor + // are compatible to each other, just need to change the index. + type_codes[i] = kTVMDLTensorHandle; + // translate to a remote view of DLTensor + auto dptr = std::make_unique( + *static_cast(values[i].v_handle)); + dptr->ctx = RemoveSessMask(dptr->ctx); + dptr->data = static_cast(dptr->data)->data; + values[i].v_handle = dptr.get(); + temp_dltensors.emplace_back(std::move(dptr)); + break; + } + case kTVMContext: { + values[i].v_ctx = RemoveSessMask(values[i].v_ctx); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + values[i].v_handle = UnwrapRemoteValueToHandle( + TVMArgValue(values[i], tcode)); + break; + } + } + } + auto set_return = [this, rv](TVMArgs args) { + this->WrapRemoteReturnToValue(args, rv); + }; + sess_->CallFunc(handle_, values.data(), type_codes.data(), + args.size(), set_return); } + ~RPCWrappedFunc() { try { - sess_->CallRemote(RPCCode::kFreeFunc, handle_); + sess_->FreeHandle(handle_, kTVMPackedFuncHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } } - static void WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue* rv); + private: + // remote function handle + void* handle_{nullptr}; + // pointer to the session. + std::shared_ptr sess_; + + // unwrap a remote value to the underlying handle. + void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const; + // wrap a remote return via Set + void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const; - static void* UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg); + // remove a remote session mask + TVMContext RemoveSessMask(TVMContext ctx) const { + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "Can not pass in local context or context with a different remote session"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } // deleter of RPC remote array static void RemoteNDArrayDeleter(Object* obj) { auto* ptr = static_cast(obj); RemoteSpace* space = static_cast(ptr->dl_tensor.data); - space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); + space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle); delete space; delete ptr; } + // wrap return value as remote NDArray. - static NDArray WrapRemoteNDArray(std::shared_ptr sess, - DLTensor* tensor, - void* nd_handle) { + NDArray WrapRemoteNDArray(DLTensor* tensor, void* nd_handle) const { NDArray::Container* data = new NDArray::Container(); data->manager_ctx = nd_handle; data->SetDeleter(RemoteNDArrayDeleter); RemoteSpace* space = new RemoteSpace(); - space->sess = sess; + space->sess = sess_; space->data = tensor->data; data->dl_tensor.data = space; NDArray ret(GetObjectPtr(data)); @@ -89,18 +143,13 @@ class RPCWrappedFunc { data->dl_tensor.ctx.device_id = tensor->ctx.device_id; data->dl_tensor.ctx.device_type = static_cast( static_cast(tensor->ctx.device_type) + - kRPCSessMask * (sess->table_index() + 1)); + kRPCSessMask * (sess_->table_index() + 1)); // check strides. CHECK(tensor->strides == nullptr); // setup byteoffset data->dl_tensor.byte_offset = tensor->byte_offset; return ret; } - - private: - PackedFunc fwrap_; - void* handle_{nullptr}; - std::shared_ptr sess_; }; // RPC that represents a remote module session. @@ -109,10 +158,11 @@ class RPCModuleNode final : public ModuleNode { RPCModuleNode(void* module_handle, std::shared_ptr sess) : module_handle_(module_handle), sess_(sess) { } + ~RPCModuleNode() { if (module_handle_ != nullptr) { try { - sess_->CallRemote(RPCCode::kModuleFree, module_handle_); + sess_->FreeHandle(module_handle_, kTVMModuleHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } @@ -127,31 +177,56 @@ class RPCModuleNode final : public ModuleNode { PackedFunc GetFunction( const std::string& name, const ObjectPtr& sptr_to_self) final { - RPCFuncHandle handle = GetFuncHandle(name); - return WrapRemote(handle); + if (module_handle_ == nullptr) { + return WrapRemoteFunc(sess_->GetFunction(name)); + } else { + InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); + return remote_mod_get_function_(GetRef(this), name, false); + } } std::string GetSource(const std::string& format) final { - if (module_handle_ != nullptr) { - std::string ret = sess_->CallRemote( - RPCCode::kModuleGetSource, module_handle_, format); - } + LOG(FATAL) << "GetSource for rpc Module is not supported"; return ""; } - std::shared_ptr& sess() { - return sess_; - } - PackedFunc GetTimeEvaluator(const std::string& name, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - RPCFuncHandle handle = GetFuncHandle(name); - if (handle == nullptr) return PackedFunc(); - handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms); - return WrapRemote(handle); + InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator"); + // Remove session mask because we pass ctx by parts. + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "ValueError: Need to pass the matched remote context to RPCModule.GetTimeEvaluator"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + + if (module_handle_ != nullptr) { + return remote_get_time_evaluator_( + GetRef(this), name, + static_cast(ctx.device_type), ctx.device_id, + number, repeat, min_repeat_ms); + } else { + return remote_get_time_evaluator_( + Optional(nullptr), name, + static_cast(ctx.device_type), ctx.device_id, + number, repeat, min_repeat_ms); + } + } + + Module LoadModule(std::string name) { + InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module"); + return remote_load_module_(name); + } + + void ImportModule(Module other) { + InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); + remote_import_module_(GetRef(this), other); + } + + const std::shared_ptr& sess() { + return sess_; } void* module_handle() const { @@ -159,7 +234,15 @@ class RPCModuleNode final : public ModuleNode { } private: - PackedFunc WrapRemote(RPCFuncHandle handle) { + template + void InitRemoteFunc(FType* func, const std::string& name) { + if (*func != nullptr) return; + RPCSession::PackedFuncHandle handle = sess_->GetFunction(name); + CHECK(handle != nullptr) << "Cannot found remote function " << name; + *func = WrapRemoteFunc(handle); + } + + PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { if (handle == nullptr) return PackedFunc(); auto wf = std::make_shared(handle, sess_); return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { @@ -167,33 +250,30 @@ class RPCModuleNode final : public ModuleNode { }); } - RPCFuncHandle GetFuncHandle(const std::string& name) { - RPCFuncHandle handle = nullptr; - if (module_handle_ == nullptr) { - handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name); - } else { - handle = sess_->CallRemote( - RPCCode::kModuleGetFunc, module_handle_, name); - } - return handle; - } // The module handle void* module_handle_{nullptr}; // The local channel std::shared_ptr sess_; - // Wrap function to wrap remote module/function. - PackedFunc fwrap_; + // remote function to get time evaluator + TypedPackedFunc, std::string, int, int, int, int, int)> + remote_get_time_evaluator_; + // remote function getter for modules. + TypedPackedFunc remote_mod_get_function_; + // remote function getter for load module + TypedPackedFunc remote_load_module_; + // remote function getter for load module + TypedPackedFunc remote_import_module_; }; -void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg) { + +void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { if (arg.type_code() == kTVMModuleHandle) { Module mod = arg; std::string tkey = mod->type_key(); CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); - CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index) + CHECK(rmod->sess() == sess_) << "ValueError: Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { @@ -204,93 +284,172 @@ void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, } } -void RPCWrappedFunc::WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue *rv) { - void* handle = args.values[0].v_handle; - int tcode = args.type_codes[0]; +void RPCWrappedFunc::WrapRemoteReturnToValue( + TVMArgs args, + TVMRetValue *rv) const { + int tcode = args[0]; - if (handle == nullptr) return; + if (tcode == kTVMNullptr) return; if (tcode == kTVMPackedFuncHandle) { - auto wf = std::make_shared(handle, sess); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto wf = std::make_shared(handle, sess_); *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + return wf->operator()(args, rv); + }); } else if (tcode == kTVMModuleHandle) { - auto n = make_object(handle, sess); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto n = make_object(handle, sess_); *rv = Module(n); } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) { - CHECK_EQ(args.size(), 2); - DLTensor* tensor = args[0]; - void* nd_handle = args[1]; - *rv = WrapRemoteNDArray(sess, tensor, nd_handle); + CHECK_EQ(args.size(), 3); + DLTensor* tensor = args[1]; + void* nd_handle = args[2]; + *rv = WrapRemoteNDArray(tensor, nd_handle); } else { - LOG(FATAL) << "Cannot wrap tcode=" << tcode; + CHECK_EQ(args.size(), 2); + *rv = args[1]; } } -Module CreateRPCModule(std::shared_ptr sess) { +Module CreateRPCSessionModule(std::shared_ptr sess) { auto n = make_object(nullptr, sess); + RPCSession::InsertToSessionTable(sess); return Module(n); } +std::shared_ptr RPCModuleGetSession(Module mod) { + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") + << "ValueError: Cannot pass a non-RPC module to remote"; + auto* rmod = static_cast(mod.operator->()); + return rmod->sess(); +} + +PackedFunc WrapTimeEvaluator(PackedFunc pf, + TVMContext ctx, + int number, + int repeat, + int min_repeat_ms) { + CHECK(pf != nullptr); + + if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { + auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); + CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; + return (*get_micro_time_evaluator)(pf, ctx, number, repeat); + } + + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) + mutable { + TVMRetValue temp; + std::ostringstream os; + // skip first time call, to activate lazy compilation components. + pf.CallPacked(args, &temp); + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + + for (int i = 0; i < repeat; ++i) { + std::chrono::time_point< + std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + number = static_cast( + std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + + tbegin = std::chrono::high_resolution_clock::now(); + // start timing + for (int i = 0; i < number; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + tend = std::chrono::high_resolution_clock::now(); + + duration_ms = std::chrono::duration_cast > + (tend - tbegin).count() * 1000; + } while (duration_ms < min_repeat_ms); + + double speed = std::chrono::duration_cast >( + tend - tbegin).count() / number; + os.write(reinterpret_cast(&speed), sizeof(speed)); + } + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + + TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; +.set_body_typed([](Optional opt_mod, + std::string name, + int device_type, + int device_id, + int number, + int repeat, + int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + if (opt_mod.defined()) { + Module m = opt_mod.value(); std::string tkey = m->type_key(); - TVMContext ctx; - ctx.device_type = static_cast(args[2].operator int()); - ctx.device_id = args[3]; if (tkey == "rpc") { - *rv = static_cast(m.operator->()) - ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]); + return static_cast(m.operator->()) + ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms); } else { - *rv = WrapTimeEvaluator( - m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]); + return WrapTimeEvaluator( + m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); } - }); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapTimeEvaluator( + *pf, ctx, number, repeat, min_repeat_ms); + } +}); -TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - auto& sess = static_cast(m.operator->())->sess(); - void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]); - auto n = make_object(mhandle, sess); - *rv = Module(n); - }); +// server function registration. +TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule") +.set_body_typed([](Module parent, Module child) { + parent->Import(child); +}); -TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module parent = args[0]; - Module child = args[1]; - CHECK(!std::strcmp(parent->type_key(), "rpc") && - !std::strcmp(child->type_key(), "rpc")); - auto* pmod = static_cast(parent.operator->()); - auto* cmod = static_cast(child.operator->()); - CHECK(pmod->sess().get() == cmod->sess().get()) - << "Import of remote module need to belong to same session."; - pmod->sess()->CallRemote(RPCCode::kModuleImport, - pmod->module_handle(), - cmod->module_handle()); - }); - -TVM_REGISTER_GLOBAL("rpc._ModuleHandle") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->module_handle(); - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") +.set_body_typed([](Module parent, std::string name, bool query_imports) { + return parent->GetFunction(name, query_imports); +}); + +// functions to access an RPC module. +TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule") +.set_body_typed([](Module sess, std::string name) { + std::string tkey = sess->type_key(); + CHECK_EQ(tkey, "rpc"); + return static_cast(sess.operator->())->LoadModule(name); +}); -TVM_REGISTER_GLOBAL("rpc._SessTableIndex") +TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule") +.set_body_typed([](Module parent, Module child) { + std::string tkey = parent->type_key(); + CHECK_EQ(tkey, "rpc"); + static_cast(parent.operator->())->ImportModule(child); +}); + +TVM_REGISTER_GLOBAL("rpc.SessTableIndex") .set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->sess()->table_index(); - }); + Module m = args[0]; + std::string tkey = m->type_key(); + CHECK_EQ(tkey, "rpc"); + *rv = static_cast(m.operator->())->sess()->table_index(); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index f6a7fb60b5f4d..612ca418e812b 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,8 @@ namespace tvm { namespace runtime { std::string RPCGetPath(const std::string& name) { - static const PackedFunc* f = + // do live lookup everytime as workpath can change. + const PackedFunc* f = runtime::Registry::Get("tvm.rpc.server.workpath"); CHECK(f != nullptr) << "require tvm.rpc.server.workpath"; return (*f)(name); diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index ae293abfacdd2..dd0afa0145d22 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -21,817 +21,16 @@ * \file rpc_session.cc * \brief RPC session for remote function call. */ -#include #include #include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include -#include #include "rpc_session.h" -#include "../object_internal.h" -#include "../../support/ring_buffer.h" -#include "../../support/socket.h" -#include "../micro/micro_session.h" namespace tvm { namespace runtime { -// Temp buffer for data array -struct RPCByteArrayBuffer { - TVMByteArray arr; - std::string data; -}; -// Temp buffer for data array -struct RPCDataArrayBuffer { - DLTensor tensor; - std::vector shape; -}; -/*! - * \brief Temporal argument buffer. - */ -struct RPCArgBuffer { - // The argument values - std::vector value; - // The type codes. - std::vector tcode; - // Temporal resources. - std::vector > temp_bytes; - // Temporal array - std::vector > temp_array; - // convert buffer as TVMArgs - TVMArgs AsTVMArgs() const { - return TVMArgs(value.data(), tcode.data(), static_cast(value.size())); - } -}; - -// Event handler for RPC events. -class RPCSession::EventHandler : public dmlc::Stream { - public: - EventHandler(support::RingBuffer* reader, - support::RingBuffer* writer, - int rpc_sess_table_index, - std::string name, - std::string* remote_key) - : reader_(reader), - writer_(writer), - rpc_sess_table_index_(rpc_sess_table_index), - name_(name), - remote_key_(remote_key) { - this->Clear(); - if (*remote_key == "%toinit") { - state_ = kInitHeader; - remote_key_->resize(0); - pending_request_bytes_ = sizeof(int32_t); - } - } - // Bytes needed to fulfill current request - size_t BytesNeeded() { - if (reader_->bytes_available() < pending_request_bytes_) { - return pending_request_bytes_ - reader_->bytes_available(); - } else { - return 0; - } - } - // Request number of bytes from reader. - void RequestBytes(size_t nbytes) { - pending_request_bytes_ += nbytes; - reader_->Reserve(pending_request_bytes_); - } - // Whether we are ready to handle next request. - bool Ready() { - return reader_->bytes_available() >= pending_request_bytes_; - } - bool CanCleanShutdown() const { - return state_ == kRecvCode; - } - void FinishCopyAck() { - this->SwitchToState(kRecvCode); - } - RPCCode HandleNextEvent(TVMRetValue* rv, - bool client_mode, - const PackedFunc* fwrap) { - std::swap(client_mode_, client_mode); - while (this->Ready()) { - switch (state_) { - case kInitHeader: HandleInitHeader(); break; - case kRecvCode: HandleRecvCode(); break; - case kRecvCallHandle: { - CHECK(this->Read(&call_handle_)); - this->SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case kRecvPackedSeqNumArgs: { - CHECK(this->Read(&num_packed_args_)); - arg_buf_.reset(new RPCArgBuffer()); - arg_buf_->value.resize(num_packed_args_); - arg_buf_->tcode.resize(num_packed_args_); - this->SwitchToState(kRecvPackedSeqTypeCode); - break; - } - case kRecvPackedSeqTypeCode: { - if (num_packed_args_ != 0) { - this->ReadArray(arg_buf_->tcode.data(), num_packed_args_); - } - arg_index_ = 0; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kRecvPackedSeqArg: { - this->HandleRecvPackedSeqArg(); - break; - } - case kDoCopyFromRemote: { - this->HandleCopyFromRemote(); - break; - } - case kDoCopyToRemote: { - this->HandleCopyToRemote(); - break; - } - case kReturnReceived: { - CHECK_GE(arg_buf_->value.size(), 1U); - - TVMArgValue argv = arg_buf_->AsTVMArgs()[0]; - if (argv.type_code() == kTVMPackedFuncHandle || - argv.type_code() == kTVMModuleHandle || - argv.type_code() == kTVMDLTensorHandle) { - CHECK(fwrap != nullptr) << "function/module wrapper not available"; - fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv); - } else { - CHECK_EQ(arg_buf_->value.size(), 1U); - *rv = argv; - } - arg_buf_.reset(); - this->SwitchToState(kRecvCode); - std::swap(client_mode_, client_mode); - return RPCCode::kReturn; - } - case kCopyAckReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kCopyAck; - } - case kShutdownReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kShutdown; - } - } - } - std::swap(client_mode_, client_mode); - return RPCCode::kNone; - } - // Reset and clear all states. - void Clear() { - state_ = kRecvCode; - pending_request_bytes_ = sizeof(RPCCode); - arg_recv_stage_ = 0; - arg_buf_.reset(); - } - // strip session on mask - TVMContext StripSessMask(TVMContext ctx) { - int dev_type = ctx.device_type; - CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1) - << "Can not pass in local context or context with a different remote session"; - ctx.device_type = static_cast(dev_type % kRPCSessMask); - return ctx; - } - // Send Packed sequence to writer. - // - // client_mode: whether we are in client mode. - // - // funwrap: auxiliary function to unwrap remote Object - // when it is provided, we need to unwrap objects. - // - // return_ndarray is a special flag to handle returning of ndarray - // In this case, we return the shape, context and data of the array, - // as well as a customized PackedFunc that handles deletion of - // the array in the remote. - void SendPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args, - bool client_mode, - FUnwrapRemoteObject funwrap = nullptr, - bool return_ndarray = false) { - std::swap(client_mode_, client_mode); - - this->Write(num_args); - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - if (tcode == kTVMNDArrayHandle) tcode = kTVMDLTensorHandle; - this->Write(tcode); - } - - // Argument packing. - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - TVMValue value = arg_values[i]; - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Write(value.v_int64); - break; - } - case kTVMDataType: { - this->Write(value.v_type); - // padding - int32_t padding = 0; - this->Write(padding); - break; - } - case kTVMContext: { - value.v_ctx = StripSessMask(value.v_ctx); - this->Write(value.v_ctx); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: { - // always send handle in 64 bit. - uint64_t handle; - // allow pass module as argument to remote. - if (funwrap != nullptr) { - void* remote_handle = (*funwrap)( - rpc_sess_table_index_, - runtime::TVMArgValue(value, tcode)); - handle = reinterpret_cast(remote_handle); - } else { - CHECK(!client_mode_) - << "Cannot directly pass remote object as argument"; - handle = reinterpret_cast(value.v_handle); - } - this->Write(handle); - break; - } - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle = reinterpret_cast(value.v_handle); - this->Write(handle); - break; - } - case kTVMNDArrayHandle: - case kTVMDLTensorHandle: { - DLTensor* arr = static_cast(value.v_handle); - TVMContext ctx; - uint64_t data; - if (!return_ndarray) { - // in the client mode - // ctx contains the remote table index - // the space is wrapped by an RemoteSpace - // that holds reference to the session. - ctx = StripSessMask(arr->ctx); - data = reinterpret_cast( - static_cast(arr->data)->data); - } else { - // When we return NDArray, we directly return - // the space and the context - // The client will be further wrapping - ctx = arr->ctx; - data = reinterpret_cast(arr->data); - } - this->Write(data); - this->Write(ctx); - this->Write(arr->ndim); - this->Write(arr->dtype); - this->WriteArray(arr->shape, arr->ndim); - CHECK(arr->strides == nullptr) - << "Do not support strided remote array"; - CHECK_EQ(arr->byte_offset, 0) - << "Do not support send byte offset"; - break; - } - case kTVMNullptr: break; - case kTVMStr: { - const char* s = value.v_str; - uint64_t len = strlen(s); - this->Write(len); - this->WriteArray(s, len); - break; - } - case kTVMBytes: { - TVMByteArray* bytes = static_cast(arg_values[i].v_handle); - uint64_t len = bytes->size; - this->Write(len); - this->WriteArray(bytes->data, len); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - std::swap(client_mode_, client_mode); - } - - // Endian aware IO handling - using Stream::Read; - using Stream::Write; - using Stream::ReadArray; - using Stream::WriteArray; - - inline bool Read(RPCCode* code) { - int cdata; - if (!this->Read(&cdata)) return false; - *code = static_cast(cdata); - return true; - } - inline void Write(RPCCode code) { - int cdata = static_cast(code); - this->Write(cdata); - } - - protected: - enum State { - kInitHeader, - kRecvCode, - kRecvCallHandle, - kRecvPackedSeqNumArgs, - kRecvPackedSeqTypeCode, - kRecvPackedSeqArg, - kDoCopyFromRemote, - kDoCopyToRemote, - kReturnReceived, - kCopyAckReceived, - kShutdownReceived - }; - // Current state; - State state_; - // The RPCCode to be read. - RPCCode code_; - // Handle for the remote function call. - uint64_t call_handle_; - // Initialize remote header - bool init_header_step_{0}; - // Number of packed arguments. - int num_packed_args_; - // Current argument index. - int arg_index_; - // The stage of each argument receiver. - int arg_recv_stage_; - // Whether current handler is client or server mode. - bool client_mode_{false}; - // Argument buffer - std::unique_ptr arg_buf_; - // Temp byte buffer. - std::unique_ptr temp_bytes_; - // Temp array buffer. - std::unique_ptr temp_array_; - // Internal temporal data space. - std::string temp_data_; - // Temp variables for copy request state. - TVMContext copy_ctx_; - DLDataType copy_dtype_; - uint64_t copy_handle_, copy_offset_, copy_size_; - // State switcher - void SwitchToState(State state) { - // invariant - CHECK_EQ(pending_request_bytes_, 0U) - << "state=" << state; - state_ = state; - switch (state) { - case kInitHeader: { - LOG(FATAL) << "cannot switch to init header"; - break; - } - case kRecvCode: { - this->RequestBytes(sizeof(RPCCode)); - break; - } - case kRecvCallHandle: { - this->RequestBytes(sizeof(call_handle_)); - break; - } - case kRecvPackedSeqNumArgs: { - this->RequestBytes(sizeof(num_packed_args_)); - break; - } - case kRecvPackedSeqTypeCode: { - this->RequestBytes(sizeof(int) * num_packed_args_); - break; - } - case kRecvPackedSeqArg: { - CHECK_LE(arg_index_, num_packed_args_); - if (arg_index_ == num_packed_args_) { - // The function can change state_ again. - HandlePackedCall(); - } else { - RequestRecvPackedSeqArg(); - } - break; - } - case kDoCopyFromRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kDoCopyToRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kCopyAckReceived: - case kReturnReceived: - case kShutdownReceived: { - break; - } - } - } - // Requets bytes needed for next computation. - void RequestRecvPackedSeqArg() { - CHECK_EQ(arg_recv_stage_, 0); - int tcode = arg_buf_->tcode[arg_index_]; - static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant"); - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: - case kTVMDataType: - case kTVMOpaqueHandle: - case kTVMStr: - case kTVMBytes: - case kTVMModuleHandle: - case kTVMContext: { - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMPackedFuncHandle: { - CHECK(client_mode_) - << "Only client can receive remote functions"; - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMNullptr: break; - case kTVMDLTensorHandle: { - this->RequestBytes(sizeof(uint64_t)); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(int)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - // Handler for packed sequence argument receive. - void HandleRecvPackedSeqArg() { - CHECK_LT(arg_index_, num_packed_args_); - int tcode = arg_buf_->tcode[arg_index_]; - TVMValue& value = arg_buf_->value[arg_index_]; - if (arg_recv_stage_ == 0) { - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Read(&(value.v_int64)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMDataType: { - this->Read(&(value.v_type)); - int32_t padding = 0; - this->Read(&padding); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMContext: { - this->Read(&(value.v_ctx)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle; - this->Read(&handle); - value.v_handle = reinterpret_cast(handle); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMNullptr: { - value.v_handle = nullptr; - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMStr: - case kTVMBytes: { - uint64_t len; - this->Read(&len); - temp_bytes_.reset( new RPCByteArrayBuffer()); - temp_bytes_->data.resize(len); - arg_recv_stage_ = 1; - this->RequestBytes(len); - break; - } - case kTVMDLTensorHandle: { - temp_array_.reset(new RPCDataArrayBuffer()); - uint64_t handle; - this->Read(&handle); - DLTensor& tensor = temp_array_->tensor; - tensor.data = reinterpret_cast(handle); - this->Read(&(tensor.ctx)); - this->Read(&(tensor.ndim)); - this->Read(&(tensor.dtype)); - temp_array_->shape.resize(tensor.ndim); - tensor.shape = temp_array_->shape.data(); - arg_recv_stage_ = 1; - tensor.strides = nullptr; - tensor.byte_offset = 0; - this->RequestBytes(sizeof(int64_t) * tensor.ndim); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } else { - CHECK_EQ(arg_recv_stage_, 1); - if (tcode == kTVMStr || tcode == kTVMBytes) { - if (temp_bytes_->data.size() != 0) { - this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size()); - } - if (tcode == kTVMStr) { - value.v_str = temp_bytes_->data.c_str(); - } else { - temp_bytes_->arr.size = static_cast(temp_bytes_->data.size()); - temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data); - value.v_handle = &(temp_bytes_->arr); - } - arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_)); - } else { - CHECK_EQ(tcode, kTVMDLTensorHandle); - DLTensor& tensor = temp_array_->tensor; - this->ReadArray(tensor.shape, tensor.ndim); - value.v_handle = &tensor; - arg_buf_->temp_array.emplace_back(std::move(temp_array_)); - } - ++arg_index_; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - } - } - // handler for initial header read - void HandleInitHeader() { - if (init_header_step_ == 0) { - int32_t len; - this->Read(&len); - remote_key_->resize(len); - init_header_step_ = 1; - this->RequestBytes(len); - return; - } else { - CHECK_EQ(init_header_step_, 1); - this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); - this->SwitchToState(kRecvCode); - } - } - // Handler for read code. - void HandleRecvCode() { - this->Read(&code_); - if (code_ > RPCCode::kSystemFuncStart) { - SwitchToState(kRecvPackedSeqNumArgs); - return; - } - // invariant. - CHECK_EQ(arg_recv_stage_, 0); - switch (code_) { - case RPCCode::kCallFunc: { - SwitchToState(kRecvCallHandle); - break; - } - case RPCCode::kException: - case RPCCode::kReturn: { - SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case RPCCode::kCopyFromRemote: { - SwitchToState(kDoCopyFromRemote); - break; - } - case RPCCode::kCopyToRemote: { - SwitchToState(kDoCopyToRemote); - break; - } - case RPCCode::kShutdown: { - SwitchToState(kShutdownReceived); - break; - } - case RPCCode::kCopyAck: { - SwitchToState(kCopyAckReceived); - break; - } - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } - } - - void HandleCopyFromRemote() { - uint64_t handle, offset, num_bytes; - TVMContext ctx; - DLDataType type_hint; - this->Read(&handle); - this->Read(&offset); - this->Read(&num_bytes); - this->Read(&ctx); - this->Read(&type_hint); - size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; - - if (ctx.device_type == kDLCPU) { - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - char* dptr = reinterpret_cast(handle) + offset; - if (!DMLC_IO_NO_ENDIAN_SWAP) { - temp_data_.resize(0); - temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes); - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - this->WriteArray(temp_data_.data(), num_bytes); - } else { - this->WriteArray(dptr, num_bytes); - } - } else { - temp_data_.resize(num_bytes + 1); - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(ctx)->CopyDataFromTo( - reinterpret_cast(handle), offset, - dmlc::BeginPtr(temp_data_), 0, - num_bytes, ctx, cpu_ctx, type_hint, nullptr); - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - } - this->WriteArray(&temp_data_[0], num_bytes); - } catch (const std::runtime_error &e) { - RPCCode code = RPCCode::kException; - this->Write(code); - TVMValue ret_value; - ret_value.v_str = e.what(); - int ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } - this->SwitchToState(kRecvCode); - } - - void HandleCopyToRemote() { - // use static variable to persist state. - // This only works if next stage is immediately after this. - if (arg_recv_stage_ == 0) { - CHECK(this->Read(©_handle_)); - CHECK(this->Read(©_offset_)); - CHECK(this->Read(©_size_)); - CHECK(this->Read(©_ctx_)); - CHECK(this->Read(©_dtype_)); - arg_recv_stage_ = 1; - CHECK_EQ(pending_request_bytes_, 0U); - this->RequestBytes(copy_size_); - } else { - CHECK_EQ(arg_recv_stage_, 1); - TVMValue ret_value; - ret_value.v_handle = nullptr; - int ret_tcode = kTVMNullptr; - RPCCode code = RPCCode::kReturn; - std::string errmsg; - - size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8; - if (copy_ctx_.device_type == kDLCPU) { - char* dptr = reinterpret_cast(copy_handle_) + copy_offset_; - this->ReadArray(dptr, copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes); - } - } else { - temp_data_.resize(copy_size_ + 1); - this->ReadArray(&temp_data_[0], copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes); - } - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(copy_ctx_)->CopyDataFromTo( - temp_data_.data(), 0, - reinterpret_cast(copy_handle_), copy_offset_, - copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr); - } catch (const std::runtime_error &e) { - code = RPCCode::kException; - errmsg = e.what(); - ret_value.v_str = errmsg.c_str(); - ret_tcode = kTVMStr; - } - } - this->Write(code); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - arg_recv_stage_ = 0; - this->SwitchToState(kRecvCode); - } - } - // Handle for packed call. - void HandlePackedCall(); - - template - void CallHandler(F f) { - TVMRetValue rv; - TVMValue ret_value; - int ret_tcode; - try { - // Need to move out, in case f itself need to call RecvPackedSeq - // Which will override argbuf again. - std::unique_ptr args = std::move(arg_buf_); - f(args->AsTVMArgs(), &rv); - RPCCode code = RPCCode::kReturn; - this->Write(code); - if (rv.type_code() == kTVMStr) { - ret_value.v_str = rv.ptr()->c_str(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMBytes) { - std::string* bytes = rv.ptr(); - TVMByteArray arr; - arr.data = bytes->c_str(); - arr.size = bytes->length(); - ret_value.v_handle = &arr; - ret_tcode = kTVMBytes; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMPackedFuncHandle || - rv.type_code() == kTVMModuleHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send function and module handle back."; - rv.MoveToCHost(&ret_value, &ret_tcode); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMNDArrayHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send NDArray back"; - // We follow a special protocol to return NDArray to client side - // The first pack value is the NDArray handle as DLTensor - // The second pack value is a customized deleter that deletes the NDArray. - TVMValue ret_value_pack[2]; - int ret_tcode_pack[2]; - rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); - ret_value_pack[1].v_handle = ret_value_pack[0].v_handle; - ret_tcode_pack[1] = kTVMOpaqueHandle; - SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); - } else { - ret_value = rv.value(); - ret_tcode = rv.type_code(); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } catch (const std::runtime_error& e) { - RPCCode code = RPCCode::kException; - this->Write(code); - ret_value.v_str = e.what(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } - - private: - // Utility functions - // Internal read function, update pending_request_bytes_ - size_t Read(void* data, size_t size) final { - CHECK_LE(size, pending_request_bytes_); - reader_->Read(data, size); - pending_request_bytes_ -= size; - return size; - } - void Write(const void* data, size_t size) final { - writer_->Write(data, size); - } - // Number of pending bytes requests - size_t pending_request_bytes_; - // The ring buffer to read data from. - support::RingBuffer* reader_; - // The ringr buffer to write reply to. - support::RingBuffer* writer_; - // Session table index. - int rpc_sess_table_index_; - // Name of session. - std::string name_; - // remote key - std::string* remote_key_; -}; - -struct RPCSessTable { +class RPCSessTable { public: static constexpr int kMaxRPCSession = 32; // Get global singleton @@ -864,465 +63,13 @@ struct RPCSessTable { std::array, kMaxRPCSession> tbl_; }; -RPCCode RPCSession::HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap) { - RPCCode code = RPCCode::kCallFunc; - while (code != RPCCode::kReturn && - code != RPCCode::kShutdown && - code != RPCCode::kCopyAck) { - while (writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - size_t bytes_needed = handler_->BytesNeeded(); - if (bytes_needed != 0) { - size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { - return channel_->Recv(data, size); - }, bytes_needed); - if (n == 0) { - if (handler_->CanCleanShutdown()) { - return RPCCode::kShutdown; - } else { - LOG(FATAL) << "Channel closes before we get neded bytes"; - } - } - } - code = handler_->HandleNextEvent(rv, client_mode, fwrap); - } - return code; -} - -void RPCSession::Init() { - // Event handler - handler_ = std::make_shared( - &reader_, &writer_, table_index_, name_, &remote_key_); - // Quick function to call remote. - call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); - RPCCode code = HandleUntilReturnEvent(rv, true, nullptr); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); - }); -} - -std::shared_ptr RPCSession::Create( - std::unique_ptr channel, - std::string name, - std::string remote_key) { - std::shared_ptr sess = std::make_shared(); - sess->channel_ = std::move(channel); - sess->name_ = std::move(name); - sess->remote_key_ = std::move(remote_key); - sess->table_index_ = RPCSessTable::Global()->Insert(sess); - sess->Init(); - return sess; -} - std::shared_ptr RPCSession::Get(int table_index) { return RPCSessTable::Global()->Get(table_index); } -RPCSession::~RPCSession() { - this->Shutdown(); -} - -void RPCSession::Shutdown() { - if (channel_ != nullptr) { - RPCCode code = RPCCode::kShutdown; - handler_->Write(code); - // flush all writing buffer to output channel. - try { - while (writer_.bytes_available() != 0) { - size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - if (n == 0) break; - } - } catch (const dmlc::Error& e) { - } - channel_.reset(nullptr); - } -} - -void RPCSession::ServerLoop() { - std::lock_guard lock(mutex_); - if (const auto* f = Registry::Get("tvm.rpc.server.start")) { - (*f)(); - } - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown); - if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { - (*f)(); - } - channel_.reset(nullptr); -} - -int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) { - std::lock_guard lock(mutex_); - RPCCode code = RPCCode::kNone; - if (bytes.length() != 0) { - reader_.Write(bytes.c_str(), bytes.length()); - TVMRetValue rv; - code = handler_->HandleNextEvent(&rv, false, nullptr); - } - if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); - if (code == RPCCode::kShutdown) return 0; - if (writer_.bytes_available() != 0) return 2; - return 1; -} - -// Get remote function with name -void RPCSession::CallFunc(void* h, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap) { - std::lock_guard lock(mutex_); - - RPCCode code = RPCCode::kCallFunc; - handler_->Write(code); - uint64_t handle = reinterpret_cast(h); - handler_->Write(handle); - handler_->SendPackedSeq( - args.values, args.type_codes, args.num_args, true, funwrap); - code = HandleUntilReturnEvent(rv, true, fwrap); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); -} - -void RPCSession::CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_to, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_to = handler_->StripSessMask(ctx_to); - RPCCode code = RPCCode::kCopyToRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(to); - handler_->Write(handle); - uint64_t offset = static_cast(to_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_to); - handler_->Write(type_hint); - handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn); -} - -void RPCSession::CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_from, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_from = handler_->StripSessMask(ctx_from); - RPCCode code = RPCCode::kCopyFromRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(from); - handler_->Write(handle); - uint64_t offset = static_cast(from_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_from); - handler_->Write(type_hint); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck); - reader_.Reserve(data_size); - handler_->RequestBytes(data_size); - while (!handler_->Ready()) { - size_t bytes_needed = handler_->BytesNeeded(); - reader_.WriteWithCallback([this](void* data, size_t size) { - size_t n = channel_->Recv(data, size); - CHECK_NE(n, 0U) << "Channel closes before we get neded bytes"; - return n; - }, bytes_needed); - } - handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); - handler_->FinishCopyAck(); -} - -RPCFuncHandle RPCSession::GetTimeEvaluator( - RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - return this->CallRemote( - RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms); -} - -// Event handler functions -void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) { - std::string name = args[0]; - auto *fp = tvm::runtime::Registry::Get(name); - if (fp != nullptr) { - *rv = static_cast(new tvm::runtime::PackedFunc(*fp)); - } else { - *rv = nullptr; - } -} - -void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - delete static_cast(handle); -} - -void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAPI::Get(ctx)->SetDevice(ctx); -} - -void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAttrKind kind = static_cast(args[1].operator int()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPI::Get(ctx, true); - if (api != nullptr) { - api->GetAttr(ctx, kind, rv); - } else { - *rv = 0; - } - } else { - DeviceAPI::Get(ctx)->GetAttr( - ctx, static_cast(kind), rv); - } -} - -void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - uint64_t nbytes = args[1]; - uint64_t alignment = args[2]; - DLDataType type_hint = args[3]; - void* data = DeviceAPI::Get(ctx)->AllocDataSpace( - ctx, nbytes, alignment, type_hint); - *rv = data; -} - -void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - void* ptr = args[1]; - DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr); -} - -void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - TVMStreamHandle handle = args[1]; - DeviceAPI::Get(ctx)->StreamSync(ctx, handle); -} - -void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { - void* from = args[0]; - uint64_t from_offset = args[1]; - void* to = args[2]; - uint64_t to_offset = args[3]; - uint64_t size = args[4]; - TVMContext ctx_from = args[5]; - TVMContext ctx_to = args[6]; - DLDataType type_hint = args[7]; - TVMStreamHandle stream = args[8]; - TVMContext ctx = ctx_from; - if (ctx.device_type == kDLCPU) { - ctx = ctx_to; - } else { - CHECK(ctx_to.device_type == kDLCPU || - ctx_to.device_type == ctx_from.device_type) - << "Can not copy across different ctx types directly"; - } - DeviceAPI::Get(ctx)->CopyDataFromTo( - from, from_offset, - to, to_offset, - size, ctx_from, ctx_to, type_hint, stream); -} - -void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { - static const PackedFunc* fsys_load_ = nullptr; - if (fsys_load_ == nullptr) { - fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module"); - CHECK(fsys_load_ != nullptr); - } - std::string file_name = args[0]; - TVMRetValue ret = (*fsys_load_)(file_name); - // pass via void* - TVMValue value; - int rcode; - ret.MoveToCHost(&value, &rcode); - CHECK_EQ(rcode, kTVMModuleHandle); - *rv = static_cast(value.v_handle); -} - -void RPCModuleImport(TVMArgs args, TVMRetValue *rv) { - void* pmod = args[0]; - void* cmod = args[1]; - ObjectInternal::GetModuleNode(pmod)->Import( - GetRef(ObjectInternal::GetModuleNode(cmod))); -} - -void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - ObjectInternal::ObjectFree(mhandle); -} - -void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction( - args[1], false); - if (pf != nullptr) { - *rv = static_cast(new PackedFunc(pf)); - } else { - *rv = nullptr; - } -} - -void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - std::string fmt = args[1]; - *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt); -} - -void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - static_cast( - reinterpret_cast(handle))->DecRef(); -} - -void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { - PackedFunc *pf = static_cast(args[0].operator void*()); - void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4])); - delete pf; - *rv = fhandle; -} - -void RPCSession::EventHandler::HandlePackedCall() { - CHECK_EQ(pending_request_bytes_, 0U); - if (code_ == RPCCode::kReturn) { - state_ = kReturnReceived; return; - } - // reset state to clean init state - state_ = kRecvCode; - this->RequestBytes(sizeof(RPCCode)); - // Event handler sit at clean state at this point. - switch (code_) { - case RPCCode::kCallFunc: { - PackedFunc* pf = reinterpret_cast(call_handle_); - CallHandler([pf](TVMArgs args, TVMRetValue* rv) { - pf->CallPacked(args, rv); - }); - break; - } - case RPCCode::kException: { - CHECK_EQ(arg_buf_->value.size(), 1U); - CHECK_EQ(arg_buf_->tcode[0], kTVMStr); - std::ostringstream os; - os << "Except caught from RPC call: " << arg_buf_->value[0].v_str; - arg_buf_.reset(); - throw dmlc::Error(os.str()); - break; - } - // system functions - case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break; - case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break; - case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break; - case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break; - case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break; - case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break; - case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break; - case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break; - case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break; - case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break; - case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break; - case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break; - case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break; - case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break; - case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break; - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } - CHECK_EQ(state_, kRecvCode); -} - -PackedFunc WrapTimeEvaluator(PackedFunc pf, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms) { - if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { - auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); - CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; - return (*get_micro_time_evaluator)(pf, ctx, number, repeat); - } - - auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable { - TVMRetValue temp; - std::ostringstream os; - // skip first time call, to activate lazy compilation components. - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - - for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; - double duration_ms = 0.0; - - do { - if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random - } - - tbegin = std::chrono::high_resolution_clock::now(); - // start timing - for (int i = 0; i < number; ++i) { - pf.CallPacked(args, &temp); - } - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - tend = std::chrono::high_resolution_clock::now(); - - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; - } while (duration_ms < min_repeat_ms); - - double speed = std::chrono::duration_cast >( - tend - tbegin).count() / number; - os.write(reinterpret_cast(&speed), sizeof(speed)); - } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} - -size_t CallbackChannel::Send(const void* data, size_t size) { - TVMByteArray bytes; - bytes.data = static_cast(data); - bytes.size = size; - int64_t n = fsend_(bytes); - if (n == -1) { - support::Socket::Error("CallbackChannel::Send"); - } - return static_cast(n); -} - -size_t CallbackChannel::Recv(void* data, size_t size) { - TVMRetValue ret = frecv_(size); - - if (ret.type_code() != kTVMBytes) { - support::Socket::Error("CallbackChannel::Recv"); - } - std::string* bytes = ret.ptr(); - memcpy(static_cast(data), bytes->c_str(), bytes->length()); - return bytes->length(); +void RPCSession::InsertToSessionTable(std::shared_ptr sess) { + CHECK_EQ(sess->table_index_, 0); + sess->table_index_ = RPCSessTable::Global()->Insert(sess); } } // namespace runtime diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index db63be4be74da..a715e7b0b20c7 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -24,230 +24,166 @@ #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ + #include #include -#include -#include +#include #include -#include -#include "../../support/ring_buffer.h" +#include namespace tvm { namespace runtime { -// Magic header for RPC data plane -const int kRPCMagic = 0xff271; -// magic header for RPC tracker(control plane) -const int kRPCTrackerMagic = 0x2f271; -// sucess response -const int kRPCSuccess = kRPCMagic + 0; -// cannot found matched key in server -const int kRPCMismatch = kRPCMagic + 2; - -/*! \brief Enumeration code for the RPC tracker */ -enum class TrackerCode : int { - kFail = -1, - kSuccess = 0, - kPing = 1, - kStop = 2, - kPut = 3, - kRequest = 4, - kUpdateInfo = 5, - kSummary = 6, - kGetPendingMatchKeys = 7 -}; -/*! \brief The remote functio handle */ -using RPCFuncHandle = void*; - -struct RPCArgBuffer; - -/*! \brief The RPC code */ -enum class RPCCode : int { - kNone, - kCallFunc, - kReturn, - kException, - kShutdown, - kCopyFromRemote, - kCopyToRemote, - kCopyAck, - // The following are code that can send over CallRemote - kSystemFuncStart, - kGetGlobalFunc, - kGetTimeEvaluator, - kFreeFunc, - kDevSetDevice, - kDevGetAttr, - kDevAllocData, - kDevFreeData, - kDevStreamSync, - kCopyAmongRemote, - kModuleLoad, - kModuleImport, - kModuleFree, - kModuleGetFunc, - kModuleGetSource, - kNDArrayFree -}; - /*! - * \brief Function that unwraps a remote object to its handle. - * \param rpc_sess_table_index RPC session table index for validation. - * \param obj Handle to the object argument. - * \return The corresponding handle. - */ -typedef void* (*FUnwrapRemoteObject)( - int rpc_sess_table_index, - const TVMArgValue& obj); - -/*! - * \brief Abstract channel interface used to create RPCSession. + * \brief The interface of all remote RPC sessions. + * + * It contains all the necessary interface to implement + * remote call and resource management. + * + * The interface is designed to allow easy proxy-chaining + * by forward requests to another RPCSession. */ -class RPCChannel { +class RPCSession { public: - /*! \brief virtual destructor */ - virtual ~RPCChannel() {} - /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. - */ - virtual size_t Send(const void* data, size_t size) = 0; + /*! \brief PackedFunc Handle in the remote. */ + using PackedFuncHandle = void*; + + /*! \brief Module handle in the remote. */ + using ModuleHandle = void*; + + /*! \brief NDArray handle in the remote. */ + using NDArrayHandle = void*; + /*! - * \brief Recv data from channel. + * \brief Callback to send an encoded return values via encode_args. + * + * \param encode_args The arguments that we can encode the return values into. + * \param ret_tcode The actual remote type code of the return value. * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. + * Encoding convention (as list of arguments): + * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention. + * - PackedFunc/Module: [tcode: int, handle: void*] + * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*] + * DLTensor* contains the meta-data as well as handle into the remote data. + * nd_handle can be used for deletion. */ - virtual size_t Recv(void* data, size_t size) = 0; -}; + using FEncodeReturn = std::function; + + /*! \brief Destructor.*/ + virtual ~RPCSession() {} -// Bidirectional Communication Session of PackedRPC -class RPCSession { - public: - /*! \brief virtual destructor */ - ~RPCSession(); /*! - * \brief The server loop that server runs to handle RPC calls. + * \brief Get function in the session. + * \param name The name of the function. + * \return The function handle. */ - void ServerLoop(); + virtual PackedFuncHandle GetFunction(const std::string& name) = 0; + /*! - * \brief Message handling function for event driven server. - * Called when the server receives a message. - * Event driven handler will never call recv on the channel - * and always relies on the ServerEventHandler. - * to receive the data. + * \brief Call into a remote Packed function. * - * \param in_bytes The incoming bytes. - * \param event_flag 1: read_available, 2: write_avaiable. - * \return State flag. - * 1: continue running, no need to write, - * 2: need to write - * 0: shutdown - */ - int ServerEventHandler(const std::string& in_bytes, - int event_flag); - /*! - * \brief Call into remote function - * \param handle The function handle - * \param args The arguments - * \param rv The return value. - * \param funpwrap Function that takes a remote object and returns the raw handle. - * \param fwrap Wrapper function to turn Function/Module handle into real return. + * Calling convention: + * + * - type_code is follows the PackedFunc convention. + * - int/float/string/bytes follows the PackedFunc convention, all data are local. + * - PackedFunc/Module and future remote objects: pass remote handle instead. + * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor + * points to a remote data handle returned by the Device API. + * The meta-data of the DLTensor sits on local. + * + * The caller populates the arguments and manages these arguments. + * + * The callee can change the content of arg_values and arg_type_codes + * if they want to do inplace modify and forward. + * + * The callee need to store the return value into ret_value. + * - PackedFunc/Module are stored as void* + * - NDArray is stored as local NDArray, whose data field is a remote handle. + * Notably the NDArray's deleter won't delete remote handle. + * It is up to the user of the RPCSession to such wrapping. + * - In short, remote handles are "moved" as return values + * and the callee needs to explicitly manage them by calling + * the deleter functions when they are no longer needed. + * + * \param func The function handle. + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to set the return value, + * if not called, return value is null. */ - void CallFunc(RPCFuncHandle handle, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap); + virtual void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) = 0; + /*! * \brief Copy bytes into remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. - * \param to The target array. - * \param to_offset The byte offset in the to. + * \param local_from The source host data. + * \param local_from_offset The byte offeset in the from. + * \param remote_to The target array. + * \param remote_to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_to The target context. + * \param remote_ctx_to The target context. * \param type_hint Hint of content data type. */ - void CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint); + virtual void CopyToRemote(void* local_from, + size_t local_from_offset, + void* remote_to, + size_t remote_to_offset, + size_t nbytes, + TVMContext remote_ctx_to, + DLDataType type_hint) = 0; /*! * \brief Copy bytes from remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. + * \param remote_from The source host data. + * \param remote_from_offset The byte offeset in the from. * \param to The target array. * \param to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_from The source context. + * \param remote_ctx_from The source context in the remote. * \param type_hint Hint of content data type. */ - void CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint); + virtual void CopyFromRemote(void* remote_from, + size_t remote_from_offset, + void* local_to, + size_t local_to_offset, + size_t nbytes, + TVMContext remote_ctx_from, + DLDataType type_hint) = 0; + /*! - * \brief Get a remote timer function on ctx. - * This function consumes fhandle, caller should not call Free on fhandle. - * - * \param fhandle The function handle. - * \param ctx The ctx to run measurement on. - * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. - * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. - * \return A remote timer function + * \brief Free a remote function. + * \param handle The remote handle, can be NDArray/PackedFunc/Module + * \param type_code The type code of the underlying type. */ - RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms); + virtual void FreeHandle(void* handle, int type_code) = 0; + /*! - * \brief Call a remote defined system function with arguments. - * \param fcode The function code. - * \param args The arguments - * \return The returned remote value. + * \brief Get device API that represents the remote + * actions that can be taken on the remote. + * + * The caller can then call into the Alloc/Free functions + * to allocate free spaces and taking the pointer as the handle. + * + * The device API is guaranteed to be alive during the + * lifetime of the Session. + * + * \param ctx The remote context. + * \param allow_missing Whether can we return nullptr if it is not available. + * + * \return The device API. */ - template - inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args); + virtual DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) = 0; + /*! * \return The session table index of the session. */ int table_index() const { return table_index_; } - /*! - * \brief Create a RPC session with given channel. - * \param channel The communication channel. - * \param name The local name of the session, used for debug - * \param remote_key The remote key of the session - * if remote_key equals "%toinit", we need to re-intialize - * it by event handler. - */ - static std::shared_ptr Create( - std::unique_ptr channel, - std::string name, - std::string remote_key); + /*! * \brief Try get session from the global session table by table index. * \param table_index The table index of the session. @@ -256,62 +192,25 @@ class RPCSession { static std::shared_ptr Get(int table_index); private: - class EventHandler; - // Handle events until receives a return - // Also flushes channels so that the function advances. - RPCCode HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap); - // Initalization - void Init(); - // Shutdown - void Shutdown(); - // Internal channel. - std::unique_ptr channel_; - // Internal mutex - std::recursive_mutex mutex_; - // Internal ring buffer. - support::RingBuffer reader_, writer_; - // Event handler. - std::shared_ptr handler_; - // call remote with specified function code. - PackedFunc call_remote_; - // The index of this session in RPC session table. + /*! \brief index of this session in RPC session table */ int table_index_{0}; - // The name of the session. - std::string name_; - // The remote key - std::string remote_key_; + /*! \brief Insert the current session to the session table.*/ + static void InsertToSessionTable(std::shared_ptr sess); + // friend declaration + friend Module CreateRPCSessionModule(std::shared_ptr sess); }; /*! - * \brief RPC channel which callback - * frontend (Python/Java/etc.)'s send & recv function + * \brief Remote space handle cell used by the RPC runtime API. + * + * When we allocate space using a rpc context, the data pointer + * points to an allocated RemoteSpace. */ -class CallbackChannel final : public RPCChannel { - public: - explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) - : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} - - ~CallbackChannel() {} - /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. - */ - size_t Send(const void* data, size_t size) final; - /*! - * \brief Recv data from channel. - * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. - */ - size_t Recv(void* data, size_t size) final; - - private: - PackedFunc fsend_; - PackedFunc frecv_; +struct RemoteSpace { + /*! \brief The remote data handle. */ + void* data; + /*! \brief Reference to the underlying RPC session. */ + std::shared_ptr sess; }; /*! @@ -319,18 +218,18 @@ class CallbackChannel final : public RPCChannel { * \param f The function argument. * \param ctx The context. * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. + * We call these runs as one `repeat` of measurement. * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. + * In total, the function will be invoked (1 + number x repeat) times, + * where the first one is warm up and will be discarded. + * The returned result contains `repeat` costs, + * each of which is an average of `number` costs. * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. + * By default, one `repeat` contains `number` runs. If this parameter is set, + * the parameters `number` will be dynamically adjusted to meet the + * minimum duration requirement of one `repeat`. + * i.e., When the run time of one `repeat` falls below this time, + * the `number` parameter will be automatically increased. * \return f_timer A timer function. */ PackedFunc WrapTimeEvaluator(PackedFunc f, @@ -344,21 +243,15 @@ PackedFunc WrapTimeEvaluator(PackedFunc f, * \param sess The RPC session of the global module. * \return The created module. */ -Module CreateRPCModule(std::shared_ptr sess); +Module CreateRPCSessionModule(std::shared_ptr sess); -// Remote space pointer. -struct RemoteSpace { - void* data; - std::shared_ptr sess; -}; +/*! + * \brief Get the session module from a RPC session Module. + * \param mod The input module(must be an RPCModule). + * \return The internal RPCSession. + */ +std::shared_ptr RPCModuleGetSession(Module mod); -// implementation of inline functions -template -inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) { - std::lock_guard lock(mutex_); - writer_.Write(&code, sizeof(code)); - return call_remote_(std::forward(args)...); -} } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_RPC_RPC_SESSION_H_ diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 642fbb8ec7f26..f3a30dd6c4858 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -22,8 +22,11 @@ * \brief Socket based RPC implementation. */ #include +#include #include +#include "rpc_endpoint.h" #include "rpc_session.h" +#include "rpc_local_session.h" #include "../../support/socket.h" namespace tvm { @@ -61,8 +64,8 @@ class SockChannel final : public RPCChannel { support::TCPSocket sock_; }; -std::shared_ptr -RPCConnect(std::string url, int port, std::string key) { +std::shared_ptr +RPCConnect(std::string url, int port, std::string key, TVMArgs init_seq) { support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); @@ -96,42 +99,56 @@ RPCConnect(std::string url, int port, std::string key) { remote_key.resize(keylen); CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } - return RPCSession::Create( + auto endpt = RPCEndpoint::Create( std::unique_ptr(new SockChannel(sock)), key, remote_key); + endpt->InitRemoteSession(init_seq); + return endpt; } -Module RPCClientConnect(std::string url, int port, std::string key) { - return CreateRPCModule(RPCConnect(url, port, "client:" + key)); +Module RPCClientConnect(std::string url, + int port, + std::string key, + TVMArgs init_seq) { + auto endpt = RPCConnect(url, port, "client:" + key, init_seq); + return CreateRPCSessionModule(CreateClientSession(endpt)); } // TVM_DLL needed for MSVC TVM_DLL void RPCServerLoop(int sockfd) { support::TCPSocket sock( static_cast(sockfd)); - RPCSession::Create( + RPCEndpoint::Create( std::unique_ptr(new SockChannel(sock)), "SockServerLoop", "")->ServerLoop(); } -void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { - RPCSession::Create(std::unique_ptr( - new CallbackChannel(fsend, frecv)), +void RPCServerLoop(PackedFunc fsend, + PackedFunc frecv) { + RPCEndpoint::Create( + std::unique_ptr(new CallbackChannel(fsend, frecv)), "SockServerLoop", "")->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc._Connect") -.set_body_typed(RPCClientConnect); +TVM_REGISTER_GLOBAL("rpc.Connect") +.set_body([](TVMArgs args, TVMRetValue *rv) { + std::string url = args[0]; + int port = args[1]; + std::string key = args[2]; + *rv = RPCClientConnect( + url, port, key, + TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); +}); -TVM_REGISTER_GLOBAL("rpc._ServerLoop") +TVM_REGISTER_GLOBAL("rpc.ServerLoop") .set_body([](TVMArgs args, TVMRetValue* rv) { - if (args.size() == 1) { - RPCServerLoop(args[0]); - } else { - CHECK_EQ(args.size(), 2); - RPCServerLoop( - args[0].operator tvm::runtime::PackedFunc(), - args[1].operator tvm::runtime::PackedFunc()); - } - }); + if (args[0].type_code() == kDLInt) { + RPCServerLoop(args[0]); + } else { + RPCServerLoop( + args[0].operator tvm::runtime::PackedFunc(), + args[1].operator tvm::runtime::PackedFunc()); + } +}); + } // namespace runtime } // namespace tvm diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index b61e6bb9fa016..bc109f95b91d7 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -22,6 +22,7 @@ import time import multiprocessing +import pytest import numpy as np from tvm import rpc from tvm.contrib import util @@ -77,11 +78,9 @@ def remotethrow(name): f1 = client.get_function("rpc.test.addone") assert f1(10) == 11 f3 = client.get_function("rpc.test.except") - try: + + with pytest.raises(tvm.error.RPCError): f3("abc") - assert False - except tvm.error.TVMError as e: - assert "abc" in str(e) f2 = client.get_function("rpc.test.strcat") assert f2("abc", 11) == "abc:11" @@ -101,6 +100,21 @@ def remote_array_func(y): fremote = remote.get_function("rpc.test.remote_array_func") fremote(r_cpu) + +def test_rpc_echo(): + def check(remote): + fecho = remote.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + temp = rpc.server._server_env([]) + server = rpc.Server("localhost") + client = rpc.connect(server.host, server.port) + check(rpc.LocalSession()) + check(client) + + def test_rpc_file_exchange(): if not tvm.runtime.enabled("rpc"): return @@ -114,14 +128,15 @@ def test_rpc_file_exchange(): def test_rpc_remote_module(): if not tvm.runtime.enabled("rpc"): return - server = rpc.Server("localhost") - client = rpc.connect(server.host, server.port) # graph n = tvm.runtime.convert(1024) A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) + server = rpc.Server("localhost") + client = rpc.connect(server.host, server.port) + def check_remote(remote): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") @@ -188,8 +203,9 @@ def check_remote_link_cl(remote): fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote(client) check_remote(rpc.LocalSession()) + check_remote(client) + def test_rpc_return_func(): @@ -204,6 +220,37 @@ def addone(x): assert fadd(12) == 22 +def test_rpc_session_constructor(): + # start server + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + def check_multi_hop(): + # use server0 as proxy to connect to server1 + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor=[ + "rpc.Connect", server1.host, server1.port, "x1"]) + + fecho = client.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + nd = tvm.nd.array([1,2,3], ctx=client.cpu(0)) + assert(nd.asnumpy()[1] == 2) + + def check_error_handling(): + with pytest.raises(tvm.error.RPCError): + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor=["rpc.NonExistingConstructor"]) + + check_multi_hop() + check_error_handling() + + def test_rpc_return_ndarray(): # Use closure to check the ref counter correctness nd = tvm.nd.array(np.zeros(10).astype("float32")) @@ -221,6 +268,7 @@ def my_module(name): # start server server = rpc.Server("localhost", key="x1") client = rpc.connect(server.host, server.port, key="x1") + m = client.get_function("rpc.test.remote_return_nd") get_arr = m("get_arr") ref_count = m("ref_count") @@ -315,6 +363,7 @@ def target(host, port, device_key, timeout): time.sleep(0.5) summary = client.summary() + assert summary['queue_info'][device_key]['free'] == 0 assert summary['queue_info'][device_key]['pending'] == 1 @@ -334,6 +383,8 @@ def target(host, port, device_key, timeout): if __name__ == "__main__": logging.basicConfig(level=logging.INFO) + test_rpc_session_constructor() + test_rpc_echo() test_rpc_return_ndarray() test_rpc_return_func() test_bigendian_rpc() diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js index b62b298d969e2..86ef59cb73b19 100644 --- a/web/tvm_runtime.js +++ b/web/tvm_runtime.js @@ -907,7 +907,7 @@ var tvm_runtime = tvm_runtime || {}; if (typeof systemFunc.fcreateServer === "undefined") { systemFunc.fcreateServer = - getGlobalFunc("rpc._CreateEventDrivenServer"); + getGlobalFunc("rpc.CreateEventDrivenServer"); } if (systemFunc.fcreateServer == null) { throwError("RPCServer is not included in runtime");