From 5b68ab6b7cd2b58b55a632998fdc1cb71605a3d2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 30 Apr 2020 14:51:31 -0700 Subject: [PATCH 1/4] Update dmlc-core which was mistakenly overriden --- 3rdparty/dmlc-core | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 808f485387f9..ff3db4367a30 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 808f485387f9a03f78fa9f1159f387d0d91b7a28 +Subproject commit ff3db4367a30f542aafb83b4af45e685b80102d0 From d7bb20fc3bb6d15b63a1e69b62f2603122782296 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 28 Apr 2020 11:48:56 -0700 Subject: [PATCH 2/4] [REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra. This PR refactors the RPC protocol to make it more modularized. - RPCSession: represent a set of features that need to be implemented - RPCEndPont: End point that forwards the RPCSession requests over a communication channel. - RPCModule: Exposes an RPCSession as an rpc device in the TVM Runtime API. In the new design, the local machine is presented as a special case of RPCSession. The remote is just another client session that calls into RPCEndPoint. The RPC communication path is as follows. ``` client -> ClientSession -> EndPoint[client@n0] -> networking[between n0 <=> n1] -> EndPoint[server@n1] -> LocalSession[@n1] ``` Because of the new modular design, we can now chain more sessions together. For example, we can now run the following proxy setup (testcase in test_runtime_rpc.test_session_constructor). ``` client -> ClientSession -> Endpoint[client@n0] -> networking[between n0 <=> n1] -> Endpoint[server@n1] -> ClientSession -> Endpoint[client@n1] -> networking[between n1 <=> n2] -> Endpoint[server@n2] -> LocalSession[@n2] ``` We can also implement other types of Sessions. As an example, We introduced a PopenSession that communicates with the another process via a pipe. We also add more comments about the internal of the RPC. The communication protocol is simplfied using a similar convention as PackedFunc. This allows us to further reduce the amount of special remote syscalls. Due to the major improvement and simplification, we are making a non-compatible update to the RPC protocol. It means that the client and server needs to be upgraded to together in order for it to function correctly. This PR also introduces a versioning mechanism to the current RPC procotol, so that future upgrade will be produce more user friendly with error messages. --- .gitignore | 3 +- apps/bundle_deploy/runtime.cc | 1 - apps/cpp_rpc/rpc_server.cc | 20 +- apps/cpp_rpc/rpc_tracker_client.h | 2 +- include/tvm/runtime/c_runtime_api.h | 48 + include/tvm/runtime/device_api.h | 10 +- .../org/apache/tvm/contrib/GraphRuntime.java | 53 +- .../main/java/org/apache/tvm/rpc/Client.java | 2 +- .../org/apache/tvm/rpc/NativeServerLoop.java | 2 +- .../java/org/apache/tvm/rpc/RPCSession.java | 4 +- python/tvm/_ffi/_ctypes/packed_func.py | 11 + python/tvm/_ffi/_cython/packed_func.pxi | 12 + python/tvm/_ffi/base.py | 5 +- python/tvm/contrib/cc.py | 10 +- python/tvm/contrib/graph_runtime.py | 9 +- python/tvm/error.py | 5 + python/tvm/rpc/__init__.py | 4 +- python/tvm/rpc/_ffi_api.py | 21 + python/tvm/rpc/base.py | 7 - python/tvm/rpc/client.py | 102 +- python/tvm/rpc/minrpc.py | 79 ++ python/tvm/rpc/proxy.py | 3 +- python/tvm/rpc/server.py | 7 +- python/tvm/runtime/module.py | 6 +- src/runtime/c_runtime_api.cc | 37 + src/runtime/module.cc | 2 +- src/runtime/registry.cc | 22 +- src/runtime/rpc/minrpc/minrpc_server.h | 598 ++++++++ src/runtime/rpc/minrpc/posix_popen_server.cc | 74 + src/runtime/rpc/rpc_channel.cc | 52 + src/runtime/rpc/rpc_channel.h | 98 ++ src/runtime/rpc/rpc_device_api.cc | 53 +- src/runtime/rpc/rpc_endpoint.cc | 1059 ++++++++++++++ src/runtime/rpc/rpc_endpoint.h | 226 +++ src/runtime/rpc/rpc_event_impl.cc | 18 +- src/runtime/rpc/rpc_local_session.cc | 146 ++ src/runtime/rpc/rpc_local_session.h | 82 ++ src/runtime/rpc/rpc_module.cc | 402 ++++-- src/runtime/rpc/rpc_pipe_impl.cc | 133 ++ src/runtime/rpc/rpc_protocol.h | 487 +++++++ src/runtime/rpc/rpc_server_env.cc | 7 +- src/runtime/rpc/rpc_session.cc | 1263 +---------------- src/runtime/rpc/rpc_session.h | 407 ++---- src/runtime/rpc/rpc_socket_impl.cc | 59 +- src/support/arena.h | 163 ++- tests/python/unittest/test_runtime_rpc.py | 120 +- web/tvm_runtime.js | 2 +- 47 files changed, 4070 insertions(+), 1866 deletions(-) create mode 100644 python/tvm/rpc/_ffi_api.py create mode 100644 python/tvm/rpc/minrpc.py create mode 100644 src/runtime/rpc/minrpc/minrpc_server.h create mode 100644 src/runtime/rpc/minrpc/posix_popen_server.cc create mode 100644 src/runtime/rpc/rpc_channel.cc create mode 100644 src/runtime/rpc/rpc_channel.h create mode 100644 src/runtime/rpc/rpc_endpoint.cc create mode 100644 src/runtime/rpc/rpc_endpoint.h create mode 100644 src/runtime/rpc/rpc_local_session.cc create mode 100644 src/runtime/rpc/rpc_local_session.h create mode 100644 src/runtime/rpc/rpc_pipe_impl.cc create mode 100644 src/runtime/rpc/rpc_protocol.h diff --git a/.gitignore b/.gitignore index 068cb87484a0..1fcb2dc2d3fc 100644 --- a/.gitignore +++ b/.gitignore @@ -2,9 +2,10 @@ __pycache__/ *.py[cod] *$py.class - +*.S # C extensions *.so +*.ll # Distribution / packaging .Python diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc index 7a116e89fa88..844f404d98f4 100644 --- a/apps/bundle_deploy/runtime.cc +++ b/apps/bundle_deploy/runtime.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - #include #include #include diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index ea4ab00c113b..57a68f452d3d 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -33,7 +33,7 @@ #include #include "../../src/support/socket.h" -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" #include "rpc_env.h" #include "rpc_server.h" @@ -86,7 +86,7 @@ class RPCServer { tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), custom_addr_(std::move(custom_addr)) { - + } /*! @@ -98,7 +98,7 @@ class RPCServer { tracker_sock_.Close(); listen_sock_.Close(); } catch(...) { - + } } @@ -144,7 +144,7 @@ class RPCServer { } int timeout = GetTimeOutFromOpts(opts); -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) // step 3: serving if (timeout != 0) { const pid_t timer_pid = fork(); @@ -197,7 +197,7 @@ class RPCServer { try { SpawnRPCChild(conn.sockfd, seconds(timeout)); } catch (const std::exception&) { - + } auto dur = high_resolution_clock::now() - start_time; @@ -217,10 +217,10 @@ class RPCServer { * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, + void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, - support::SockAddr* addr, - std::string* opts, + support::SockAddr* addr, + std::string* opts, int ping_period = 2) { std::set old_keyset; std::string matchkey; @@ -330,7 +330,7 @@ void ServerLoopFromChild(SOCKET socket) { tvm::support::TCPSocket sock(socket); const auto env = RPCEnv(); RPCServerLoop(int(sock.sockfd)); - + sock.Close(); env.CleanUp(); } @@ -357,7 +357,7 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc._ServerCreate") +TVM_REGISTER_GLOBAL("rpc.ServerCreate") .set_body([](TVMArgs args, TVMRetValue* rv) { RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index dfd576f4c195..9b9a707dd376 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -31,7 +31,7 @@ #include #include -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_end_point.h" #include "../../src/support/socket.h" namespace tvm { diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 920ecfbf9b13..79bcdc6c0573 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -550,6 +550,54 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); */ TVM_DLL int TVMObjectFree(TVMObjectHandle obj); +/*! + * \brief Allocate a data space on device. + * \param ctx The device context to perform operation. + * \param nbytes The number of bytes in memory. + * \param alignment The alignment of the memory. + * \param type_hint The type of elements. Only needed by certain backends such + * as nbytes & alignment are sufficient for most backends. + * \param out_data The allocated device pointer. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx, + size_t nbytes, + size_t alignment, + DLDataType type_hint, + void** out_data); + +/*! + * \brief Free a data space on device. + * \param ctx The device context to perform operation. + * \param ptr The data space. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMDeviceFreeDataSpace(TVMContext ctx, void* ptr); + +/*! + * \brief Copy data from one place to another. + * \param from The source array. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param num_bytes The size of the memory in bytes + * \param ctx_from The source context + * \param ctx_to The target context + * \param type_hint The type of elements, only neded by certain backends. + * can be useful for cross device endian converison. + * \param stream Optional stream object. + * \return 0 when success, -1 when failure happens. + */ +TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t num_bytes, + TVMContext ctx_from, + TVMContext ctx_to, + DLDataType type_hint, + TVMStreamHandle stream); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index f2ddc84e9f98..12069182354b 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -157,9 +157,9 @@ class TVM_DLL DeviceAPI { * \param event_dst The destination stream to synchronize. */ virtual void SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, - TVMStreamHandle event_dst); - /*! + TVMStreamHandle event_src, + TVMStreamHandle event_dst); + /*! * \brief Allocate temporal workspace for backend execution. * * \note We have the following assumption about backend temporal @@ -176,8 +176,8 @@ class TVM_DLL DeviceAPI { * as OpenGL, as nbytes is sufficient for most backends. */ virtual void* AllocWorkspace(TVMContext ctx, - size_t nbytes, - DLDataType type_hint = {}); + size_t nbytes, + DLDataType type_hint = {}); /*! * \brief Free temporal workspace in backend execution. * diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java index c31c67f283af..61ff966eaf38 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java @@ -38,53 +38,14 @@ public class GraphRuntime { * @return Runtime graph module that can be used to execute the graph. */ public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) { - Module graphModule = null; - if (ctx.deviceType >= RPC.RPC_SESS_MASK) { - if (!(ctx instanceof TVMRemoteContext)) { - throw new IllegalArgumentException( - "Looks like you are using remote context with no RPCSession bind." - + "Use session.context instead."); - } - RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession; - // check arguments - if (!"rpc".equals(libmod.typeKey())) { - throw new IllegalArgumentException("libmod.typeKey != rpc"); - } - final int sessIndex = (int) ((Function) reflectionStaticCall( - RPC.class, "getApi", "_SessTableIndex")) - .pushArg(libmod).invoke().asLong(); - if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) { - throw new IllegalArgumentException(String.format( - "libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d", - sessIndex, reflectionGetField(rpcSession, "tblIndex"))); - } - - Function rpcModuleHandle = (Function) reflectionStaticCall( - RPC.class, "getApi","_ModuleHandle"); - if (rpcModuleHandle == null) { - throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle." - + "Did you compile tvm_runtime with the correct version?"); - } - - Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create"); - if (fcreate == null) { - throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create." - + "Did you compile tvm_runtime with correct version?"); - } - - TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke(); - graphModule = fcreate.call(graphJson, hmod, - ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule(); - } else { - Function fcreate = Function.getFunction("tvm.graph_runtime.create"); - if (fcreate == null) { - throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." - + "Did you compile tvm_runtime with correct version?"); - } - graphModule = fcreate.pushArg(graphJson) - .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) - .invoke().asModule(); + Function fcreate = Function.getFunction("tvm.graph_runtime.create"); + if (fcreate == null) { + throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." + + "Did you compile tvm_runtime with correct version?"); } + Module graphModule = fcreate.pushArg(graphJson) + .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) + .invoke().asModule(); return new GraphModule(graphModule, ctx); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java index 5178ac900a36..69321c3b51c8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java @@ -29,7 +29,7 @@ public class Client { * @return The connected session. */ public static RPCSession connect(String url, int port, String key) { - Function doConnect = RPC.getApi("_Connect"); + Function doConnect = RPC.getApi("Connect"); if (doConnect == null) { throw new RuntimeException("Please compile with USE_RPC=1"); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java index 29a457f39a40..1f3191fb2e8c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java @@ -46,7 +46,7 @@ public NativeServerLoop(final Function fsend, final Function frecv) { try { tempDir = serverEnv(); System.err.println("starting server loop..."); - RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); + RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); System.err.println("done server loop..."); } catch (IOException e) { e.printStackTrace(); diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java index 92b328488b40..b9f621473cf4 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java @@ -39,7 +39,7 @@ public class RPCSession { RPCSession(Module sess) { session = sess; - tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong(); + tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong(); } /** @@ -237,7 +237,7 @@ public byte[] download(String path) { * @return The remote module containing remote function. */ public Module loadModule(String path) { - return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); + return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); } diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index dc2dc1944f30..6d2b966b8815 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -141,6 +141,17 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, TVMContext): values[i].v_int64 = _ctx_to_int64(arg) type_codes[i] = TypeCode.TVM_CONTEXT + elif isinstance(arg, bytes): + byte_arr = bytearray(arg) + arr = TVMByteArray() + arr.data = ctypes.cast( + (ctypes.c_byte * len(arg)).from_buffer(byte_arr), + ctypes.POINTER(ctypes.c_byte)) + arr.size = len(arg) + values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) + temp_args.append(byte_arr) + temp_args.append(arr) + type_codes[i] = TypeCode.BYTES elif isinstance(arg, bytearray): arr = TVMByteArray() arr.data = ctypes.cast( diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 1f68df1885db..4a1bdfd97817 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -142,6 +142,18 @@ cdef inline int make_arg(object arg, value[0].v_ctx = (( ctypes.addressof(arg)))[0] tcode[0] = kTVMContext + elif isinstance(arg, bytes): + byte_arr = bytearray(arg) + arr = TVMByteArray() + arr.data = ctypes.cast( + (ctypes.c_byte * len(arg)).from_buffer(byte_arr), + ctypes.POINTER(ctypes.c_byte)) + arr.size = len(arg) + value[0].v_handle = ( + ctypes.addressof(arr)) + tcode[0] = kTVMBytes + temp_args.append(byte_arr) + temp_args.append(arr) elif isinstance(arg, bytearray): arr = TVMByteArray() arr.data = ctypes.cast( diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 8d3ce19f9444..61360107f671 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -48,8 +48,11 @@ def _load_lib(): """Load libary by searching possible path.""" lib_path = libinfo.find_lib_path() lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) - # DMatrix functions lib.TVMGetLastError.restype = ctypes.c_char_p + # Put the libpath to LD_LIBRARY_PATH + # will be useful for pipe session to find libtvm + os.environ["LD_LIBRARY_PATH"] = "%s:%s" % ( + os.path.dirname(lib_path[0]), os.environ.get("LD_LIBRARY_PATH", "")) return lib, os.path.basename(lib_path[0]) # version number diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index ae37923a1dcf..8ad47acfe989 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -90,7 +90,8 @@ def get_target_triple(): def cross_compiler(compile_func, options=None, output_format=None, - get_target_triple=None): + get_target_triple=None, + add_files=None): """Create a cross compiler function by specializing compile_func with options. This function can be used to construct compile functions that @@ -111,6 +112,10 @@ def cross_compiler(compile_func, get_target_triple: Optional[Callable] Function that can target triple according to dumpmachine option of compiler. + add_files: Optional[List[str]] + List of paths to additional object, source, library files + to pass as part of the compilation. + Returns ------- fcompile : Callable[[str, str, Optional[str]], None] @@ -133,6 +138,7 @@ def cross_compiler(compile_func, """ base_options = [] if options is None else options kwargs = {} + add_files = [] if add_files is None else add_files # handle case where compile_func is the name of the cc if isinstance(compile_func, str): @@ -144,7 +150,7 @@ def _fcompile(outputs, objects, options=None): all_options = base_options if options is not None: all_options += options - compile_func(outputs, objects, options=all_options, **kwargs) + compile_func(outputs, objects + add_files, options=all_options, **kwargs) if not output_format and hasattr(compile_func, "output_format"): output_format = compile_func.output_format diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 73235f71c77b..740d1c3f19f3 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -18,9 +18,10 @@ import numpy as np import tvm._ffi -from .._ffi.base import string_types -from .._ffi.runtime_ctypes import TVMContext -from ..rpc import base as rpc_base +from tvm.rpc import _ffi_api as _rpc_ffi_api +from tvm.rpc import base as rpc_base +from tvm._ffi.base import string_types +from tvm._ffi.runtime_ctypes import TVMContext def create(graph_json_str, libmod, ctx): @@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx): device_type = cur_ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: assert libmod.type_key == "rpc" - assert rpc_base._SessTableIndex( + assert _rpc_ffi_api.SessTableIndex( libmod) == cur_ctx._rpc_sess._tbl_index num_rpc_ctx += 1 device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK diff --git a/python/tvm/error.py b/python/tvm/error.py index 4c3e6060c25a..7366fdb3be39 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -57,6 +57,11 @@ def __init__(self, msg): register_error("KeyError", KeyError) +@register_error +class RPCError(RuntimeError): + """Error thrown by the RPC call.""" + + @register_error class OpError(TVMError): """Base class of all operator errors in frontends.""" diff --git a/python/tvm/rpc/__init__.py b/python/tvm/rpc/__init__.py index 5f959eb44745..b64ba33d9e09 100644 --- a/python/tvm/rpc/__init__.py +++ b/python/tvm/rpc/__init__.py @@ -26,4 +26,6 @@ """ from .server import Server -from .client import RPCSession, LocalSession, TrackerSession, connect, connect_tracker +from .client import connect, connect_tracker +from .client import RPCSession, LocalSession, PopenSession, TrackerSession +from .minrpc import with_minrpc diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py new file mode 100644 index 000000000000..1a7cc739b5c1 --- /dev/null +++ b/python/tvm/rpc/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.rpc""" +import tvm._ffi + + +tvm._ffi._init_api("rpc", __name__) diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index bc81534a12d9..f0e33f8503f2 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -17,8 +17,6 @@ """Base definitions for RPC.""" # pylint: disable=invalid-name -from __future__ import absolute_import - import socket import time import json @@ -26,7 +24,6 @@ import struct import random import logging -import tvm._ffi from .._ffi.base import py_str @@ -176,7 +173,3 @@ def connect_with_retry(addr, timeout=60, retry_period=5): logger.warning("Cannot connect to tracker %s, retry in %g secs...", str(addr), retry_period) time.sleep(retry_period) - - -# Still use tvm.rpc for the foreign functions -tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base") diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index ed57e0d4276d..d4250353f8f9 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -15,19 +15,20 @@ # specific language governing permissions and limitations # under the License. """RPC client tools""" -from __future__ import absolute_import - import os +import stat import socket import struct import time + import tvm._ffi from tvm.contrib import util from tvm._ffi.base import TVMError from tvm.runtime import ndarray as nd -from tvm.runtime import load_module as _load_module from . import base +from . import server +from . import _ffi_api class RPCSession(object): @@ -38,9 +39,23 @@ class RPCSession(object): # pylint: disable=invalid-name def __init__(self, sess): self._sess = sess - self._tbl_index = base._SessTableIndex(sess) + self._tbl_index = _ffi_api.SessTableIndex(sess) self._remote_funcs = {} + def system_lib(self): + """Get system-wide library module. + + Returns + ------- + module : runtime.Module + The system-wide library module. + + See Also + -------- + tvm.runtime.system_lib + """ + return self.get_function("runtime.SystemLib")() + def get_function(self, name): """Get function from the session. @@ -145,7 +160,7 @@ def load_module(self, path): m : Module The remote module containing remote function. """ - return base._LoadRemoteModule(self._sess, path) + return _ffi_api.LoadRemoteModule(self._sess, path) def cpu(self, dev_id=0): """Construct CPU device.""" @@ -183,28 +198,37 @@ class LocalSession(RPCSession): need to be ran both locally and remotely. """ def __init__(self): - # pylint: disable=super-init-not-called - self.context = nd.context - self.get_function = tvm._ffi.get_global_func - self._temp = util.tempdir() + self._temp = server._server_env([]) + RPCSession.__init__(self, _ffi_api.LocalSession()) - def upload(self, data, target=None): - if isinstance(data, bytearray): - if not target: - raise ValueError("target must present when file is a bytearray") - blob = data - else: - blob = bytearray(open(data, "rb").read()) - if not target: - target = os.path.basename(data) - with open(self._temp.relpath(target), "wb") as f: - f.write(blob) - def download(self, path): - return bytearray(open(self._temp.relpath(path), "rb").read()) +@tvm._ffi.register_func("rpc.PopenSession") +def _popen_session(binary): + temp = util.tempdir() - def load_module(self, path): - return _load_module(self._temp.relpath(path)) + if isinstance(binary, (bytes, bytearray)): + path_exec = temp.relpath("server.minrpc") + with open(path_exec, "wb") as outfile: + outfile.write(binary) + os.chmod(path_exec, stat.S_IXUSR) + path_exec = os.path.abspath(path_exec) + else: + path_exec = os.path.abspath(binary) + + sess = _ffi_api.CreatePipeClient(path_exec) + return sess + + +class PopenSession(RPCSession): + """RPCSession interface backed by popen. + + Parameters + ---------- + binary : List[Union[str, bytes]] + The binary to be executed. + """ + def __init__(self, binary): + RPCSession.__init__(self, _popen_session(binary)) class TrackerSession(object): @@ -378,7 +402,7 @@ def request_and_run(self, key, max_retry, str(last_err))) -def connect(url, port, key="", session_timeout=0): +def connect(url, port, key="", session_timeout=0, session_constructor=None): """Connect to RPC Server Parameters @@ -397,15 +421,41 @@ def connect(url, port, key="", session_timeout=0): the connection when duration is longer than this value. When duration is zero, it means the request must always be kept alive. + session_constructor: List + List of additional arguments to passed as the remote session constructor. + Returns ------- sess : RPCSession The connected session. + + Examples + -------- + Normal usage + .. code-block:: python + + client = rpc.connect(server_url, server_port, server_key) + + Session_constructor can be used to customize the session in the remote + The following code connects to a remote internal server via a proxy + by constructing another RPCClientSession on the proxy machine and use that + as the serving session of the proxy endpoint. + + .. code-block:: python + + client_via_proxy = rpc.connect( + proxy_server_url, proxy_server_port, proxy_server_key, + session_constructor=[ + "rpc.Connect", internal_url, internal_port, internal_key]) + """ try: if session_timeout: key += " -timeout=%s" % str(session_timeout) - sess = base._Connect(url, port, key) + session_constructor = session_constructor if session_constructor else [] + if not isinstance(session_constructor, (list, tuple)): + raise TypeError("Expect the session constructor to be a list or tuple") + sess = _ffi_api.Connect(url, port, key, *session_constructor) except NameError: raise RuntimeError("Please compile with USE_RPC=1") return RPCSession(sess) diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py new file mode 100644 index 000000000000..768b52886ca2 --- /dev/null +++ b/python/tvm/rpc/minrpc.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utils to path.""" +import os +from tvm._ffi import libinfo +from tvm.contrib import cc + + +def find_minrpc_server_libpath(server="posix_popen_server"): + """Get the path of minrpc server libary. + + Parameters + ---------- + server : str + The kind of built in minrpc server. + + Returns + ------- + path : str + The path to the min server library. + """ + curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..", "..")) + + path = os.path.join( + source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server)) + + candidates = [path] + if not os.path.isfile(path): + raise RuntimeError("Cannot find minserver %s, in candidates %s" % (server, candidates)) + return path + + +def with_minrpc(compile_func, + server="posix_popen_server", + runtime="libtvm"): + """Attach the compiler function with minrpc related options. + + Parameters + ---------- + compile_func : function + The compilation function to decorate. + + server : str + The server type. + + runtime : str + The runtime library. + + Returns + ------- + fcompile : function + The return compilation. + """ + server_path = find_minrpc_server_libpath(server) + runtime_path = libinfo.find_lib_path( + [runtime, runtime + ".so", runtime + ".dylib"])[0] + + fcompile = cc.cross_compiler( + compile_func, + options=["-std=c++14"] + ["-I" + path for path in libinfo.find_include_path()], + add_files=[server_path, runtime_path]) + fcompile.__name__ = "with_minrpc" + fcompile.need_system_lib = True + return fcompile diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index c3a3647948ee..03746dad6d62 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -42,6 +42,7 @@ raise ImportError( "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg) +from . import _ffi_api from . import base from .base import TrackerCode from .server import _server_env @@ -549,7 +550,7 @@ def _fsend(data): data = bytes(data) conn.write_message(data, binary=True) return len(data) - on_message = base._CreateEventDrivenServer( + on_message = _ffi_api.CreateEventDrivenServer( _fsend, "WebSocketProxyServer", "%toinit") return on_message diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 03749c1c17e4..15a3c7de789d 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -43,6 +43,7 @@ from tvm._ffi.libinfo import find_lib_path from tvm.runtime.module import load_module as _load_module from tvm.contrib import util +from . import _ffi_api from . import base from . base import TrackerCode @@ -56,7 +57,7 @@ def _server_env(load_library, work_path=None): temp = util.tempdir() # pylint: disable=unused-variable - @tvm._ffi.register_func("tvm.rpc.server.workpath") + @tvm._ffi.register_func("tvm.rpc.server.workpath", override=True) def get_workpath(path): return temp.relpath(path) @@ -81,7 +82,7 @@ def _serve_loop(sock, addr, load_library, work_path=None): """Server loop""" sockfd = sock.fileno() temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) + _ffi_api.ServerLoop(sockfd) if not work_path: temp.remove() logger.info("Finish serving %s", addr) @@ -330,7 +331,7 @@ def __init__(self, utvm_dev_config_args=None, ): try: - if base._ServerLoop is None: + if _ffi_api.ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") except NameError: raise RuntimeError("Please compile with USE_RPC=1") diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 716f87f33fc1..b580e3f6dc6d 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -244,6 +244,7 @@ def _dso_exportable(self): def export_library(self, file_name, fcompile=None, + addons=None, **kwargs): """Export the module and its imported device code one library. @@ -283,7 +284,7 @@ def export_library(self, modules = self._collect_dso_modules() temp = _util.tempdir() - files = [] + files = addons if addons else [] is_system_lib = False has_c_module = False llvm_target_triple = None @@ -313,6 +314,9 @@ def export_library(self, if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"): llvm_target_triple = fcompile.get_target_triple() + if getattr(fcompile, "need_system_lib", False) and not is_system_lib: + raise ValueError("%s need --system-lib option" % str(fcompile)) + if self.imported_modules: if enabled("llvm") and llvm_target_triple: path_obj = temp.relpath("devc.o") diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index fb1f74da2103..32b3381eeb2e 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -460,6 +460,7 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* ret_val, int* ret_type_code) { API_BEGIN(); + TVMRetValue rv; (*static_cast(func)).CallPacked( TVMArgs(args, arg_type_codes, num_args), &rv); @@ -585,6 +586,42 @@ int TVMCbArgToReturn(TVMValue* value, int* code) { API_END(); } + +int TVMDeviceAllocDataSpace(DLContext ctx, + size_t nbytes, + size_t alignment, + DLDataType type_hint, + void** out_data) { + API_BEGIN(); + out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace( + ctx, nbytes, alignment, type_hint); + API_END(); +} + +int TVMDeviceFreeDataSpace(DLContext ctx, void* ptr) { + API_BEGIN(); + DeviceAPIManager::Get(ctx)->FreeDataSpace(ctx, ptr); + API_END(); +} + +int TVMDeviceCopyDataFromTo(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t num_bytes, + TVMContext ctx_from, + TVMContext ctx_to, + DLDataType type_hint, + TVMStreamHandle stream) { + API_BEGIN(); + TVMContext ctx = ctx_from.device_type != kDLCPU ? ctx_from : ctx_to; + DeviceAPIManager::Get(ctx)->CopyDataFromTo( + from, from_offset, + to, to_offset, + num_bytes, ctx_from, ctx_to, type_hint, stream); + API_END(); +} + // set device api TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) .set_body([](TVMArgs args, TVMRetValue *ret) { diff --git a/src/runtime/module.cc b/src/runtime/module.cc index d2ed7ff9e2b7..813a79d43c06 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -36,7 +36,7 @@ void ModuleNode::Import(Module other) { if (!std::strcmp(this->type_key(), "rpc")) { static const PackedFunc* fimport_ = nullptr; if (fimport_ == nullptr) { - fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule"); + fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule"); CHECK(fimport_ != nullptr); } (*fimport_)(GetRef(this), other); diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 4717d89e33c1..855a342a7e97 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -37,7 +37,7 @@ struct Registry::Manager { // map storing the functions. // We delibrately used raw pointer // This is because PackedFunc can contain callbacks into the host languge(python) - // and the resource can become invalid because of indeterminstic order of destruction. + // and the resource can become invalid because of indeterminstic order of destruction and forking. // The resources will only be recycled during program exit. std::unordered_map fmap; // mutex @@ -60,20 +60,18 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) return *this; } -Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*) +Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*) Manager* m = Manager::Global(); std::lock_guard lock(m->mutex); - auto it = m->fmap.find(name); - if (it == m->fmap.end()) { - Registry* r = new Registry(); - r->name_ = name; - m->fmap[name] = r; - return *r; - } else { - CHECK(override) - << "Global PackedFunc " << name << " is already registered"; - return *it->second; + if (m->fmap.count(name)) { + CHECK(can_override) + << "Global PackedFunc " << name << " is already registered"; } + + Registry* r = new Registry(); + r->name_ = name; + m->fmap[name] = r; + return *r; } bool Registry::Remove(const std::string& name) { diff --git a/src/runtime/rpc/minrpc/minrpc_server.h b/src/runtime/rpc/minrpc/minrpc_server.h new file mode 100644 index 000000000000..370720d14be6 --- /dev/null +++ b/src/runtime/rpc/minrpc/minrpc_server.h @@ -0,0 +1,598 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file minrpc_server.h + * \brief Minimum RPC server implementation, + * redirects all the calls to C runtime API. + * + * \note This file do not depend on c++ std or c std, + * and only depends on TVM's C runtime API. + */ +#ifndef TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ +#define TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ + +#include +#include +#include "../rpc_protocol.h" +#include "../../../support/arena.h" + +/*! \brief Whether or not to enable glog style DLOG */ +#ifndef TVM_MINRPC_ENABLE_LOGGING +#define TVM_MINRPC_ENABLE_LOGGING 0 +#endif + +#ifndef MINRPC_CHECK +#define MINRPC_CHECK(cond) \ + if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError); +#endif + +#if TVM_MINRPC_ENABLE_LOGGING +#include +#endif + + +namespace tvm { +namespace runtime { + +/*! + * \brief A minimum RPC server that only depends on the tvm C runtime.. + * + * All the dependencies are provided by the io arguments. + * + * \tparam TIOHandler IO provider to provide io handling. + * An IOHandler needs to provide the following functions: + * - PosixWrite, PosixRead, Close: posix style, read, write, close API. + * - Exit: exit with status code. + */ +template +class MinRPCServer { + public: + /*! + * \brief Constructor. + * \param io The IO handler. + */ + explicit MinRPCServer(TIOHandler io) + : io_(io), arena_(PageAllocator(io)) {} + + /*! \brief Run the server loop until shutdown signal is received. */ + void ServerLoop() { + RPCCode code; + uint64_t packet_len; + + while (true) { + arena_.RecycleAll(); + allow_clean_shutdown_ = true; + + this->Read(&packet_len); + if (packet_len == 0) continue; + this->Read(&code); + + allow_clean_shutdown_ = false; + + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscallFunc(code); + } else { + switch (code) { + case RPCCode::kCallFunc: { + HandleNormalCallFunc(); + break; + } + case RPCCode::kInitServer: { + HandleInitServer(); + break; + } + case RPCCode::kCopyFromRemote: { + HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + HandleCopyToRemote(); + break; + } + case RPCCode::kShutdown: { + this->Shutdown(); + return; + } + default: { + this->ThrowError(RPCServerStatus::kUnknownRPCCode); + break; + } + } + } + } + } + + void Shutdown() { + arena_.FreeAll(); + io_.Close(); + } + + void HandleNormalCallFunc() { + uint64_t call_handle; + TVMValue* values; + int* tcodes; + int num_args; + TVMValue ret_value[3]; + int ret_tcode[3]; + + this->Read(&call_handle); + RecvPackedSeq(&values, &tcodes, &num_args); + + int call_ecode = TVMFuncCall( + reinterpret_cast(call_handle), + values, tcodes, num_args, + &(ret_value[1]), &(ret_tcode[1])); + + if (call_ecode == 0) { + // Return value encoding as in LocalSession + int rv_tcode = ret_tcode[1]; + ret_tcode[0] = kDLInt; + ret_value[0].v_int64 = rv_tcode; + if (rv_tcode == kTVMNDArrayHandle) { + ret_tcode[1] = kTVMDLTensorHandle; + ret_value[2].v_handle = ret_value[1].v_handle; + ret_tcode[2] = kTVMOpaqueHandle; + this->ReturnPackedSeq(ret_value, ret_tcode, 3); + } else if (rv_tcode == kTVMPackedFuncHandle || + rv_tcode == kTVMModuleHandle) { + ret_tcode[1] = kTVMOpaqueHandle; + this->ReturnPackedSeq(ret_value, ret_tcode, 2); + } else { + this->ReturnPackedSeq(ret_value, ret_tcode, 2); + } + } else { + this->ReturnLastTVMError(); + } + } + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + + char* data_ptr; + int call_ecode = 0; + if (ctx.device_type == kDLCPU) { + data_ptr = reinterpret_cast(handle) + offset; + } else { + data_ptr = this->ArenaAlloc(num_bytes); + call_ecode = TVMDeviceCopyDataFromTo( + reinterpret_cast(handle), offset, + data_ptr, 0, num_bytes, + ctx, DLContext{kDLCPU, 0}, + type_hint, nullptr); + } + + if (call_ecode == 0) { + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + } else { + this->ReturnLastTVMError(); + } + } + + void HandleCopyToRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + int call_ecode = 0; + + if (ctx.device_type == kDLCPU) { + char* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + } else { + char* temp_data = this->ArenaAlloc(num_bytes); + this->ReadArray(temp_data, num_bytes); + + call_ecode = TVMDeviceCopyDataFromTo( + temp_data, 0, + reinterpret_cast(handle), offset, + num_bytes, + DLContext{kDLCPU, 0}, ctx, + type_hint, nullptr); + } + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void HandleSyscallFunc(RPCCode code) { + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + switch (code) { + case RPCCode::kFreeHandle: { + this->SyscallFreeHandle(values, tcodes, num_args); + break; + } + case RPCCode::kGetGlobalFunc: { + this->SyscallGetGlobalFunc(values, tcodes, num_args); + break; + } + case RPCCode::kDevSetDevice: { + this->ReturnException("SetDevice not supported"); + break; + } + case RPCCode::kDevGetAttr: { + this->ReturnException("GetAttr not supported"); + break; + } + case RPCCode::kDevAllocData: { + this->SyscallDevAllocData(values, tcodes, num_args); + break; + } + case RPCCode::kDevFreeData: { + this->SyscallDevFreeData(values, tcodes, num_args); + break; + } + case RPCCode::kDevStreamSync: { + this->SyscallDevStreamSync(values, tcodes, num_args); + break; + } + case RPCCode::kCopyAmongRemote: { + this->SyscallCopyAmongRemote(values, tcodes, num_args); + break; + } + default: { + this->ReturnException("Syscall not recognized"); + break; + } + } + } + + void HandleInitServer() { + uint64_t len; + this->Read(&len); + char* proto_ver = this->ArenaAlloc(len + 1); + this->ReadArray(proto_ver, len); + + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + MINRPC_CHECK(num_args == 0); + this->ReturnVoid(); + } + + void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[1] == kDLInt); + + void* handle = values[0].v_handle; + int64_t type_code = values[1].v_int64; + int call_ecode; + + if (type_code == kTVMNDArrayHandle) { + call_ecode = TVMArrayFree(static_cast(handle)); + } else if (type_code == kTVMPackedFuncHandle) { + call_ecode = TVMFuncFree(handle); + } else { + MINRPC_CHECK(type_code == kTVMModuleHandle); + call_ecode = TVMModFree(handle); + } + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 1); + MINRPC_CHECK(tcodes[0] == kTVMStr); + + void* handle; + int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 9); + // from, from_offset + MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[1] == kDLInt); + // to, to_offset + MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[3] == kDLInt); + // size + MINRPC_CHECK(tcodes[4] == kDLInt); + // ctx_from, ctx_to + MINRPC_CHECK(tcodes[5] == kTVMContext); + MINRPC_CHECK(tcodes[6] == kTVMContext); + // type_hint, stream + MINRPC_CHECK(tcodes[7] == kTVMDataType); + MINRPC_CHECK(tcodes[8] == kTVMOpaqueHandle); + + void* from = values[0].v_handle; + int64_t from_offset = values[1].v_int64; + void* to = values[2].v_handle; + int64_t to_offset = values[3].v_int64; + int64_t size = values[4].v_int64; + TVMContext ctx_from = values[5].v_ctx; + TVMContext ctx_to = values[6].v_ctx; + DLDataType type_hint = values[7].v_type; + TVMStreamHandle stream = values[8].v_handle; + + int call_ecode = TVMDeviceCopyDataFromTo( + from, from_offset, + to, to_offset, size, + ctx_from, ctx_to, type_hint, stream); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 4); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kDLInt); + MINRPC_CHECK(tcodes[2] == kDLInt); + MINRPC_CHECK(tcodes[3] == kTVMDataType); + + TVMContext ctx = values[0].v_ctx; + int64_t nbytes = values[1].v_int64; + int64_t alignment = values[2].v_int64; + DLDataType type_hint = values[3].v_type; + + void* handle; + int call_ecode = TVMDeviceAllocDataSpace( + ctx, nbytes, alignment, type_hint, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + TVMContext ctx = values[0].v_ctx; + void* handle = values[1].v_handle; + + int call_ecode = TVMDeviceFreeDataSpace(ctx, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + TVMContext ctx = values[0].v_ctx; + void* handle = values[1].v_handle; + + int call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + io_.Exit(static_cast(code)); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + template + void Read(T* data) { + static_assert(std::is_pod::value, "need to be trival"); + this->ReadRawBytes(data, sizeof(T)); + } + + template + void ReadArray(T* data, size_t count) { + static_assert(std::is_pod::value, "need to be trival"); + return this->ReadRawBytes(data, sizeof(T) * count); + } + + template + void Write(const T& data) { + static_assert(std::is_pod::value, "need to be trival"); + return this->WriteRawBytes(&data, sizeof(T)); + } + + template + void WriteArray(T* data, size_t count) { + static_assert(std::is_pod::value, "need to be trival"); + return this->WriteRawBytes(data, sizeof(T) * count); + } + + private: + // Internal allocator that redirects alloc to TVM's C API. + class PageAllocator { + public: + using ArenaPageHeader = tvm::support::ArenaPageHeader; + + explicit PageAllocator(TIOHandler io) + : io_(io) {} + + ArenaPageHeader* allocate(size_t min_size) { + size_t npages = ((min_size + kPageSize - 1) / kPageSize); + void* data; + + if (TVMDeviceAllocDataSpace( + DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign, + DLDataType{kDLInt, 1, 1}, &data) != 0) { + io_.Exit(static_cast(RPCServerStatus::kAllocError)); + } + + ArenaPageHeader* header = static_cast(data); + header->size = npages * kPageSize; + header->offset = sizeof(ArenaPageHeader); + return header; + } + + void deallocate(ArenaPageHeader* page) { + if (TVMDeviceFreeDataSpace(DLContext{kDLCPU, 0}, page) != 0) { + io_.Exit(static_cast(RPCServerStatus::kAllocError)); + } + } + + static const constexpr int kPageSize = 2 << 10; + static const constexpr int kPageAlign = 8; + + private: + TIOHandler io_; + }; + + void RecvPackedSeq(TVMValue** out_values, + int** out_tcodes, + int* out_num_args) { + RPCReference::RecvPackedSeq( + out_values, out_tcodes, out_num_args, this); + } + + void ReturnVoid() { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = + sizeof(code) + sizeof(num_args) + sizeof(tcode); + + this->Write(packet_nbytes); + this->Write(code); + this->Write(num_args); + this->Write(tcode); + } + + void ReturnHandle(void* handle) { + int32_t num_args = 1; + int32_t tcode = kTVMOpaqueHandle; + RPCCode code = RPCCode::kReturn; + uint64_t encode_handle = reinterpret_cast(handle); + + uint64_t packet_nbytes = + sizeof(code) + sizeof(num_args) + + sizeof(tcode) + sizeof(encode_handle); + + this->Write(packet_nbytes); + this->Write(code); + this->Write(num_args); + this->Write(tcode); + this->Write(encode_handle); + } + + void ReturnException(const char* msg) { + RPCReference::ReturnException(msg, this); + } + + void ReturnPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args) { + RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); + } + + void ReturnLastTVMError() { + this->ReturnException(TVMGetLastError()); + } + + void ReadRawBytes(void* data, size_t size) { + char* buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_.PosixRead(buf, size - ndone); + if (ret == 0) { + if (allow_clean_shutdown_) { + this->Shutdown(); + io_.Exit(0); + } else { + this->ThrowError(RPCServerStatus::kReadError); + } + } + if (ret == -1) { + this->ThrowError(RPCServerStatus::kReadError); + } + ndone += ret; + buf += ret; + } + } + + void WriteRawBytes(const void* data, size_t size) { + const char *buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_.PosixWrite(buf, size - ndone); + if (ret == 0 || ret == -1) { + this->ThrowError(RPCServerStatus::kWriteError); + } + buf += ret; + ndone += ret; + } + } + + /*! \brief IO handler. */ + TIOHandler io_; + /*! \brief internal arena. */ + support::GenericArena arena_; + /*! \brief Whether we are in a state that allows clean shutdown. */ + bool allow_clean_shutdown_{true}; + static_assert(DMLC_LITTLE_ENDIAN, "MinRPC only works on little endian."); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ diff --git a/src/runtime/rpc/minrpc/posix_popen_server.cc b/src/runtime/rpc/minrpc/posix_popen_server.cc new file mode 100644 index 000000000000..fdc57112f0b9 --- /dev/null +++ b/src/runtime/rpc/minrpc/posix_popen_server.cc @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Disable constructor to bring minimum dep on c++ABI. +#define TVM_ARENA_HAS_DESTRUCTOR 0 + +#include +#include +#include "minrpc_server.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief IOHandler based on posix API. + */ +class PosixIOHandler { + public: + explicit PosixIOHandler(int read_fd = 0, int write_fd = 1) + : read_fd_(read_fd), write_fd_(write_fd) { + } + + ssize_t PosixRead(void* data, size_t size) { + return read(read_fd_, data, size); + } + + ssize_t PosixWrite(const void* data, size_t size) { + return write(write_fd_, data, size); + } + + void Exit(int code) { + exit(code); + } + + void Close() { + if (read_fd_ != 0) close(read_fd_); + if (write_fd_ != 0) close(write_fd_); + } + + private: + int read_fd_{0}; + int write_fd_{1}; +}; + +/*! \brief Type for the posix version of min rpc server. */ +using PosixMinRPCServer = MinRPCServer; + +} // namespace runtime +} // namespace tvm + +int main(int argc, char* argv[]) { + if (argc != 3) return -1; + // pass the descriptor via arguments. + tvm::runtime::PosixIOHandler handler(atoi(argv[1]), atoi(argv[2])); + tvm::runtime::PosixMinRPCServer server(handler); + server.ServerLoop(); + return 0; +} diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc new file mode 100644 index 000000000000..f8dc6e636324 --- /dev/null +++ b/src/runtime/rpc/rpc_channel.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.cc + */ +#include +#include "rpc_channel.h" + +namespace tvm { +namespace runtime { + +size_t CallbackChannel::Send(const void* data, size_t size) { + TVMByteArray bytes; + bytes.data = static_cast(data); + bytes.size = size; + int64_t n = fsend_(bytes); + if (n == -1) { + LOG(FATAL) << "CallbackChannel::Send"; + } + return static_cast(n); +} + +size_t CallbackChannel::Recv(void* data, size_t size) { + TVMRetValue ret = frecv_(size); + + if (ret.type_code() != kTVMBytes) { + LOG(FATAL) << "CallbackChannel::Recv"; + } + std::string* bytes = ret.ptr(); + memcpy(static_cast(data), bytes->c_str(), bytes->length()); + return bytes->length(); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h new file mode 100644 index 000000000000..be34a8b50440 --- /dev/null +++ b/src/runtime/rpc/rpc_channel.h @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_H_ +#define TVM_RUNTIME_RPC_RPC_CHANNEL_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Abstract channel interface used to create RPCEndpoint. + */ +class RPCChannel { + public: + /*! \brief virtual destructor */ + virtual ~RPCChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + virtual size_t Send(const void* data, size_t size) = 0; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + virtual size_t Recv(void* data, size_t size) = 0; +}; + +/*! + * \brief RPC channel which callback + * frontend (Python/Java/etc.)'s send & recv function + */ +class CallbackChannel final : public RPCChannel { + public: + /*! + * \brief Constructor. + * + * \param fsend The send function, takes in a TVMByteArray and returns the + * number of bytes sent in that array. Returns -1 if error happens. + * \param frecv The recv function, takes an expected maximum size, and return + * a byte array with the actual amount of data received. + */ + explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) + : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} + + ~CallbackChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + size_t Send(const void* data, size_t size) final; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + size_t Recv(void* data, size_t size) final; + + private: + PackedFunc fsend_; + PackedFunc frecv_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_H_ diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 9fd45acd14bf..ade4d1683fb1 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include "rpc_session.h" namespace tvm { @@ -31,20 +32,24 @@ namespace runtime { class RPCDeviceAPI final : public DeviceAPI { public: void SetDevice(TVMContext ctx) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevSetDevice, ctx); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->SetDevice(remote_ctx); } + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { - *rv = GetSess(ctx)->CallRemote( - RPCCode::kDevGetAttr, ctx, static_cast(kind)); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv); } + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { auto sess = GetSess(ctx); - void *data = sess->CallRemote( - RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + auto remote_ctx = RemoveSessMask(ctx); + void *data = sess->GetDeviceAPI(remote_ctx)->AllocDataSpace( + remote_ctx, nbytes, alignment, type_hint); + RemoteSpace* space = new RemoteSpace(); space->data = data; space->sess = std::move(sess); @@ -52,9 +57,10 @@ class RPCDeviceAPI final : public DeviceAPI { } void FreeDataSpace(TVMContext ctx, void* ptr) final { RemoteSpace* space = static_cast(ptr); + auto remote_ctx = RemoveSessMask(ctx); try { - GetSess(ctx)->CallRemote( - RPCCode::kDevFreeData, ctx, space->data); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace( + remote_ctx, space->data); } catch (const dmlc::Error& e) { // fault tolerance to remote close. } @@ -75,29 +81,35 @@ class RPCDeviceAPI final : public DeviceAPI { to_dev_type > kRPCSessMask) { CHECK(ctx_from.device_type == ctx_to.device_type) << "Cannot copy across two different remote session"; - GetSess(ctx_from)->CallRemote( - RPCCode::kCopyAmongRemote, - static_cast(from)->data, from_offset, - static_cast(to)->data, to_offset, - size, ctx_from, ctx_to, type_hint, stream); + auto remote_ctx_from = RemoveSessMask(ctx_from); + auto remote_ctx_to = RemoveSessMask(ctx_to); + auto remote_ctx = remote_ctx_from; + if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to; + GetSess(ctx_from)->GetDeviceAPI(remote_ctx) + ->CopyDataFromTo(static_cast(from)->data, from_offset, + static_cast(to)->data, to_offset, + size, remote_ctx_from, remote_ctx_to, type_hint, stream); } else if (from_dev_type > kRPCSessMask && to_dev_type == kDLCPU) { + auto remote_ctx_from = RemoveSessMask(ctx_from); GetSess(ctx_from)->CopyFromRemote( static_cast(from)->data, from_offset, - to, to_offset, size, ctx_from, type_hint); + to, to_offset, size, remote_ctx_from, type_hint); } else if (from_dev_type == kDLCPU && to_dev_type > kRPCSessMask) { + auto remote_ctx_to = RemoveSessMask(ctx_to); GetSess(ctx_to)->CopyToRemote( - (void*)from, from_offset, // NOLINT(*) + const_cast(from), from_offset, static_cast(to)->data, to_offset, - size, ctx_to, type_hint); + size, remote_ctx_to, type_hint); } else { LOG(FATAL) << "expect copy from/to remote or between remote"; } } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevStreamSync, ctx, stream); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->StreamSync(ctx, stream); } private: @@ -107,6 +119,11 @@ class RPCDeviceAPI final : public DeviceAPI { int tbl_index = dev_type / kRPCSessMask - 1; return RPCSession::Get(tbl_index); } + + static TVMContext RemoveSessMask(TVMContext ctx) { + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } }; TVM_REGISTER_GLOBAL("device_api.rpc") diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc new file mode 100644 index 000000000000..916ecaee8a78 --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -0,0 +1,1059 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_session.cc + * \brief RPC session for remote function call. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rpc_endpoint.h" +#include "rpc_local_session.h" +#include "../object_internal.h" +#include "../../support/ring_buffer.h" +#include "../../support/arena.h" + +namespace tvm { +namespace runtime { + +/*! + * Event-driven state-machine based handlers for RPCEndpoint. + * + * Key functions: + * + * - SendPackedSeq: send the arguments over to the peer + * - HandleNextEvent: handle the next request from the peer(RPCCode followed by per code protocol). + */ +class RPCEndpoint::EventHandler : public dmlc::Stream { + public: + EventHandler(support::RingBuffer* reader, + support::RingBuffer* writer, + std::string name, + std::string* remote_key) + : reader_(reader), + writer_(writer), + name_(name), + remote_key_(remote_key) { + this->Clear(); + + if (*remote_key == "%toinit") { + state_ = kInitHeader; + remote_key_->resize(0); + pending_request_bytes_ = sizeof(int32_t); + } + } + + /*! + * \brief Bytes needed to fulfill current request + */ + size_t BytesNeeded() const { + if (reader_->bytes_available() < pending_request_bytes_) { + return pending_request_bytes_ - reader_->bytes_available(); + } else { + return 0; + } + } + + /*! + * \brief Request number of bytes from the reader. + * \param nbytes The number of bytes + */ + void RequestBytes(size_t nbytes) { + pending_request_bytes_ += nbytes; + reader_->Reserve(pending_request_bytes_); + } + + /*! \return Whether we are ready to handle next request. */ + bool Ready() const { + return reader_->bytes_available() >= pending_request_bytes_; + } + + /*! \return Whether we can perform a clean shutdown */ + bool CanCleanShutdown() const { + return state_ == kRecvPacketNumBytes; + } + + /*! \brief Finish the copy ack stage. */ + void FinishCopyAck() { + this->SwitchToState(kRecvPacketNumBytes); + } + + /*! + * \brief Enter the io loop until the next event. + * \param client_mode Whether we are in the client. + * \param setreturn The function to set the return value encoding. + * \return The function to set return values when there is a return event. + */ + RPCCode HandleNextEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { + std::swap(client_mode_, client_mode); + + while (this->Ready()) { + switch (state_) { + case kInitHeader: HandleInitHeader(); break; + case kRecvPacketNumBytes: { + uint64_t packet_nbytes; + CHECK(this->Read(&packet_nbytes)); + if (packet_nbytes != 0) { + this->SwitchToState(kProcessPacket); + this->RequestBytes(packet_nbytes); + } else { + this->SwitchToState(kRecvPacketNumBytes); + } + break; + } + case kProcessPacket: { + this->HandleProcessPacket(setreturn); + break; + } + case kReturnReceived: { + this->SwitchToState(kRecvPacketNumBytes); + std::swap(client_mode_, client_mode); + return RPCCode::kReturn; + } + case kCopyAckReceived: { + std::swap(client_mode_, client_mode); + return RPCCode::kCopyAck; + } + case kShutdownReceived: { + std::swap(client_mode_, client_mode); + return RPCCode::kShutdown; + } + } + } + std::swap(client_mode_, client_mode); + return RPCCode::kNone; + } + + /*! \brief Clear all the states in the Handler.*/ + void Clear() { + state_ = kRecvPacketNumBytes; + pending_request_bytes_ = sizeof(uint64_t); + } + + /*! + * \brief Validate that the arguments can be sent through RPC. + * \param arg_values The argument values. + * \param type_codes The type codes. + */ + void ValidateArguments(const TVMValue* arg_values, + const int* type_codes, + int num_args) { + TVMArgs args(arg_values, type_codes, num_args); + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { + LOG(FATAL) << "ValueError: Cannot pass argument " << i + << ", type " << args[i].AsObjectRef()->GetTypeKey() + << " is not supported by RPC"; + } else if (tcode == kTVMContext) { + DLContext ctx = args[i]; + CHECK_LT(static_cast(ctx.device_type), kRPCSessMask) + << "InternalError: cannot pass RPC context in the channel"; + } + } + } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code); + } + + uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode) { + return RPCReference::PackedSeqGetNumBytes( + arg_values, type_codes, num_args, client_mode, this); + } + + void SendPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode) { + RPCReference::SendPackedSeq( + arg_values, type_codes, num_args, client_mode, this); + } + + // Endian aware IO handling + using Stream::Read; + using Stream::Write; + using Stream::ReadArray; + using Stream::WriteArray; + + bool Read(RPCCode* code) { + int32_t cdata; + if (!this->Read(&cdata)) return false; + *code = static_cast(cdata); + return true; + } + void Write(RPCCode code) { + int32_t cdata = static_cast(code); + this->Write(cdata); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + protected: + enum State { + kInitHeader, + kRecvPacketNumBytes, + kProcessPacket, + kReturnReceived, + kCopyAckReceived, + kShutdownReceived + }; + // Current state; + State state_; + // Initialize remote header + bool init_header_step_{0}; + // Whether current handler is client or server mode. + bool client_mode_{false}; + // Internal arena + support::Arena arena_; + + // State switcher + void SwitchToState(State state) { + // invariant + if (state != kCopyAckReceived) { + CHECK_EQ(pending_request_bytes_, 0U) + << "state=" << state; + } + state_ = state; + CHECK(state != kInitHeader) + << "cannot switch to init header"; + if (state == kRecvPacketNumBytes) { + this->RequestBytes(sizeof(uint64_t)); + // recycle arena for the next session. + arena_.RecycleAll(); + } + } + + // handler for initial header read + void HandleInitHeader() { + if (init_header_step_ == 0) { + int32_t len; + this->Read(&len); + remote_key_->resize(len); + init_header_step_ = 1; + this->RequestBytes(len); + return; + } else { + CHECK_EQ(init_header_step_, 1); + this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); + this->SwitchToState(kRecvPacketNumBytes); + } + } + + // Handler for read code. + void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kNone; + this->Read(&code); + + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscall(code); + } else { + switch (code) { + case RPCCode::kInitServer: { + this->HandleInitServer(); + break; + } + case RPCCode::kCallFunc: { + this->HandleNormalCallFunc(); + break; + } + case RPCCode::kCopyFromRemote: { + this->HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + this->HandleCopyToRemote(); + break; + } + case RPCCode::kException: + case RPCCode::kReturn: { + this->HandleReturn(code, setreturn); + break; + } + case RPCCode::kCopyAck: { + this->SwitchToState(kCopyAckReceived); + break; + } + case RPCCode::kShutdown: { + this->SwitchToState(kShutdownReceived); + break; + } + default: LOG(FATAL) << "Unknown event " << static_cast(code); + } + } + } + + /*! + * \brief Recive incoming packed seq from the stream. + * \return The received argments. + * \note The TVMArgs is available until we switchstate. + */ + TVMArgs RecvPackedSeq() { + TVMValue* values; + int* tcodes; + int num_args; + RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this); + return TVMArgs(values, tcodes, num_args); + } + + /*! + * \brief Return exception to the remote. + * \param err_msg The error message. + */ + void ReturnException(const char* err_msg) { + RPCReference::ReturnException(err_msg, this); + } + + /*! + * \brief Return nullptr to the remote. + * \param err_msg The error message. + */ + void ReturnVoid() { + RPCReference::ReturnVoid(this); + } + + /*! + * \brief Return a packed sequence to the remote. + * \param args The arguments. + */ + void ReturnPackedSeq(TVMArgs args) { + RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.size(), this); + } + + /*! + * \brief Handle the case when return/exception value is received. + * \param code The RPC code. + * \param setreturn The function to encode return. + */ + void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) { + TVMArgs args = RecvPackedSeq(); + + if (code == RPCCode::kException) { + // switch to the state before sending exception. + this->SwitchToState(kRecvPacketNumBytes); + std::string msg = args[0]; + LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg; + } + + CHECK(setreturn != nullptr) << "fsetreturn not available"; + setreturn(args); + + this->SwitchToState(kReturnReceived); + } + + void HandleSyscall(RPCCode code); + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + + char* data_ptr; + + if (ctx.device_type == kDLCPU) { + data_ptr = reinterpret_cast(handle) + offset; + // endian aware handling + if (!DMLC_IO_NO_ENDIAN_SWAP) { + char* temp = this->ArenaAlloc(num_bytes); + std::memcpy(temp, data_ptr, num_bytes); + dmlc::ByteSwap(temp, elem_bytes, num_bytes / elem_bytes); + data_ptr = temp; + } + } else { + try { + data_ptr = this->ArenaAlloc(num_bytes); + GetServingSession()->CopyFromRemote( + reinterpret_cast(handle), offset, + data_ptr, 0, + num_bytes, ctx, type_hint); + // endian aware handling + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes); + } + } catch (const std::runtime_error &e) { + this->ReturnException(e.what()); + this->SwitchToState(kRecvPacketNumBytes); + return; + } + } + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + // Return Copy Ack + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + + this->SwitchToState(kRecvPacketNumBytes); + } + + void HandleCopyToRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + + if (ctx.device_type == kDLCPU) { + char* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes); + } + } else { + char* temp_data = this->ArenaAlloc(num_bytes); + this->ReadArray(temp_data, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(temp_data, elem_bytes, num_bytes / elem_bytes); + } + + try { + GetServingSession()->CopyToRemote( + temp_data, 0, + reinterpret_cast(handle), offset, + num_bytes, ctx, type_hint); + } catch (const std::runtime_error &e) { + this->ReturnException(e.what()); + this->SwitchToState(kRecvPacketNumBytes); + return; + } + } + + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); + } + + // Handle for packed call. + void HandleNormalCallFunc() { + uint64_t call_handle; + + this->Read(&call_handle); + TVMArgs args = RecvPackedSeq(); + + try { + GetServingSession()->CallFunc( + reinterpret_cast(call_handle), + args.values, args.type_codes, args.size(), + [this](TVMArgs ret) { this->ReturnPackedSeq(ret); }); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + } + + this->SwitchToState(kRecvPacketNumBytes); + } + + void HandleInitServer() { + std::string client_protocol_ver; + + uint64_t len; + this->Read(&len); + client_protocol_ver.resize(len); + this->Read(dmlc::BeginPtr(client_protocol_ver), len); + + TVMArgs args = RecvPackedSeq(); + + try { + CHECK(serving_session_ == nullptr) + << "Server has already been initialized"; + + std::string server_protocol_ver = kRPCProtocolVer; + CHECK_EQ(client_protocol_ver, server_protocol_ver) + << "Server[" << name_ << "]: Client protocol version mismatch with the server " + << " server protocol=" << server_protocol_ver + << ", client protocol=" << client_protocol_ver; + + if (args.size() == 0) { + serving_session_ = std::make_shared(); + } else { + std::string constructor_name = args[0]; + auto* fconstructor = Registry::Get(constructor_name); + CHECK(fconstructor != nullptr) + << " Cannot find session constructor " << constructor_name; + TVMRetValue con_ret; + + try { + fconstructor->CallPacked( + TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), &con_ret); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "Server[" << name_ << "]:" + << " Error caught from session constructor " << constructor_name + << ":\n" << e.what(); + } + + CHECK_EQ(con_ret.type_code(), kTVMModuleHandle) + << "Server[" << name_ << "]:" + << " Constructor " << constructor_name + << " need to return an RPCModule"; + Module mod = con_ret; + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") + << "Constructor " << constructor_name << " to return an RPCModule"; + serving_session_ = RPCModuleGetSession(mod); + } + + this->ReturnVoid(); + } catch (const std::runtime_error &e) { + this->ReturnException(e.what()); + } + + this->SwitchToState(kRecvPacketNumBytes); + } + + // Handler for special syscalls that have a specific RPCCode. + template + void SysCallHandler(F f) { + TVMArgs args = RecvPackedSeq(); + try { + TVMRetValue rv; + f(GetServingSession(), args, &rv); + TVMValue ret_value; + int ret_tcode; + TVMArgsSetter setter(&ret_value, &ret_tcode); + setter(0, rv); + + this->ReturnPackedSeq(TVMArgs(&ret_value, &ret_tcode, 1)); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + } + this->SwitchToState(kRecvPacketNumBytes); + } + + private: + RPCSession* GetServingSession() const { + CHECK(serving_session_ != nullptr) + << "Need to call InitRemoteSession first before any further actions"; + return serving_session_.get(); + } + // Utility functions + // Internal read function, update pending_request_bytes_ + size_t Read(void* data, size_t size) final { + CHECK_LE(size, pending_request_bytes_); + reader_->Read(data, size); + pending_request_bytes_ -= size; + return size; + } + // wriite the data to the channel. + void Write(const void* data, size_t size) final { + writer_->Write(data, size); + } + // Number of pending bytes requests + size_t pending_request_bytes_{0}; + // The ring buffer to read data from. + support::RingBuffer* reader_; + // The ringr buffer to write reply to. + support::RingBuffer* writer_; + // The session used to serve the RPC requests. + std::shared_ptr serving_session_; + // Name of endpoint. + std::string name_; + // remote key + std::string* remote_key_; +}; + +RPCCode RPCEndpoint::HandleUntilReturnEvent( + bool client_mode, RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kCallFunc; + while (code != RPCCode::kReturn && + code != RPCCode::kShutdown && + code != RPCCode::kCopyAck) { + while (writer_.bytes_available() != 0) { + writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + } + size_t bytes_needed = handler_->BytesNeeded(); + if (bytes_needed != 0) { + size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { + return channel_->Recv(data, size); + }, bytes_needed); + if (n == 0) { + if (handler_->CanCleanShutdown()) { + return RPCCode::kShutdown; + } else { + LOG(FATAL) << "Channel closes before we get neded bytes"; + } + } + } + code = handler_->HandleNextEvent(client_mode, setreturn); + } + return code; +} + +void RPCEndpoint::Init() { + // Event handler + handler_ = std::make_shared( + &reader_, &writer_, name_, &remote_key_); + // Quick function to for syscall remote. + syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) { + std::lock_guard lock(mutex_); + RPCCode code = static_cast(all_args[0].operator int()); + TVMArgs args(all_args.values + 1, all_args.type_codes +1, all_args.num_args -1); + + uint64_t packet_nbytes = + sizeof(code) + + handler_->PackedSeqGetNumBytes( + args.values, args.type_codes, args.num_args, true); + + // All packet begins with packet nbytes + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + + code = HandleUntilReturnEvent(true, [rv](TVMArgs args) { + CHECK_EQ(args.size(), 1); + *rv = args[0]; + }); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); + }); +} + +std::shared_ptr RPCEndpoint::Create( + std::unique_ptr channel, + std::string name, + std::string remote_key) { + std::shared_ptr endpt = std::make_shared(); + endpt->channel_ = std::move(channel); + endpt->name_ = std::move(name); + endpt->remote_key_ = std::move(remote_key); + endpt->Init(); + return endpt; +} + +RPCEndpoint::~RPCEndpoint() { + this->Shutdown(); +} + +void RPCEndpoint::Shutdown() { + if (channel_ != nullptr) { + RPCCode code = RPCCode::kShutdown; + uint64_t packet_nbytes = sizeof(code); + + handler_->Write(packet_nbytes); + handler_->Write(code); + + // flush all writing buffer to output channel. + try { + while (writer_.bytes_available() != 0) { + size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + if (n == 0) break; + } + } catch (const dmlc::Error& e) { + } + channel_.reset(nullptr); + } +} + +void RPCEndpoint::ServerLoop() { + if (const auto* f = Registry::Get("tvm.rpc.server.start")) { + (*f)(); + } + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown); + if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { + (*f)(); + } + channel_.reset(nullptr); +} + +int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) { + RPCCode code = RPCCode::kNone; + if (in_bytes.length() != 0) { + reader_.Write(in_bytes.c_str(), in_bytes.length()); + code = handler_->HandleNextEvent(false, [](TVMArgs) {}); + } + if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { + writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + } + CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); + if (code == RPCCode::kShutdown) return 0; + if (writer_.bytes_available() != 0) return 2; + return 1; +} + +void RPCEndpoint::InitRemoteSession(TVMArgs args) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kInitServer; + std::string protocol_ver = kRPCProtocolVer; + uint64_t length = protocol_ver.length(); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(length) + + length + + handler_->PackedSeqGetNumBytes( + args.values, args.type_codes, args.num_args, true); + + // All packet begins with packet nbytes + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(length); + handler_->WriteArray(protocol_ver.data(), length); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + + code = HandleUntilReturnEvent(true, [](TVMArgs args) {}); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +// Get remote function with name +void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + RPCSession::FEncodeReturn encode_return) { + std::lock_guard lock(mutex_); + + handler_->ValidateArguments(arg_values, arg_type_codes, num_args); + RPCCode code = RPCCode::kCallFunc; + uint64_t handle = reinterpret_cast(h); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(handle) + + handler_->PackedSeqGetNumBytes( + arg_values, arg_type_codes, num_args, true); + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->SendPackedSeq( + arg_values, arg_type_codes, num_args, true); + + code = HandleUntilReturnEvent(true, encode_return); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +void RPCEndpoint::CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t data_size, + TVMContext ctx_to, + DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyToRemote; + uint64_t handle = reinterpret_cast(to); + uint64_t offset = static_cast(to_offset); + uint64_t size = static_cast(data_size); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(handle) + + sizeof(offset) + + sizeof(size) + + sizeof(ctx_to) + + sizeof(type_hint) + + data_size; + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->Write(offset); + handler_->Write(size); + handler_->Write(ctx_to); + handler_->Write(type_hint); + handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); + + CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kReturn); +} + +void RPCEndpoint::CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t data_size, + TVMContext ctx_from, + DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyFromRemote; + uint64_t handle = reinterpret_cast(from); + uint64_t offset = static_cast(from_offset); + uint64_t size = static_cast(data_size); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(handle) + + sizeof(offset) + + sizeof(size) + + sizeof(ctx_from) + + sizeof(type_hint); + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->Write(offset); + handler_->Write(size); + handler_->Write(ctx_from); + handler_->Write(type_hint); + + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kCopyAck); + handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); + handler_->FinishCopyAck(); +} + +// SysCallEventHandler functions +void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + *rv = handler->GetFunction(name); +} + +void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + void* handle = args[0]; + int type_code = args[1]; + handler->FreeHandle(handle, type_code); +} + +void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + handler->GetDeviceAPI(ctx)->SetDevice(ctx); +} + +void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + DeviceAttrKind kind = static_cast(args[1].operator int()); + if (kind == kExist) { + DeviceAPI* api = handler->GetDeviceAPI(ctx, true); + if (api != nullptr) { + api->GetAttr(ctx, kind, rv); + } else { + *rv = 0; + } + } else { + handler->GetDeviceAPI(ctx)->GetAttr( + ctx, static_cast(kind), rv); + } +} + +void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + uint64_t nbytes = args[1]; + uint64_t alignment = args[2]; + DLDataType type_hint = args[3]; + void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace( + ctx, nbytes, alignment, type_hint); + *rv = data; +} + +void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + void* ptr = args[1]; + handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr); +} + +void RPCDevStreamSync(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + TVMStreamHandle handle = args[1]; + handler->GetDeviceAPI(ctx)->StreamSync(ctx, handle); +} + +void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + void* from = args[0]; + uint64_t from_offset = args[1]; + void* to = args[2]; + uint64_t to_offset = args[3]; + uint64_t size = args[4]; + TVMContext ctx_from = args[5]; + TVMContext ctx_to = args[6]; + DLDataType type_hint = args[7]; + TVMStreamHandle stream = args[8]; + TVMContext ctx = ctx_from; + + if (ctx.device_type == kDLCPU) { + ctx = ctx_to; + } else { + CHECK(ctx_to.device_type == kDLCPU || + ctx_to.device_type == ctx_from.device_type) + << "Can not copy across different ctx types directly"; + } + handler->GetDeviceAPI(ctx)->CopyDataFromTo( + from, from_offset, + to, to_offset, + size, ctx_from, ctx_to, type_hint, stream); +} + +void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { + // Event handler sit at clean state at this point. + switch (code) { + // system functions + case RPCCode::kFreeHandle: SysCallHandler(RPCFreeHandle); break; + case RPCCode::kGetGlobalFunc: SysCallHandler(RPCGetGlobalFunc); break; + case RPCCode::kDevSetDevice: SysCallHandler(RPCDevSetDevice); break; + case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break; + case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break; + case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; + case RPCCode::kDevStreamSync: SysCallHandler(RPCDevStreamSync); break; + case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; + default: LOG(FATAL) << "Unknown event " << static_cast(code); + } + + CHECK_EQ(state_, kRecvPacketNumBytes); +} + +/*! + * \brief RPC client session that proxies all calls to an endpoint. + */ +class RPCClientSession : public RPCSession, + public DeviceAPI { + public: + /*! + * \brief param endpoint The client endpoint of the session. + */ + explicit RPCClientSession(std::shared_ptr endpoint) + : endpoint_(endpoint) {} + + // function overrides + PackedFuncHandle GetFunction(const std::string& name) final { + return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name); + } + + void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) final { + endpoint_->CallFunc( + func, arg_values, arg_type_codes, num_args, fencode_return); + } + + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) final { + endpoint_->CopyToRemote( + from, from_offset, to, to_offset, nbytes, ctx_to, type_hint); + } + + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) final { + endpoint_->CopyFromRemote( + from, from_offset, to, to_offset, nbytes, ctx_from, type_hint); + } + + void FreeHandle(void* handle, int type_code) final { + endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code); + } + + + void SetDevice(TVMContext ctx) final { + endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx); + } + + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { + if (ctx.device_type == kDLCPU && kind == kExist) { + // cpu always exists. + *rv = 1; + } else { + *rv = endpoint_->SysCallRemote(RPCCode::kDevGetAttr, ctx, static_cast(kind)); + } + } + + void* AllocDataSpace(TVMContext ctx, + size_t nbytes, + size_t alignment, + DLDataType type_hint) final { + return endpoint_->SysCallRemote( + RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + } + + void FreeDataSpace(TVMContext ctx, void* ptr) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeData, ctx, ptr); + } + + void CopyDataFromTo(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t size, + TVMContext ctx_from, + TVMContext ctx_to, + DLDataType type_hint, + TVMStreamHandle stream) final { + endpoint_->SysCallRemote( + RPCCode::kCopyAmongRemote, + const_cast(from), from_offset, + to, to_offset, + size, + ctx_from, ctx_to, + type_hint, stream); + } + + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevStreamSync, ctx, stream); + } + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final { + return this; + } + + private: + std::shared_ptr endpoint_; +}; + +std::shared_ptr +CreateClientSession(std::shared_ptr endpoint) { + return std::make_shared(endpoint); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h new file mode 100644 index 000000000000..0fc064baa702 --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.h @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_endpoint.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ +#define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ + +#include +#include +#include +#include +#include +#include "rpc_session.h" +#include "rpc_channel.h" +#include "rpc_protocol.h" +#include "../../support/ring_buffer.h" + +namespace tvm { +namespace runtime { + +// Magic header for RPC data plane +const int kRPCMagic = 0xff271; +// magic header for RPC tracker(control plane) +const int kRPCTrackerMagic = 0x2f271; +// sucess response +const int kRPCSuccess = kRPCMagic + 0; +// cannot found matched key in server +const int kRPCMismatch = kRPCMagic + 2; + +/*! \brief Enumeration code for the RPC tracker */ +enum class TrackerCode : int { + kFail = -1, + kSuccess = 0, + kPing = 1, + kStop = 2, + kPut = 3, + kRequest = 4, + kUpdateInfo = 5, + kSummary = 6, + kGetPendingMatchKeys = 7 +}; + + +/*! + * \brief Communication endpoints to connect local and remote RPC sessions. + * An endpoint can either be a client or a server. + */ +class RPCEndpoint { + public: + /*! \brief virtual destructor */ + ~RPCEndpoint(); + /*! + * \brief The server loop that server runs to handle RPC calls. + */ + void ServerLoop(); + /*! + * \brief Message handling function for an async IO event driven server. + * + * Called when the server receives a message or an IO event update. + * Event driven handler will never call recv on the channel + * and always relies on the ServerIOEventHandler to receive the data. + * + * \param in_bytes The incoming bytes. + * \param event_flag 1: read_available, 2: write_avaiable. + * \return State flag. + * 1: continue running, no need to write, + * 2: need to write + * 0: shutdown + */ + int ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag); + + /*! + * \brief Initalize the session on the remote that will be used to back all the RPC requests. + * + * If no session constructor arguments is passed, LocalSession will be used in the remote. + * Otherwise the remote serving session will be constructed using the arguments + * specified in the session_constructor. + * + * The construction rule can be summarized as follows: + * + * \code + * + * auto args = session_constructor; + * int n = args.size(); + * if (n != 0) { + * std::string constructor = args[0]; + * server.serving_session_ = GetGlobalFunc(constructor)( + * args[1], args[2] ... args[n - 1]) + * } else { + * server.serving_session_ = LocalSession(); + * } + * \endcode + * + * \param session_constructor Optional sequence of the remote sesssion constructor. + */ + void InitRemoteSession(TVMArgs session_constructor); + + /*! + * \brief Call into remote function + * \param handle The function handle + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to receive return value encodings. + */ + void CallFunc(RPCSession::PackedFuncHandle handle, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + RPCSession::FEncodeReturn encode_return); + /*! + * \brief Copy bytes into remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_to The target context. + * \param type_hint Hint of content data type. + */ + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint); + /*! + * \brief Copy bytes from remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_from The source context. + * \param type_hint Hint of content data type. + */ + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint); + + /*! + * \brief Call a remote defined system function with arguments. + * \param fcode The function code. + * \param args The arguments + * \return The returned remote value. + */ + template + inline TVMRetValue SysCallRemote(RPCCode fcode, Args&& ...args); + /*! + * \brief Create a RPC session with given channel. + * \param channel The communication channel. + * \param name The local name of the session, used for debug + * \param remote_key The remote key of the session + * if remote_key equals "%toinit", we need to re-intialize + * it by event handler. + */ + static std::shared_ptr Create( + std::unique_ptr channel, + std::string name, + std::string remote_key); + + private: + class EventHandler; + // Handle events until receives a return + // Also flushes channels so that the function advances. + RPCCode HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn); + // Initalization + void Init(); + // Shutdown + void Shutdown(); + // Internal channel. + std::unique_ptr channel_; + // Internal mutex + std::mutex mutex_; + // Internal ring buffer. + support::RingBuffer reader_, writer_; + // Event handler. + std::shared_ptr handler_; + // syscall remote with specified function code. + PackedFunc syscall_remote_; + // The name of the session. + std::string name_; + // The remote key + std::string remote_key_; +}; + +/*! + * \brief Create an RPC client session from an RPC client endpoint. + * \param endpoint The endpoint. + * \return The created session. + */ +std::shared_ptr +CreateClientSession(std::shared_ptr endpoint); + +// implementation of inline functions +template +inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&& ...args) { + return syscall_remote_(static_cast(code), std::forward(args)...); +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 29adb0fed108..284dca5cce6b 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,11 +19,12 @@ /*! * \file rpc_event_impl.cc - * \brief Event based RPC server implementation. + * \brief Event driven RPC server implementation. */ #include #include -#include "rpc_session.h" +#include "rpc_endpoint.h" +#include "rpc_local_session.h" namespace tvm { namespace runtime { @@ -35,16 +36,17 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, LOG(FATAL) << "Do not allow explicit receive"; return 0; }); + std::unique_ptr ch(new CallbackChannel(fsend, frecv)); - std::shared_ptr sess = - RPCSession::Create(std::move(ch), name, remote_key); + std::shared_ptr sess = + RPCEndpoint::Create(std::move(ch), name, remote_key); return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - int ret = sess->ServerEventHandler(args[0], args[1]); + int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]); *rv = ret; }); } -TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer") +TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer") .set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc new file mode 100644 index 000000000000..0a2809bbaf4c --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file local_session.cc + * \brief Local session that directs requests to local API. + */ +#include +#include +#include +#include "rpc_local_session.h" + +namespace tvm { +namespace runtime { + +RPCSession::PackedFuncHandle +LocalSession::GetFunction(const std::string& name) { + PackedFunc pf = this->GetFunctionInternal(name); + // return raw handl because the remote need to explicitly manage it. + if (pf != nullptr) return new PackedFunc(pf); + return nullptr; +} + +void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& encode_return) { + auto* pf = static_cast(func); + TVMRetValue rv; + + pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); + int rv_tcode = rv.type_code(); + + // return value encoding. + TVMValue ret_value_pack[3]; + int ret_tcode_pack[3]; + TVMArgsSetter set_arg(ret_value_pack, ret_tcode_pack); + // first location always encode type code. + set_arg(0, rv_tcode); + + if (rv_tcode == kTVMNDArrayHandle) { + // We follow a special protocol to return NDArray to client side + // The first pack value is the NDArray handle as DLTensor + // The second pack value is a customized deleter that deletes the NDArray. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMDLTensorHandle; + ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; + ret_tcode_pack[2] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); + } else if (rv_tcode == kTVMPackedFuncHandle || + rv_tcode == kTVMModuleHandle) { + // MoveToCHost means rv no longer manages the object. + // return handle instead. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else if (rv_tcode == kTVMBytes) { + TVMByteArray byte_arr; + auto* sptr = rv.ptr(); + byte_arr.data = sptr->data(); + byte_arr.size = sptr->length(); + set_arg(1, byte_arr); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else { + set_arg(1, rv); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } +} + +void LocalSession::CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + this->GetDeviceAPI(ctx_to)->CopyDataFromTo( + from, from_offset, + to, to_offset, + nbytes, cpu_ctx, ctx_to, type_hint, nullptr); +} + +void LocalSession::CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + + this->GetDeviceAPI(ctx_from)->CopyDataFromTo( + from, from_offset, + to, to_offset, + nbytes, ctx_from, cpu_ctx, type_hint, nullptr); +} + +void LocalSession::FreeHandle(void* handle, int type_code) { + TVMValue value; + value.v_handle = handle; + // will trigger deleter once the rv goes out of the scope. + TVMRetValue rv = TVMRetValue::MoveFromCHost(value, type_code); +} + +DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool allow_missing) { + return DeviceAPI::Get(ctx, allow_missing); +} + +PackedFunc LocalSession::GetFunctionInternal(const std::string& name) { + auto* fp = tvm::runtime::Registry::Get(name); + if (fp != nullptr) { + return *fp; + } else { + return nullptr; + } +} + +TVM_REGISTER_GLOBAL("rpc.LocalSession") +.set_body_typed([]() { + return CreateRPCSessionModule(std::make_shared()); +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h new file mode 100644 index 000000000000..ebb3ea11c50e --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_local_session.h + * \brief Local session that directs all request to the local runtime API. + */ +#ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ +#define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ + +#include +#include +#include +#include +#include "rpc_session.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief A local session that directly use the handle repr of the + * local tvm runtime objects on the same process. + */ +class LocalSession : public RPCSession { + public: + // function overrides + PackedFuncHandle GetFunction(const std::string& name) final; + + void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) final; + + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) final; + + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) final; + + void FreeHandle(void* handle, int type_code) final; + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) final; + + protected: + /*! + * \brief Internal implementation of GetFunction. + * \param name The name of the function. + * \return The corresponding PackedFunc. + */ + virtual PackedFunc GetFunctionInternal(const std::string& name); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 0e48e6fb2708..106230457fc5 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -18,63 +18,117 @@ */ /*! - * \file rpc_device_api.cc - * \brief RPC module. + * \file rpc_module.cc + * \brief RPC runtime module. */ #include +#include #include #include +#include "rpc_endpoint.h" #include "rpc_session.h" namespace tvm { namespace runtime { -// Wrapped remote function to packed func. -class RPCWrappedFunc { +/*! + * \brief A wrapped remote function as a PackedFunc. + */ +class RPCWrappedFunc : public Object { public: RPCWrappedFunc(void* handle, std::shared_ptr sess) : handle_(handle), sess_(sess) { - fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - WrapRemote(sess, args, rv); - }); } - void operator()(TVMArgs args, TVMRetValue *rv) const { - sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_); + void operator()(TVMArgs args, TVMRetValue* rv) const { + std::vector values(args.values, args.values + args.size()); + std::vector type_codes(args.type_codes, args.type_codes + args.size()); + std::vector> temp_dltensors; + + // scan and check whether we need rewrite these arguments + // to their remote variant. + for (int i = 0; i < args.size(); ++i) { + int tcode = type_codes[i]; + + switch (tcode) { + case kTVMDLTensorHandle: + case kTVMNDArrayHandle: { + // Pass NDArray as DLTensor, NDArray and DLTensor + // are compatible to each other, just need to change the index. + type_codes[i] = kTVMDLTensorHandle; + // translate to a remote view of DLTensor + auto dptr = std::make_unique( + *static_cast(values[i].v_handle)); + dptr->ctx = RemoveSessMask(dptr->ctx); + dptr->data = static_cast(dptr->data)->data; + values[i].v_handle = dptr.get(); + temp_dltensors.emplace_back(std::move(dptr)); + break; + } + case kTVMContext: { + values[i].v_ctx = RemoveSessMask(values[i].v_ctx); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + values[i].v_handle = UnwrapRemoteValueToHandle( + TVMArgValue(values[i], tcode)); + break; + } + } + } + auto set_return = [this, rv](TVMArgs args) { + this->WrapRemoteReturnToValue(args, rv); + }; + sess_->CallFunc(handle_, values.data(), type_codes.data(), + args.size(), set_return); } + ~RPCWrappedFunc() { try { - sess_->CallRemote(RPCCode::kFreeFunc, handle_); + sess_->FreeHandle(handle_, kTVMPackedFuncHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } } - static void WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue* rv); + private: + // remote function handle + void* handle_{nullptr}; + // pointer to the session. + std::shared_ptr sess_; - static void* UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg); + // unwrap a remote value to the underlying handle. + void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const; + // wrap a remote return via Set + void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const; + + // remove a remote session mask + TVMContext RemoveSessMask(TVMContext ctx) const { + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "Can not pass in local context or context with a different remote session"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } // deleter of RPC remote array static void RemoteNDArrayDeleter(Object* obj) { auto* ptr = static_cast(obj); RemoteSpace* space = static_cast(ptr->dl_tensor.data); - space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); + space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle); delete space; delete ptr; } + // wrap return value as remote NDArray. - static NDArray WrapRemoteNDArray(std::shared_ptr sess, - DLTensor* tensor, - void* nd_handle) { + NDArray WrapRemoteNDArray(DLTensor* tensor, void* nd_handle) const { NDArray::Container* data = new NDArray::Container(); data->manager_ctx = nd_handle; data->SetDeleter(RemoteNDArrayDeleter); RemoteSpace* space = new RemoteSpace(); - space->sess = sess; + space->sess = sess_; space->data = tensor->data; data->dl_tensor.data = space; NDArray ret(GetObjectPtr(data)); @@ -89,18 +143,13 @@ class RPCWrappedFunc { data->dl_tensor.ctx.device_id = tensor->ctx.device_id; data->dl_tensor.ctx.device_type = static_cast( static_cast(tensor->ctx.device_type) + - kRPCSessMask * (sess->table_index() + 1)); + kRPCSessMask * (sess_->table_index() + 1)); // check strides. CHECK(tensor->strides == nullptr); // setup byteoffset data->dl_tensor.byte_offset = tensor->byte_offset; return ret; } - - private: - PackedFunc fwrap_; - void* handle_{nullptr}; - std::shared_ptr sess_; }; // RPC that represents a remote module session. @@ -109,10 +158,11 @@ class RPCModuleNode final : public ModuleNode { RPCModuleNode(void* module_handle, std::shared_ptr sess) : module_handle_(module_handle), sess_(sess) { } + ~RPCModuleNode() { if (module_handle_ != nullptr) { try { - sess_->CallRemote(RPCCode::kModuleFree, module_handle_); + sess_->FreeHandle(module_handle_, kTVMModuleHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } @@ -127,31 +177,56 @@ class RPCModuleNode final : public ModuleNode { PackedFunc GetFunction( const std::string& name, const ObjectPtr& sptr_to_self) final { - RPCFuncHandle handle = GetFuncHandle(name); - return WrapRemote(handle); + if (module_handle_ == nullptr) { + return WrapRemoteFunc(sess_->GetFunction(name)); + } else { + InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); + return remote_mod_get_function_(GetRef(this), name, false); + } } std::string GetSource(const std::string& format) final { - if (module_handle_ != nullptr) { - std::string ret = sess_->CallRemote( - RPCCode::kModuleGetSource, module_handle_, format); - } + LOG(FATAL) << "GetSource for rpc Module is not supported"; return ""; } - std::shared_ptr& sess() { - return sess_; - } - PackedFunc GetTimeEvaluator(const std::string& name, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - RPCFuncHandle handle = GetFuncHandle(name); - if (handle == nullptr) return PackedFunc(); - handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms); - return WrapRemote(handle); + InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator"); + // Remove session mask because we pass ctx by parts. + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "ValueError: Need to pass the matched remote context to RPCModule.GetTimeEvaluator"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + + if (module_handle_ != nullptr) { + return remote_get_time_evaluator_( + GetRef(this), name, + static_cast(ctx.device_type), ctx.device_id, + number, repeat, min_repeat_ms); + } else { + return remote_get_time_evaluator_( + Optional(nullptr), name, + static_cast(ctx.device_type), ctx.device_id, + number, repeat, min_repeat_ms); + } + } + + Module LoadModule(std::string name) { + InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module"); + return remote_load_module_(name); + } + + void ImportModule(Module other) { + InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); + remote_import_module_(GetRef(this), other); + } + + const std::shared_ptr& sess() { + return sess_; } void* module_handle() const { @@ -159,7 +234,15 @@ class RPCModuleNode final : public ModuleNode { } private: - PackedFunc WrapRemote(RPCFuncHandle handle) { + template + void InitRemoteFunc(FType* func, const std::string& name) { + if (*func != nullptr) return; + RPCSession::PackedFuncHandle handle = sess_->GetFunction(name); + CHECK(handle != nullptr) << "Cannot found remote function " << name; + *func = WrapRemoteFunc(handle); + } + + PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { if (handle == nullptr) return PackedFunc(); auto wf = std::make_shared(handle, sess_); return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { @@ -167,33 +250,30 @@ class RPCModuleNode final : public ModuleNode { }); } - RPCFuncHandle GetFuncHandle(const std::string& name) { - RPCFuncHandle handle = nullptr; - if (module_handle_ == nullptr) { - handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name); - } else { - handle = sess_->CallRemote( - RPCCode::kModuleGetFunc, module_handle_, name); - } - return handle; - } // The module handle void* module_handle_{nullptr}; // The local channel std::shared_ptr sess_; - // Wrap function to wrap remote module/function. - PackedFunc fwrap_; + // remote function to get time evaluator + TypedPackedFunc, std::string, int, int, int, int, int)> + remote_get_time_evaluator_; + // remote function getter for modules. + TypedPackedFunc remote_mod_get_function_; + // remote function getter for load module + TypedPackedFunc remote_load_module_; + // remote function getter for load module + TypedPackedFunc remote_import_module_; }; -void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg) { + +void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { if (arg.type_code() == kTVMModuleHandle) { Module mod = arg; std::string tkey = mod->type_key(); CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); - CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index) + CHECK(rmod->sess() == sess_) << "ValueError: Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { @@ -204,93 +284,173 @@ void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, } } -void RPCWrappedFunc::WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue *rv) { - void* handle = args.values[0].v_handle; - int tcode = args.type_codes[0]; +void RPCWrappedFunc::WrapRemoteReturnToValue( + TVMArgs args, + TVMRetValue *rv) const { + int tcode = args[0]; - if (handle == nullptr) return; + if (tcode == kTVMNullptr) return; if (tcode == kTVMPackedFuncHandle) { - auto wf = std::make_shared(handle, sess); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto wf = std::make_shared(handle, sess_); *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + return wf->operator()(args, rv); + }); } else if (tcode == kTVMModuleHandle) { - auto n = make_object(handle, sess); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto n = make_object(handle, sess_); *rv = Module(n); } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) { - CHECK_EQ(args.size(), 2); - DLTensor* tensor = args[0]; - void* nd_handle = args[1]; - *rv = WrapRemoteNDArray(sess, tensor, nd_handle); + CHECK_EQ(args.size(), 3); + DLTensor* tensor = args[1]; + void* nd_handle = args[2]; + *rv = WrapRemoteNDArray(tensor, nd_handle); } else { - LOG(FATAL) << "Cannot wrap tcode=" << tcode; + CHECK_EQ(args.size(), 2); + *rv = args[1]; } } -Module CreateRPCModule(std::shared_ptr sess) { +Module CreateRPCSessionModule(std::shared_ptr sess) { auto n = make_object(nullptr, sess); + RPCSession::InsertToSessionTable(sess); return Module(n); } +std::shared_ptr RPCModuleGetSession(Module mod) { + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") + << "ValueError: Cannot pass a non-RPC module to remote"; + auto* rmod = static_cast(mod.operator->()); + return rmod->sess(); +} + +PackedFunc WrapTimeEvaluator(PackedFunc pf, + TVMContext ctx, + int number, + int repeat, + int min_repeat_ms) { + CHECK(pf != nullptr); + + if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { + auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); + CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; + return (*get_micro_time_evaluator)(pf, ctx, number, repeat); + } + + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) + mutable { + TVMRetValue temp; + std::ostringstream os; + // skip first time call, to activate lazy compilation components. + pf.CallPacked(args, &temp); + + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + + for (int i = 0; i < repeat; ++i) { + std::chrono::time_point< + std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + number = static_cast( + std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + + tbegin = std::chrono::high_resolution_clock::now(); + // start timing + for (int i = 0; i < number; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + tend = std::chrono::high_resolution_clock::now(); + + duration_ms = std::chrono::duration_cast > + (tend - tbegin).count() * 1000; + } while (duration_ms < min_repeat_ms); + + double speed = std::chrono::duration_cast >( + tend - tbegin).count() / number; + os.write(reinterpret_cast(&speed), sizeof(speed)); + } + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + + TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; +.set_body_typed([](Optional opt_mod, + std::string name, + int device_type, + int device_id, + int number, + int repeat, + int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + if (opt_mod.defined()) { + Module m = opt_mod.value(); std::string tkey = m->type_key(); - TVMContext ctx; - ctx.device_type = static_cast(args[2].operator int()); - ctx.device_id = args[3]; if (tkey == "rpc") { - *rv = static_cast(m.operator->()) - ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]); + return static_cast(m.operator->()) + ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms); } else { - *rv = WrapTimeEvaluator( - m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]); + return WrapTimeEvaluator( + m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); } - }); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapTimeEvaluator( + *pf, ctx, number, repeat, min_repeat_ms); + } +}); -TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - auto& sess = static_cast(m.operator->())->sess(); - void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]); - auto n = make_object(mhandle, sess); - *rv = Module(n); - }); +// server function registration. +TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule") +.set_body_typed([](Module parent, Module child) { + parent->Import(child); +}); -TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module parent = args[0]; - Module child = args[1]; - CHECK(!std::strcmp(parent->type_key(), "rpc") && - !std::strcmp(child->type_key(), "rpc")); - auto* pmod = static_cast(parent.operator->()); - auto* cmod = static_cast(child.operator->()); - CHECK(pmod->sess().get() == cmod->sess().get()) - << "Import of remote module need to belong to same session."; - pmod->sess()->CallRemote(RPCCode::kModuleImport, - pmod->module_handle(), - cmod->module_handle()); - }); - -TVM_REGISTER_GLOBAL("rpc._ModuleHandle") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->module_handle(); - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") +.set_body_typed([](Module parent, std::string name, bool query_imports) { + return parent->GetFunction(name, query_imports); +}); + +// functions to access an RPC module. +TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule") +.set_body_typed([](Module sess, std::string name) { + std::string tkey = sess->type_key(); + CHECK_EQ(tkey, "rpc"); + return static_cast(sess.operator->())->LoadModule(name); +}); -TVM_REGISTER_GLOBAL("rpc._SessTableIndex") +TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule") +.set_body_typed([](Module parent, Module child) { + std::string tkey = parent->type_key(); + CHECK_EQ(tkey, "rpc"); + static_cast(parent.operator->())->ImportModule(child); +}); + +TVM_REGISTER_GLOBAL("rpc.SessTableIndex") .set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->sess()->table_index(); - }); + Module m = args[0]; + std::string tkey = m->type_key(); + CHECK_EQ(tkey, "rpc"); + *rv = static_cast(m.operator->())->sess()->table_index(); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc new file mode 100644 index 000000000000..376b8b5dc61e --- /dev/null +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_pipe_impl.cc + * \brief Pipe-based RPC channel. + */ +// Linux only for now, as linux is the most common usecase. +#if defined(__linux__) || defined(__ANDROID__) + +#include +#include +#include +#include + +#include +#include +#include + +#include "rpc_endpoint.h" +#include "rpc_local_session.h" +#include "../../support/pipe.h" + +namespace tvm { +namespace runtime { + +class PipeChannel final : public RPCChannel { + public: + explicit PipeChannel(int readfd, int writefd, pid_t child_pid) + : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) { + } + + ~PipeChannel() { + Close(); + } + + size_t Send(const void* data, size_t size) final { + ssize_t n = write(writefd_, data, size); + if (n == -1) { + LOG(FATAL) << "Pipe write error"; + } + return static_cast(n); + } + + size_t Recv(void* data, size_t size) final { + ssize_t n = read(readfd_, data, size); + if (n == -1) { + LOG(FATAL) << "Pipe read error"; + } + return static_cast(n); + } + + void Close() { + close(readfd_); + close(writefd_); + kill(child_pid_, SIGKILL); + } + + private: + int readfd_; + int writefd_; + pid_t child_pid_; +}; + + +Module CreatePipeClient(std::vector cmd) { + int parent2child[2]; + int child2parent[2]; + CHECK_EQ(pipe(parent2child), 0); + CHECK_EQ(pipe(child2parent), 0); + + int parent_read = child2parent[0]; + int parent_write = parent2child[1]; + int child_read = parent2child[0]; + int child_write = child2parent[1]; + + pid_t pid = fork(); + if (pid == 0) { + // child process + close(parent_read); + close(parent_write); + std::string sread_pipe = std::to_string(child_read); + std::string swrite_pipe = std::to_string(child_write); + std::vector argv; + for (auto& str : cmd) { + argv.push_back(dmlc::BeginPtr(str)); + } + argv.push_back(dmlc::BeginPtr(sread_pipe)); + argv.push_back(dmlc::BeginPtr(swrite_pipe)); + argv.push_back(nullptr); + execvp(argv[0], &argv[0]); + } + // parent process + close(child_read); + close(child_write); + + auto endpt = RPCEndpoint::Create( + std::unique_ptr( + new PipeChannel(parent_read, parent_write, pid)), + "pipe", "pipe"); + endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0)); + return CreateRPCSessionModule(CreateClientSession(endpt)); +} + +TVM_REGISTER_GLOBAL("rpc.CreatePipeClient") +.set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector cmd; + for (int i = 0; i < args.size(); ++i) { + cmd.push_back(args[i].operator std::string()); + } + *rv = CreatePipeClient(cmd); +}); + + +} // namespace runtime +} // namespace tvm +#endif diff --git a/src/runtime/rpc/rpc_protocol.h b/src/runtime/rpc/rpc_protocol.h new file mode 100644 index 000000000000..6221bfbe1e82 --- /dev/null +++ b/src/runtime/rpc/rpc_protocol.h @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_procotol.h + * \brief Common header defining the communication code used in the RPC protocol. + */ +#ifndef TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ +#define TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ + +namespace tvm { +namespace runtime { + +/*! \brief The current RPC procotol version. */ +constexpr const char* kRPCProtocolVer = "0.7.0"; + +/*! \brief The RPC code */ +enum class RPCCode : int { + kNone, + kShutdown, + kInitServer, + kCallFunc, + kReturn, + kException, + kCopyFromRemote, + kCopyToRemote, + kCopyAck, + // The following are syscall code that can send over CallRemote + kSyscallCodeStart, + kGetGlobalFunc = kSyscallCodeStart, + kFreeHandle, + kDevSetDevice, + kDevGetAttr, + kDevAllocData, + kDevFreeData, + kDevStreamSync, + kCopyAmongRemote, +}; + +/*! + * \brief List of potential error status during rpc communication. + */ +enum class RPCServerStatus : int { + kSuccess = 0, + kInvalidTypeCodeObject, + kInvalidTypeCodeNDArray, + kInvalidDLTensorFieldStride, + kInvalidDLTensorFieldByteOffset, + kUnknownTypeCode, + kUnknownRPCCode, + kRPCCodeNotSupported, + kUnknownRPCSyscall, + kCheckError, + kReadError, + kWriteError, + kAllocError +}; + +/*! + * \brief Convert RPC server status to string. + * \param status The status. + * \return The corresponding string. + */ +inline const char* RPCServerStatusToString(RPCServerStatus status) { + switch (status) { + case RPCServerStatus::kSuccess: return "kSuccess"; + case RPCServerStatus::kInvalidTypeCodeObject: return "kInvalidTypeCodeObject"; + case RPCServerStatus::kInvalidTypeCodeNDArray: return "kInvalidTypeCodeNDArray"; + case RPCServerStatus::kInvalidDLTensorFieldStride: return "kInvalidDLTensorFieldStride"; + case RPCServerStatus::kInvalidDLTensorFieldByteOffset: { + return "kInvalidDLTensorFieldByteOffset"; + } + case RPCServerStatus::kUnknownTypeCode: return "kUnknownTypeCode"; + case RPCServerStatus::kUnknownRPCCode: return "kUnknownRPCCode"; + case RPCServerStatus::kRPCCodeNotSupported: return "RPCCodeNotSupported"; + case RPCServerStatus::kUnknownRPCSyscall: return "kUnknownRPCSyscall"; + case RPCServerStatus::kCheckError: return "kCheckError"; + case RPCServerStatus::kReadError: return "kReadError"; + case RPCServerStatus::kWriteError: return "kWriteError"; + case RPCServerStatus::kAllocError: return "kAllocError"; + default: return ""; + } +} + +/*! + * \brief Reference implementation of the communication protocol. + * + * \note The implementation is intentionally written via template + * so it can be used in a dependency free setting. + * + * \sa src/runtime/rpc/device/min_rpc_server.h + */ +struct RPCReference { + /*! + * \brief Auxiliary class to get the packed sequence. + * \tparam TChannel The channel to throw errror. + */ + template + struct PackedSeqNumBytesGetter { + public: + explicit PackedSeqNumBytesGetter(TChannel* channel) + : channel_(channel) {} + + template + void Write(const T& value) { + num_bytes_ += sizeof(T); + } + + template + void WriteArray(const T* value, size_t num) { + num_bytes_ += sizeof(T) * num; + } + + void ThrowError(RPCServerStatus status) { + channel_->ThrowError(status); + } + + uint64_t num_bytes() const { + return num_bytes_; + } + + private: + TChannel* channel_; + uint64_t num_bytes_{0}; + }; + + /*! + * \return the length of the str. + * \param str the string. + * \return The length. + */ + static uint64_t StrLength(const char* str) { + uint64_t len = 0; + while (str[len] != '\0') ++len; + return len; + } + + /*! + * \brief Get the total nbytes to be sent in the packed sequence. + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + * \return The total number of bytes. + */ + template + static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode, + TChannel* channel) { + PackedSeqNumBytesGetter getter(channel); + SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter); + return getter.num_bytes(); + } + + /*! + * \brief Send packed argument sequnce to the other peer. + * + * This function serves as the foundational communication primitive between peers. + * + * TVMValue sequence encoding protocol(according to the type): + * + * - int/float/uint/bytes/str: Serialize all content. + * - DLTensor: send meta-data, send data handle as opaque handle(via uint64_t) + * - OpaqueHandle: send as uint64_t + * - ModuleHandle, PackedFuncHandle: send as uint64_t, + * The support to Module/PackedFuncHandle are reserved for arguments + * in the CallFunc from a client to server only. + * Note that we cannot simply take these argument out(as the handle) + * refers to a value on the remote(instead of local). + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void SendPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode, + TChannel* channel) { + channel->Write(num_args); + channel->WriteArray(type_codes, num_args); + + // Argument packing. + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + TVMValue value = arg_values[i]; + switch (tcode) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + channel->template Write(value.v_int64); + break; + } + case kTVMDataType: { + channel->Write(value.v_type); + // padding + int32_t padding = 0; + channel->template Write(padding); + break; + } + case kTVMContext: { + channel->Write(value.v_ctx); + break; + } + + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + if (!client_mode) { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject); + } + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + channel->Write(handle); + break; + } + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + channel->Write(handle); + break; + } + case kTVMNDArrayHandle: { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray); + break; + } + case kTVMDLTensorHandle: { + DLTensor* arr = static_cast(value.v_handle); + TVMContext ctx; + uint64_t data; + // When we return NDArray, we directly return + // the space and the context + // The client will be further wrapping + ctx = arr->ctx; + data = reinterpret_cast(arr->data); + channel->Write(data); + channel->Write(ctx); + channel->Write(arr->ndim); + channel->Write(arr->dtype); + channel->WriteArray(arr->shape, arr->ndim); + if (arr->strides != nullptr) { + channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride); + } + if (arr->byte_offset != 0) { + channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldByteOffset); + } + break; + } + case kTVMNullptr: break; + case kTVMStr: { + const char* s = value.v_str; + uint64_t len = StrLength(s); + channel->Write(len); + channel->WriteArray(s, len); + break; + } + case kTVMBytes: { + TVMByteArray* bytes = static_cast(arg_values[i].v_handle); + uint64_t len = bytes->size; + channel->Write(len); + channel->WriteArray(bytes->data, len); + break; + } + default: { + channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + break; + } + } + } + } + + /*! + * \brief Receive packed seq from the channel. + * + * \param out_arg_values The values to be received. + * \param out_tcodes The type codes to be received. + * \param out_num_args Number of argument. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + * \note The temporary space are populated via an arena inside channel. + */ + template + static void RecvPackedSeq(TVMValue** out_values, + int** out_tcodes, + int* out_num_args, + TChannel* channel) { + // receive number of args + int num_args; + channel->Read(&num_args); + *out_num_args = num_args; + + if (num_args == 0) { + *out_values = nullptr; + *out_tcodes = nullptr; + return; + } + + TVMValue* values = channel->template ArenaAlloc(num_args); + int* tcodes = channel->template ArenaAlloc(num_args); + *out_values = values; + *out_tcodes = tcodes; + + // receive type code. + channel->ReadArray(tcodes, num_args); + + // receive arguments + for (int i = 0; i < num_args; ++i) { + auto& value = values[i]; + switch (tcodes[i]) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + channel->template Read(&(value.v_int64)); + break; + } + case kTVMDataType: { + channel->Read(&(value.v_type)); + int32_t padding = 0; + channel->template Read(&padding); + break; + } + case kTVMContext: { + channel->Read(&(value.v_ctx)); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle; + channel->Read(&handle); + value.v_handle = reinterpret_cast(handle); + break; + } + case kTVMNullptr: { + value.v_handle = nullptr; + break; + } + case kTVMStr: { + uint64_t len; + channel->Read(&len); + char* str = channel->template ArenaAlloc(len + 1); + str[len] = '\0'; + channel->ReadArray(str, len); + value.v_str = str; + break; + } + case kTVMBytes: { + uint64_t len; + channel->Read(&len); + TVMByteArray* arr = channel->template ArenaAlloc(1); + char* data = channel->template ArenaAlloc(len); + arr->size = len; + arr->data = data; + channel->ReadArray(data, len); + value.v_handle = arr; + break; + } + case kTVMDLTensorHandle: { + uint64_t handle; + channel->Read(&handle); + DLTensor* arr = channel->template ArenaAlloc(1); + DLTensor& tensor = *arr; + tensor.data = reinterpret_cast(handle); + channel->Read(&(tensor.ctx)); + channel->Read(&(tensor.ndim)); + channel->Read(&(tensor.dtype)); + tensor.shape = channel->template ArenaAlloc(tensor.ndim); + channel->ReadArray(tensor.shape, tensor.ndim); + tensor.strides = nullptr; + tensor.byte_offset = 0; + value.v_handle = arr; + break; + } + default: { + channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + break; + } + } + } + } + + /*! + * \brief Return an exception packet. + * + * \param msg The error message. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnException(const char* msg, TChannel* channel) { + RPCCode code = RPCCode::kException; + int32_t num_args = 1; + int32_t tcode = kTVMStr; + uint64_t len = StrLength(msg); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(num_args) + + sizeof(tcode) + + sizeof(len) + + len; + + channel->Write(packet_nbytes); + channel->Write(code); + channel->Write(num_args); + channel->Write(tcode); + channel->Write(len); + channel->WriteArray(msg, len); + } + + /*! + * \brief Return a normal packed sequence packet. + * + * \param msg The error message. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args, + TChannel* channel) { + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = + sizeof(code) + + PackedSeqGetNumBytes( + arg_values, type_codes, num_args, false, channel); + + channel->Write(packet_nbytes); + channel->Write(code); + SendPackedSeq( + arg_values, type_codes, num_args, false, channel); + } + + /*! + * \brief Return a null(void) packet. + * + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnVoid(TChannel* channel) { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(num_args) + + sizeof(tcode); + + channel->Write(packet_nbytes); + channel->Write(code); + channel->Write(num_args); + channel->Write(tcode); + } +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index f6a7fb60b5f4..612ca418e812 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,8 @@ namespace tvm { namespace runtime { std::string RPCGetPath(const std::string& name) { - static const PackedFunc* f = + // do live lookup everytime as workpath can change. + const PackedFunc* f = runtime::Registry::Get("tvm.rpc.server.workpath"); CHECK(f != nullptr) << "require tvm.rpc.server.workpath"; return (*f)(name); diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index ae293abfacdd..dd0afa0145d2 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -21,817 +21,16 @@ * \file rpc_session.cc * \brief RPC session for remote function call. */ -#include #include #include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include -#include #include "rpc_session.h" -#include "../object_internal.h" -#include "../../support/ring_buffer.h" -#include "../../support/socket.h" -#include "../micro/micro_session.h" namespace tvm { namespace runtime { -// Temp buffer for data array -struct RPCByteArrayBuffer { - TVMByteArray arr; - std::string data; -}; -// Temp buffer for data array -struct RPCDataArrayBuffer { - DLTensor tensor; - std::vector shape; -}; -/*! - * \brief Temporal argument buffer. - */ -struct RPCArgBuffer { - // The argument values - std::vector value; - // The type codes. - std::vector tcode; - // Temporal resources. - std::vector > temp_bytes; - // Temporal array - std::vector > temp_array; - // convert buffer as TVMArgs - TVMArgs AsTVMArgs() const { - return TVMArgs(value.data(), tcode.data(), static_cast(value.size())); - } -}; - -// Event handler for RPC events. -class RPCSession::EventHandler : public dmlc::Stream { - public: - EventHandler(support::RingBuffer* reader, - support::RingBuffer* writer, - int rpc_sess_table_index, - std::string name, - std::string* remote_key) - : reader_(reader), - writer_(writer), - rpc_sess_table_index_(rpc_sess_table_index), - name_(name), - remote_key_(remote_key) { - this->Clear(); - if (*remote_key == "%toinit") { - state_ = kInitHeader; - remote_key_->resize(0); - pending_request_bytes_ = sizeof(int32_t); - } - } - // Bytes needed to fulfill current request - size_t BytesNeeded() { - if (reader_->bytes_available() < pending_request_bytes_) { - return pending_request_bytes_ - reader_->bytes_available(); - } else { - return 0; - } - } - // Request number of bytes from reader. - void RequestBytes(size_t nbytes) { - pending_request_bytes_ += nbytes; - reader_->Reserve(pending_request_bytes_); - } - // Whether we are ready to handle next request. - bool Ready() { - return reader_->bytes_available() >= pending_request_bytes_; - } - bool CanCleanShutdown() const { - return state_ == kRecvCode; - } - void FinishCopyAck() { - this->SwitchToState(kRecvCode); - } - RPCCode HandleNextEvent(TVMRetValue* rv, - bool client_mode, - const PackedFunc* fwrap) { - std::swap(client_mode_, client_mode); - while (this->Ready()) { - switch (state_) { - case kInitHeader: HandleInitHeader(); break; - case kRecvCode: HandleRecvCode(); break; - case kRecvCallHandle: { - CHECK(this->Read(&call_handle_)); - this->SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case kRecvPackedSeqNumArgs: { - CHECK(this->Read(&num_packed_args_)); - arg_buf_.reset(new RPCArgBuffer()); - arg_buf_->value.resize(num_packed_args_); - arg_buf_->tcode.resize(num_packed_args_); - this->SwitchToState(kRecvPackedSeqTypeCode); - break; - } - case kRecvPackedSeqTypeCode: { - if (num_packed_args_ != 0) { - this->ReadArray(arg_buf_->tcode.data(), num_packed_args_); - } - arg_index_ = 0; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kRecvPackedSeqArg: { - this->HandleRecvPackedSeqArg(); - break; - } - case kDoCopyFromRemote: { - this->HandleCopyFromRemote(); - break; - } - case kDoCopyToRemote: { - this->HandleCopyToRemote(); - break; - } - case kReturnReceived: { - CHECK_GE(arg_buf_->value.size(), 1U); - - TVMArgValue argv = arg_buf_->AsTVMArgs()[0]; - if (argv.type_code() == kTVMPackedFuncHandle || - argv.type_code() == kTVMModuleHandle || - argv.type_code() == kTVMDLTensorHandle) { - CHECK(fwrap != nullptr) << "function/module wrapper not available"; - fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv); - } else { - CHECK_EQ(arg_buf_->value.size(), 1U); - *rv = argv; - } - arg_buf_.reset(); - this->SwitchToState(kRecvCode); - std::swap(client_mode_, client_mode); - return RPCCode::kReturn; - } - case kCopyAckReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kCopyAck; - } - case kShutdownReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kShutdown; - } - } - } - std::swap(client_mode_, client_mode); - return RPCCode::kNone; - } - // Reset and clear all states. - void Clear() { - state_ = kRecvCode; - pending_request_bytes_ = sizeof(RPCCode); - arg_recv_stage_ = 0; - arg_buf_.reset(); - } - // strip session on mask - TVMContext StripSessMask(TVMContext ctx) { - int dev_type = ctx.device_type; - CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1) - << "Can not pass in local context or context with a different remote session"; - ctx.device_type = static_cast(dev_type % kRPCSessMask); - return ctx; - } - // Send Packed sequence to writer. - // - // client_mode: whether we are in client mode. - // - // funwrap: auxiliary function to unwrap remote Object - // when it is provided, we need to unwrap objects. - // - // return_ndarray is a special flag to handle returning of ndarray - // In this case, we return the shape, context and data of the array, - // as well as a customized PackedFunc that handles deletion of - // the array in the remote. - void SendPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args, - bool client_mode, - FUnwrapRemoteObject funwrap = nullptr, - bool return_ndarray = false) { - std::swap(client_mode_, client_mode); - - this->Write(num_args); - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - if (tcode == kTVMNDArrayHandle) tcode = kTVMDLTensorHandle; - this->Write(tcode); - } - - // Argument packing. - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - TVMValue value = arg_values[i]; - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Write(value.v_int64); - break; - } - case kTVMDataType: { - this->Write(value.v_type); - // padding - int32_t padding = 0; - this->Write(padding); - break; - } - case kTVMContext: { - value.v_ctx = StripSessMask(value.v_ctx); - this->Write(value.v_ctx); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: { - // always send handle in 64 bit. - uint64_t handle; - // allow pass module as argument to remote. - if (funwrap != nullptr) { - void* remote_handle = (*funwrap)( - rpc_sess_table_index_, - runtime::TVMArgValue(value, tcode)); - handle = reinterpret_cast(remote_handle); - } else { - CHECK(!client_mode_) - << "Cannot directly pass remote object as argument"; - handle = reinterpret_cast(value.v_handle); - } - this->Write(handle); - break; - } - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle = reinterpret_cast(value.v_handle); - this->Write(handle); - break; - } - case kTVMNDArrayHandle: - case kTVMDLTensorHandle: { - DLTensor* arr = static_cast(value.v_handle); - TVMContext ctx; - uint64_t data; - if (!return_ndarray) { - // in the client mode - // ctx contains the remote table index - // the space is wrapped by an RemoteSpace - // that holds reference to the session. - ctx = StripSessMask(arr->ctx); - data = reinterpret_cast( - static_cast(arr->data)->data); - } else { - // When we return NDArray, we directly return - // the space and the context - // The client will be further wrapping - ctx = arr->ctx; - data = reinterpret_cast(arr->data); - } - this->Write(data); - this->Write(ctx); - this->Write(arr->ndim); - this->Write(arr->dtype); - this->WriteArray(arr->shape, arr->ndim); - CHECK(arr->strides == nullptr) - << "Do not support strided remote array"; - CHECK_EQ(arr->byte_offset, 0) - << "Do not support send byte offset"; - break; - } - case kTVMNullptr: break; - case kTVMStr: { - const char* s = value.v_str; - uint64_t len = strlen(s); - this->Write(len); - this->WriteArray(s, len); - break; - } - case kTVMBytes: { - TVMByteArray* bytes = static_cast(arg_values[i].v_handle); - uint64_t len = bytes->size; - this->Write(len); - this->WriteArray(bytes->data, len); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - std::swap(client_mode_, client_mode); - } - - // Endian aware IO handling - using Stream::Read; - using Stream::Write; - using Stream::ReadArray; - using Stream::WriteArray; - - inline bool Read(RPCCode* code) { - int cdata; - if (!this->Read(&cdata)) return false; - *code = static_cast(cdata); - return true; - } - inline void Write(RPCCode code) { - int cdata = static_cast(code); - this->Write(cdata); - } - - protected: - enum State { - kInitHeader, - kRecvCode, - kRecvCallHandle, - kRecvPackedSeqNumArgs, - kRecvPackedSeqTypeCode, - kRecvPackedSeqArg, - kDoCopyFromRemote, - kDoCopyToRemote, - kReturnReceived, - kCopyAckReceived, - kShutdownReceived - }; - // Current state; - State state_; - // The RPCCode to be read. - RPCCode code_; - // Handle for the remote function call. - uint64_t call_handle_; - // Initialize remote header - bool init_header_step_{0}; - // Number of packed arguments. - int num_packed_args_; - // Current argument index. - int arg_index_; - // The stage of each argument receiver. - int arg_recv_stage_; - // Whether current handler is client or server mode. - bool client_mode_{false}; - // Argument buffer - std::unique_ptr arg_buf_; - // Temp byte buffer. - std::unique_ptr temp_bytes_; - // Temp array buffer. - std::unique_ptr temp_array_; - // Internal temporal data space. - std::string temp_data_; - // Temp variables for copy request state. - TVMContext copy_ctx_; - DLDataType copy_dtype_; - uint64_t copy_handle_, copy_offset_, copy_size_; - // State switcher - void SwitchToState(State state) { - // invariant - CHECK_EQ(pending_request_bytes_, 0U) - << "state=" << state; - state_ = state; - switch (state) { - case kInitHeader: { - LOG(FATAL) << "cannot switch to init header"; - break; - } - case kRecvCode: { - this->RequestBytes(sizeof(RPCCode)); - break; - } - case kRecvCallHandle: { - this->RequestBytes(sizeof(call_handle_)); - break; - } - case kRecvPackedSeqNumArgs: { - this->RequestBytes(sizeof(num_packed_args_)); - break; - } - case kRecvPackedSeqTypeCode: { - this->RequestBytes(sizeof(int) * num_packed_args_); - break; - } - case kRecvPackedSeqArg: { - CHECK_LE(arg_index_, num_packed_args_); - if (arg_index_ == num_packed_args_) { - // The function can change state_ again. - HandlePackedCall(); - } else { - RequestRecvPackedSeqArg(); - } - break; - } - case kDoCopyFromRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kDoCopyToRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kCopyAckReceived: - case kReturnReceived: - case kShutdownReceived: { - break; - } - } - } - // Requets bytes needed for next computation. - void RequestRecvPackedSeqArg() { - CHECK_EQ(arg_recv_stage_, 0); - int tcode = arg_buf_->tcode[arg_index_]; - static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant"); - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: - case kTVMDataType: - case kTVMOpaqueHandle: - case kTVMStr: - case kTVMBytes: - case kTVMModuleHandle: - case kTVMContext: { - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMPackedFuncHandle: { - CHECK(client_mode_) - << "Only client can receive remote functions"; - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMNullptr: break; - case kTVMDLTensorHandle: { - this->RequestBytes(sizeof(uint64_t)); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(int)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - // Handler for packed sequence argument receive. - void HandleRecvPackedSeqArg() { - CHECK_LT(arg_index_, num_packed_args_); - int tcode = arg_buf_->tcode[arg_index_]; - TVMValue& value = arg_buf_->value[arg_index_]; - if (arg_recv_stage_ == 0) { - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Read(&(value.v_int64)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMDataType: { - this->Read(&(value.v_type)); - int32_t padding = 0; - this->Read(&padding); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMContext: { - this->Read(&(value.v_ctx)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle; - this->Read(&handle); - value.v_handle = reinterpret_cast(handle); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMNullptr: { - value.v_handle = nullptr; - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMStr: - case kTVMBytes: { - uint64_t len; - this->Read(&len); - temp_bytes_.reset( new RPCByteArrayBuffer()); - temp_bytes_->data.resize(len); - arg_recv_stage_ = 1; - this->RequestBytes(len); - break; - } - case kTVMDLTensorHandle: { - temp_array_.reset(new RPCDataArrayBuffer()); - uint64_t handle; - this->Read(&handle); - DLTensor& tensor = temp_array_->tensor; - tensor.data = reinterpret_cast(handle); - this->Read(&(tensor.ctx)); - this->Read(&(tensor.ndim)); - this->Read(&(tensor.dtype)); - temp_array_->shape.resize(tensor.ndim); - tensor.shape = temp_array_->shape.data(); - arg_recv_stage_ = 1; - tensor.strides = nullptr; - tensor.byte_offset = 0; - this->RequestBytes(sizeof(int64_t) * tensor.ndim); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } else { - CHECK_EQ(arg_recv_stage_, 1); - if (tcode == kTVMStr || tcode == kTVMBytes) { - if (temp_bytes_->data.size() != 0) { - this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size()); - } - if (tcode == kTVMStr) { - value.v_str = temp_bytes_->data.c_str(); - } else { - temp_bytes_->arr.size = static_cast(temp_bytes_->data.size()); - temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data); - value.v_handle = &(temp_bytes_->arr); - } - arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_)); - } else { - CHECK_EQ(tcode, kTVMDLTensorHandle); - DLTensor& tensor = temp_array_->tensor; - this->ReadArray(tensor.shape, tensor.ndim); - value.v_handle = &tensor; - arg_buf_->temp_array.emplace_back(std::move(temp_array_)); - } - ++arg_index_; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - } - } - // handler for initial header read - void HandleInitHeader() { - if (init_header_step_ == 0) { - int32_t len; - this->Read(&len); - remote_key_->resize(len); - init_header_step_ = 1; - this->RequestBytes(len); - return; - } else { - CHECK_EQ(init_header_step_, 1); - this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); - this->SwitchToState(kRecvCode); - } - } - // Handler for read code. - void HandleRecvCode() { - this->Read(&code_); - if (code_ > RPCCode::kSystemFuncStart) { - SwitchToState(kRecvPackedSeqNumArgs); - return; - } - // invariant. - CHECK_EQ(arg_recv_stage_, 0); - switch (code_) { - case RPCCode::kCallFunc: { - SwitchToState(kRecvCallHandle); - break; - } - case RPCCode::kException: - case RPCCode::kReturn: { - SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case RPCCode::kCopyFromRemote: { - SwitchToState(kDoCopyFromRemote); - break; - } - case RPCCode::kCopyToRemote: { - SwitchToState(kDoCopyToRemote); - break; - } - case RPCCode::kShutdown: { - SwitchToState(kShutdownReceived); - break; - } - case RPCCode::kCopyAck: { - SwitchToState(kCopyAckReceived); - break; - } - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } - } - - void HandleCopyFromRemote() { - uint64_t handle, offset, num_bytes; - TVMContext ctx; - DLDataType type_hint; - this->Read(&handle); - this->Read(&offset); - this->Read(&num_bytes); - this->Read(&ctx); - this->Read(&type_hint); - size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; - - if (ctx.device_type == kDLCPU) { - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - char* dptr = reinterpret_cast(handle) + offset; - if (!DMLC_IO_NO_ENDIAN_SWAP) { - temp_data_.resize(0); - temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes); - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - this->WriteArray(temp_data_.data(), num_bytes); - } else { - this->WriteArray(dptr, num_bytes); - } - } else { - temp_data_.resize(num_bytes + 1); - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(ctx)->CopyDataFromTo( - reinterpret_cast(handle), offset, - dmlc::BeginPtr(temp_data_), 0, - num_bytes, ctx, cpu_ctx, type_hint, nullptr); - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - } - this->WriteArray(&temp_data_[0], num_bytes); - } catch (const std::runtime_error &e) { - RPCCode code = RPCCode::kException; - this->Write(code); - TVMValue ret_value; - ret_value.v_str = e.what(); - int ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } - this->SwitchToState(kRecvCode); - } - - void HandleCopyToRemote() { - // use static variable to persist state. - // This only works if next stage is immediately after this. - if (arg_recv_stage_ == 0) { - CHECK(this->Read(©_handle_)); - CHECK(this->Read(©_offset_)); - CHECK(this->Read(©_size_)); - CHECK(this->Read(©_ctx_)); - CHECK(this->Read(©_dtype_)); - arg_recv_stage_ = 1; - CHECK_EQ(pending_request_bytes_, 0U); - this->RequestBytes(copy_size_); - } else { - CHECK_EQ(arg_recv_stage_, 1); - TVMValue ret_value; - ret_value.v_handle = nullptr; - int ret_tcode = kTVMNullptr; - RPCCode code = RPCCode::kReturn; - std::string errmsg; - - size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8; - if (copy_ctx_.device_type == kDLCPU) { - char* dptr = reinterpret_cast(copy_handle_) + copy_offset_; - this->ReadArray(dptr, copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes); - } - } else { - temp_data_.resize(copy_size_ + 1); - this->ReadArray(&temp_data_[0], copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes); - } - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(copy_ctx_)->CopyDataFromTo( - temp_data_.data(), 0, - reinterpret_cast(copy_handle_), copy_offset_, - copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr); - } catch (const std::runtime_error &e) { - code = RPCCode::kException; - errmsg = e.what(); - ret_value.v_str = errmsg.c_str(); - ret_tcode = kTVMStr; - } - } - this->Write(code); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - arg_recv_stage_ = 0; - this->SwitchToState(kRecvCode); - } - } - // Handle for packed call. - void HandlePackedCall(); - - template - void CallHandler(F f) { - TVMRetValue rv; - TVMValue ret_value; - int ret_tcode; - try { - // Need to move out, in case f itself need to call RecvPackedSeq - // Which will override argbuf again. - std::unique_ptr args = std::move(arg_buf_); - f(args->AsTVMArgs(), &rv); - RPCCode code = RPCCode::kReturn; - this->Write(code); - if (rv.type_code() == kTVMStr) { - ret_value.v_str = rv.ptr()->c_str(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMBytes) { - std::string* bytes = rv.ptr(); - TVMByteArray arr; - arr.data = bytes->c_str(); - arr.size = bytes->length(); - ret_value.v_handle = &arr; - ret_tcode = kTVMBytes; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMPackedFuncHandle || - rv.type_code() == kTVMModuleHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send function and module handle back."; - rv.MoveToCHost(&ret_value, &ret_tcode); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMNDArrayHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send NDArray back"; - // We follow a special protocol to return NDArray to client side - // The first pack value is the NDArray handle as DLTensor - // The second pack value is a customized deleter that deletes the NDArray. - TVMValue ret_value_pack[2]; - int ret_tcode_pack[2]; - rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); - ret_value_pack[1].v_handle = ret_value_pack[0].v_handle; - ret_tcode_pack[1] = kTVMOpaqueHandle; - SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); - } else { - ret_value = rv.value(); - ret_tcode = rv.type_code(); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } catch (const std::runtime_error& e) { - RPCCode code = RPCCode::kException; - this->Write(code); - ret_value.v_str = e.what(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } - - private: - // Utility functions - // Internal read function, update pending_request_bytes_ - size_t Read(void* data, size_t size) final { - CHECK_LE(size, pending_request_bytes_); - reader_->Read(data, size); - pending_request_bytes_ -= size; - return size; - } - void Write(const void* data, size_t size) final { - writer_->Write(data, size); - } - // Number of pending bytes requests - size_t pending_request_bytes_; - // The ring buffer to read data from. - support::RingBuffer* reader_; - // The ringr buffer to write reply to. - support::RingBuffer* writer_; - // Session table index. - int rpc_sess_table_index_; - // Name of session. - std::string name_; - // remote key - std::string* remote_key_; -}; - -struct RPCSessTable { +class RPCSessTable { public: static constexpr int kMaxRPCSession = 32; // Get global singleton @@ -864,465 +63,13 @@ struct RPCSessTable { std::array, kMaxRPCSession> tbl_; }; -RPCCode RPCSession::HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap) { - RPCCode code = RPCCode::kCallFunc; - while (code != RPCCode::kReturn && - code != RPCCode::kShutdown && - code != RPCCode::kCopyAck) { - while (writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - size_t bytes_needed = handler_->BytesNeeded(); - if (bytes_needed != 0) { - size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { - return channel_->Recv(data, size); - }, bytes_needed); - if (n == 0) { - if (handler_->CanCleanShutdown()) { - return RPCCode::kShutdown; - } else { - LOG(FATAL) << "Channel closes before we get neded bytes"; - } - } - } - code = handler_->HandleNextEvent(rv, client_mode, fwrap); - } - return code; -} - -void RPCSession::Init() { - // Event handler - handler_ = std::make_shared( - &reader_, &writer_, table_index_, name_, &remote_key_); - // Quick function to call remote. - call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); - RPCCode code = HandleUntilReturnEvent(rv, true, nullptr); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); - }); -} - -std::shared_ptr RPCSession::Create( - std::unique_ptr channel, - std::string name, - std::string remote_key) { - std::shared_ptr sess = std::make_shared(); - sess->channel_ = std::move(channel); - sess->name_ = std::move(name); - sess->remote_key_ = std::move(remote_key); - sess->table_index_ = RPCSessTable::Global()->Insert(sess); - sess->Init(); - return sess; -} - std::shared_ptr RPCSession::Get(int table_index) { return RPCSessTable::Global()->Get(table_index); } -RPCSession::~RPCSession() { - this->Shutdown(); -} - -void RPCSession::Shutdown() { - if (channel_ != nullptr) { - RPCCode code = RPCCode::kShutdown; - handler_->Write(code); - // flush all writing buffer to output channel. - try { - while (writer_.bytes_available() != 0) { - size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - if (n == 0) break; - } - } catch (const dmlc::Error& e) { - } - channel_.reset(nullptr); - } -} - -void RPCSession::ServerLoop() { - std::lock_guard lock(mutex_); - if (const auto* f = Registry::Get("tvm.rpc.server.start")) { - (*f)(); - } - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown); - if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { - (*f)(); - } - channel_.reset(nullptr); -} - -int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) { - std::lock_guard lock(mutex_); - RPCCode code = RPCCode::kNone; - if (bytes.length() != 0) { - reader_.Write(bytes.c_str(), bytes.length()); - TVMRetValue rv; - code = handler_->HandleNextEvent(&rv, false, nullptr); - } - if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); - if (code == RPCCode::kShutdown) return 0; - if (writer_.bytes_available() != 0) return 2; - return 1; -} - -// Get remote function with name -void RPCSession::CallFunc(void* h, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap) { - std::lock_guard lock(mutex_); - - RPCCode code = RPCCode::kCallFunc; - handler_->Write(code); - uint64_t handle = reinterpret_cast(h); - handler_->Write(handle); - handler_->SendPackedSeq( - args.values, args.type_codes, args.num_args, true, funwrap); - code = HandleUntilReturnEvent(rv, true, fwrap); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); -} - -void RPCSession::CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_to, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_to = handler_->StripSessMask(ctx_to); - RPCCode code = RPCCode::kCopyToRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(to); - handler_->Write(handle); - uint64_t offset = static_cast(to_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_to); - handler_->Write(type_hint); - handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn); -} - -void RPCSession::CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_from, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_from = handler_->StripSessMask(ctx_from); - RPCCode code = RPCCode::kCopyFromRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(from); - handler_->Write(handle); - uint64_t offset = static_cast(from_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_from); - handler_->Write(type_hint); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck); - reader_.Reserve(data_size); - handler_->RequestBytes(data_size); - while (!handler_->Ready()) { - size_t bytes_needed = handler_->BytesNeeded(); - reader_.WriteWithCallback([this](void* data, size_t size) { - size_t n = channel_->Recv(data, size); - CHECK_NE(n, 0U) << "Channel closes before we get neded bytes"; - return n; - }, bytes_needed); - } - handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); - handler_->FinishCopyAck(); -} - -RPCFuncHandle RPCSession::GetTimeEvaluator( - RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - return this->CallRemote( - RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms); -} - -// Event handler functions -void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) { - std::string name = args[0]; - auto *fp = tvm::runtime::Registry::Get(name); - if (fp != nullptr) { - *rv = static_cast(new tvm::runtime::PackedFunc(*fp)); - } else { - *rv = nullptr; - } -} - -void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - delete static_cast(handle); -} - -void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAPI::Get(ctx)->SetDevice(ctx); -} - -void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAttrKind kind = static_cast(args[1].operator int()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPI::Get(ctx, true); - if (api != nullptr) { - api->GetAttr(ctx, kind, rv); - } else { - *rv = 0; - } - } else { - DeviceAPI::Get(ctx)->GetAttr( - ctx, static_cast(kind), rv); - } -} - -void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - uint64_t nbytes = args[1]; - uint64_t alignment = args[2]; - DLDataType type_hint = args[3]; - void* data = DeviceAPI::Get(ctx)->AllocDataSpace( - ctx, nbytes, alignment, type_hint); - *rv = data; -} - -void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - void* ptr = args[1]; - DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr); -} - -void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - TVMStreamHandle handle = args[1]; - DeviceAPI::Get(ctx)->StreamSync(ctx, handle); -} - -void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { - void* from = args[0]; - uint64_t from_offset = args[1]; - void* to = args[2]; - uint64_t to_offset = args[3]; - uint64_t size = args[4]; - TVMContext ctx_from = args[5]; - TVMContext ctx_to = args[6]; - DLDataType type_hint = args[7]; - TVMStreamHandle stream = args[8]; - TVMContext ctx = ctx_from; - if (ctx.device_type == kDLCPU) { - ctx = ctx_to; - } else { - CHECK(ctx_to.device_type == kDLCPU || - ctx_to.device_type == ctx_from.device_type) - << "Can not copy across different ctx types directly"; - } - DeviceAPI::Get(ctx)->CopyDataFromTo( - from, from_offset, - to, to_offset, - size, ctx_from, ctx_to, type_hint, stream); -} - -void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { - static const PackedFunc* fsys_load_ = nullptr; - if (fsys_load_ == nullptr) { - fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module"); - CHECK(fsys_load_ != nullptr); - } - std::string file_name = args[0]; - TVMRetValue ret = (*fsys_load_)(file_name); - // pass via void* - TVMValue value; - int rcode; - ret.MoveToCHost(&value, &rcode); - CHECK_EQ(rcode, kTVMModuleHandle); - *rv = static_cast(value.v_handle); -} - -void RPCModuleImport(TVMArgs args, TVMRetValue *rv) { - void* pmod = args[0]; - void* cmod = args[1]; - ObjectInternal::GetModuleNode(pmod)->Import( - GetRef(ObjectInternal::GetModuleNode(cmod))); -} - -void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - ObjectInternal::ObjectFree(mhandle); -} - -void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction( - args[1], false); - if (pf != nullptr) { - *rv = static_cast(new PackedFunc(pf)); - } else { - *rv = nullptr; - } -} - -void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - std::string fmt = args[1]; - *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt); -} - -void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - static_cast( - reinterpret_cast(handle))->DecRef(); -} - -void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { - PackedFunc *pf = static_cast(args[0].operator void*()); - void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4])); - delete pf; - *rv = fhandle; -} - -void RPCSession::EventHandler::HandlePackedCall() { - CHECK_EQ(pending_request_bytes_, 0U); - if (code_ == RPCCode::kReturn) { - state_ = kReturnReceived; return; - } - // reset state to clean init state - state_ = kRecvCode; - this->RequestBytes(sizeof(RPCCode)); - // Event handler sit at clean state at this point. - switch (code_) { - case RPCCode::kCallFunc: { - PackedFunc* pf = reinterpret_cast(call_handle_); - CallHandler([pf](TVMArgs args, TVMRetValue* rv) { - pf->CallPacked(args, rv); - }); - break; - } - case RPCCode::kException: { - CHECK_EQ(arg_buf_->value.size(), 1U); - CHECK_EQ(arg_buf_->tcode[0], kTVMStr); - std::ostringstream os; - os << "Except caught from RPC call: " << arg_buf_->value[0].v_str; - arg_buf_.reset(); - throw dmlc::Error(os.str()); - break; - } - // system functions - case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break; - case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break; - case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break; - case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break; - case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break; - case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break; - case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break; - case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break; - case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break; - case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break; - case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break; - case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break; - case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break; - case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break; - case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break; - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } - CHECK_EQ(state_, kRecvCode); -} - -PackedFunc WrapTimeEvaluator(PackedFunc pf, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms) { - if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { - auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); - CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; - return (*get_micro_time_evaluator)(pf, ctx, number, repeat); - } - - auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable { - TVMRetValue temp; - std::ostringstream os; - // skip first time call, to activate lazy compilation components. - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - - for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; - double duration_ms = 0.0; - - do { - if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random - } - - tbegin = std::chrono::high_resolution_clock::now(); - // start timing - for (int i = 0; i < number; ++i) { - pf.CallPacked(args, &temp); - } - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - tend = std::chrono::high_resolution_clock::now(); - - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; - } while (duration_ms < min_repeat_ms); - - double speed = std::chrono::duration_cast >( - tend - tbegin).count() / number; - os.write(reinterpret_cast(&speed), sizeof(speed)); - } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} - -size_t CallbackChannel::Send(const void* data, size_t size) { - TVMByteArray bytes; - bytes.data = static_cast(data); - bytes.size = size; - int64_t n = fsend_(bytes); - if (n == -1) { - support::Socket::Error("CallbackChannel::Send"); - } - return static_cast(n); -} - -size_t CallbackChannel::Recv(void* data, size_t size) { - TVMRetValue ret = frecv_(size); - - if (ret.type_code() != kTVMBytes) { - support::Socket::Error("CallbackChannel::Recv"); - } - std::string* bytes = ret.ptr(); - memcpy(static_cast(data), bytes->c_str(), bytes->length()); - return bytes->length(); +void RPCSession::InsertToSessionTable(std::shared_ptr sess) { + CHECK_EQ(sess->table_index_, 0); + sess->table_index_ = RPCSessTable::Global()->Insert(sess); } } // namespace runtime diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index db63be4be74d..a715e7b0b20c 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -24,230 +24,166 @@ #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ + #include #include -#include -#include +#include #include -#include -#include "../../support/ring_buffer.h" +#include namespace tvm { namespace runtime { -// Magic header for RPC data plane -const int kRPCMagic = 0xff271; -// magic header for RPC tracker(control plane) -const int kRPCTrackerMagic = 0x2f271; -// sucess response -const int kRPCSuccess = kRPCMagic + 0; -// cannot found matched key in server -const int kRPCMismatch = kRPCMagic + 2; - -/*! \brief Enumeration code for the RPC tracker */ -enum class TrackerCode : int { - kFail = -1, - kSuccess = 0, - kPing = 1, - kStop = 2, - kPut = 3, - kRequest = 4, - kUpdateInfo = 5, - kSummary = 6, - kGetPendingMatchKeys = 7 -}; -/*! \brief The remote functio handle */ -using RPCFuncHandle = void*; - -struct RPCArgBuffer; - -/*! \brief The RPC code */ -enum class RPCCode : int { - kNone, - kCallFunc, - kReturn, - kException, - kShutdown, - kCopyFromRemote, - kCopyToRemote, - kCopyAck, - // The following are code that can send over CallRemote - kSystemFuncStart, - kGetGlobalFunc, - kGetTimeEvaluator, - kFreeFunc, - kDevSetDevice, - kDevGetAttr, - kDevAllocData, - kDevFreeData, - kDevStreamSync, - kCopyAmongRemote, - kModuleLoad, - kModuleImport, - kModuleFree, - kModuleGetFunc, - kModuleGetSource, - kNDArrayFree -}; - /*! - * \brief Function that unwraps a remote object to its handle. - * \param rpc_sess_table_index RPC session table index for validation. - * \param obj Handle to the object argument. - * \return The corresponding handle. - */ -typedef void* (*FUnwrapRemoteObject)( - int rpc_sess_table_index, - const TVMArgValue& obj); - -/*! - * \brief Abstract channel interface used to create RPCSession. + * \brief The interface of all remote RPC sessions. + * + * It contains all the necessary interface to implement + * remote call and resource management. + * + * The interface is designed to allow easy proxy-chaining + * by forward requests to another RPCSession. */ -class RPCChannel { +class RPCSession { public: - /*! \brief virtual destructor */ - virtual ~RPCChannel() {} - /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. - */ - virtual size_t Send(const void* data, size_t size) = 0; + /*! \brief PackedFunc Handle in the remote. */ + using PackedFuncHandle = void*; + + /*! \brief Module handle in the remote. */ + using ModuleHandle = void*; + + /*! \brief NDArray handle in the remote. */ + using NDArrayHandle = void*; + /*! - * \brief Recv data from channel. + * \brief Callback to send an encoded return values via encode_args. + * + * \param encode_args The arguments that we can encode the return values into. + * \param ret_tcode The actual remote type code of the return value. * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. + * Encoding convention (as list of arguments): + * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention. + * - PackedFunc/Module: [tcode: int, handle: void*] + * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*] + * DLTensor* contains the meta-data as well as handle into the remote data. + * nd_handle can be used for deletion. */ - virtual size_t Recv(void* data, size_t size) = 0; -}; + using FEncodeReturn = std::function; + + /*! \brief Destructor.*/ + virtual ~RPCSession() {} -// Bidirectional Communication Session of PackedRPC -class RPCSession { - public: - /*! \brief virtual destructor */ - ~RPCSession(); /*! - * \brief The server loop that server runs to handle RPC calls. + * \brief Get function in the session. + * \param name The name of the function. + * \return The function handle. */ - void ServerLoop(); + virtual PackedFuncHandle GetFunction(const std::string& name) = 0; + /*! - * \brief Message handling function for event driven server. - * Called when the server receives a message. - * Event driven handler will never call recv on the channel - * and always relies on the ServerEventHandler. - * to receive the data. + * \brief Call into a remote Packed function. * - * \param in_bytes The incoming bytes. - * \param event_flag 1: read_available, 2: write_avaiable. - * \return State flag. - * 1: continue running, no need to write, - * 2: need to write - * 0: shutdown - */ - int ServerEventHandler(const std::string& in_bytes, - int event_flag); - /*! - * \brief Call into remote function - * \param handle The function handle - * \param args The arguments - * \param rv The return value. - * \param funpwrap Function that takes a remote object and returns the raw handle. - * \param fwrap Wrapper function to turn Function/Module handle into real return. + * Calling convention: + * + * - type_code is follows the PackedFunc convention. + * - int/float/string/bytes follows the PackedFunc convention, all data are local. + * - PackedFunc/Module and future remote objects: pass remote handle instead. + * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor + * points to a remote data handle returned by the Device API. + * The meta-data of the DLTensor sits on local. + * + * The caller populates the arguments and manages these arguments. + * + * The callee can change the content of arg_values and arg_type_codes + * if they want to do inplace modify and forward. + * + * The callee need to store the return value into ret_value. + * - PackedFunc/Module are stored as void* + * - NDArray is stored as local NDArray, whose data field is a remote handle. + * Notably the NDArray's deleter won't delete remote handle. + * It is up to the user of the RPCSession to such wrapping. + * - In short, remote handles are "moved" as return values + * and the callee needs to explicitly manage them by calling + * the deleter functions when they are no longer needed. + * + * \param func The function handle. + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to set the return value, + * if not called, return value is null. */ - void CallFunc(RPCFuncHandle handle, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap); + virtual void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) = 0; + /*! * \brief Copy bytes into remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. - * \param to The target array. - * \param to_offset The byte offset in the to. + * \param local_from The source host data. + * \param local_from_offset The byte offeset in the from. + * \param remote_to The target array. + * \param remote_to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_to The target context. + * \param remote_ctx_to The target context. * \param type_hint Hint of content data type. */ - void CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint); + virtual void CopyToRemote(void* local_from, + size_t local_from_offset, + void* remote_to, + size_t remote_to_offset, + size_t nbytes, + TVMContext remote_ctx_to, + DLDataType type_hint) = 0; /*! * \brief Copy bytes from remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. + * \param remote_from The source host data. + * \param remote_from_offset The byte offeset in the from. * \param to The target array. * \param to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_from The source context. + * \param remote_ctx_from The source context in the remote. * \param type_hint Hint of content data type. */ - void CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint); + virtual void CopyFromRemote(void* remote_from, + size_t remote_from_offset, + void* local_to, + size_t local_to_offset, + size_t nbytes, + TVMContext remote_ctx_from, + DLDataType type_hint) = 0; + /*! - * \brief Get a remote timer function on ctx. - * This function consumes fhandle, caller should not call Free on fhandle. - * - * \param fhandle The function handle. - * \param ctx The ctx to run measurement on. - * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. - * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. - * \return A remote timer function + * \brief Free a remote function. + * \param handle The remote handle, can be NDArray/PackedFunc/Module + * \param type_code The type code of the underlying type. */ - RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms); + virtual void FreeHandle(void* handle, int type_code) = 0; + /*! - * \brief Call a remote defined system function with arguments. - * \param fcode The function code. - * \param args The arguments - * \return The returned remote value. + * \brief Get device API that represents the remote + * actions that can be taken on the remote. + * + * The caller can then call into the Alloc/Free functions + * to allocate free spaces and taking the pointer as the handle. + * + * The device API is guaranteed to be alive during the + * lifetime of the Session. + * + * \param ctx The remote context. + * \param allow_missing Whether can we return nullptr if it is not available. + * + * \return The device API. */ - template - inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args); + virtual DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) = 0; + /*! * \return The session table index of the session. */ int table_index() const { return table_index_; } - /*! - * \brief Create a RPC session with given channel. - * \param channel The communication channel. - * \param name The local name of the session, used for debug - * \param remote_key The remote key of the session - * if remote_key equals "%toinit", we need to re-intialize - * it by event handler. - */ - static std::shared_ptr Create( - std::unique_ptr channel, - std::string name, - std::string remote_key); + /*! * \brief Try get session from the global session table by table index. * \param table_index The table index of the session. @@ -256,62 +192,25 @@ class RPCSession { static std::shared_ptr Get(int table_index); private: - class EventHandler; - // Handle events until receives a return - // Also flushes channels so that the function advances. - RPCCode HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap); - // Initalization - void Init(); - // Shutdown - void Shutdown(); - // Internal channel. - std::unique_ptr channel_; - // Internal mutex - std::recursive_mutex mutex_; - // Internal ring buffer. - support::RingBuffer reader_, writer_; - // Event handler. - std::shared_ptr handler_; - // call remote with specified function code. - PackedFunc call_remote_; - // The index of this session in RPC session table. + /*! \brief index of this session in RPC session table */ int table_index_{0}; - // The name of the session. - std::string name_; - // The remote key - std::string remote_key_; + /*! \brief Insert the current session to the session table.*/ + static void InsertToSessionTable(std::shared_ptr sess); + // friend declaration + friend Module CreateRPCSessionModule(std::shared_ptr sess); }; /*! - * \brief RPC channel which callback - * frontend (Python/Java/etc.)'s send & recv function + * \brief Remote space handle cell used by the RPC runtime API. + * + * When we allocate space using a rpc context, the data pointer + * points to an allocated RemoteSpace. */ -class CallbackChannel final : public RPCChannel { - public: - explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) - : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} - - ~CallbackChannel() {} - /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. - */ - size_t Send(const void* data, size_t size) final; - /*! - * \brief Recv data from channel. - * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. - */ - size_t Recv(void* data, size_t size) final; - - private: - PackedFunc fsend_; - PackedFunc frecv_; +struct RemoteSpace { + /*! \brief The remote data handle. */ + void* data; + /*! \brief Reference to the underlying RPC session. */ + std::shared_ptr sess; }; /*! @@ -319,18 +218,18 @@ class CallbackChannel final : public RPCChannel { * \param f The function argument. * \param ctx The context. * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. + * We call these runs as one `repeat` of measurement. * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. + * In total, the function will be invoked (1 + number x repeat) times, + * where the first one is warm up and will be discarded. + * The returned result contains `repeat` costs, + * each of which is an average of `number` costs. * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. + * By default, one `repeat` contains `number` runs. If this parameter is set, + * the parameters `number` will be dynamically adjusted to meet the + * minimum duration requirement of one `repeat`. + * i.e., When the run time of one `repeat` falls below this time, + * the `number` parameter will be automatically increased. * \return f_timer A timer function. */ PackedFunc WrapTimeEvaluator(PackedFunc f, @@ -344,21 +243,15 @@ PackedFunc WrapTimeEvaluator(PackedFunc f, * \param sess The RPC session of the global module. * \return The created module. */ -Module CreateRPCModule(std::shared_ptr sess); +Module CreateRPCSessionModule(std::shared_ptr sess); -// Remote space pointer. -struct RemoteSpace { - void* data; - std::shared_ptr sess; -}; +/*! + * \brief Get the session module from a RPC session Module. + * \param mod The input module(must be an RPCModule). + * \return The internal RPCSession. + */ +std::shared_ptr RPCModuleGetSession(Module mod); -// implementation of inline functions -template -inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) { - std::lock_guard lock(mutex_); - writer_.Write(&code, sizeof(code)); - return call_remote_(std::forward(args)...); -} } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_RPC_RPC_SESSION_H_ diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 642fbb8ec7f2..f3a30dd6c485 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -22,8 +22,11 @@ * \brief Socket based RPC implementation. */ #include +#include #include +#include "rpc_endpoint.h" #include "rpc_session.h" +#include "rpc_local_session.h" #include "../../support/socket.h" namespace tvm { @@ -61,8 +64,8 @@ class SockChannel final : public RPCChannel { support::TCPSocket sock_; }; -std::shared_ptr -RPCConnect(std::string url, int port, std::string key) { +std::shared_ptr +RPCConnect(std::string url, int port, std::string key, TVMArgs init_seq) { support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); @@ -96,42 +99,56 @@ RPCConnect(std::string url, int port, std::string key) { remote_key.resize(keylen); CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } - return RPCSession::Create( + auto endpt = RPCEndpoint::Create( std::unique_ptr(new SockChannel(sock)), key, remote_key); + endpt->InitRemoteSession(init_seq); + return endpt; } -Module RPCClientConnect(std::string url, int port, std::string key) { - return CreateRPCModule(RPCConnect(url, port, "client:" + key)); +Module RPCClientConnect(std::string url, + int port, + std::string key, + TVMArgs init_seq) { + auto endpt = RPCConnect(url, port, "client:" + key, init_seq); + return CreateRPCSessionModule(CreateClientSession(endpt)); } // TVM_DLL needed for MSVC TVM_DLL void RPCServerLoop(int sockfd) { support::TCPSocket sock( static_cast(sockfd)); - RPCSession::Create( + RPCEndpoint::Create( std::unique_ptr(new SockChannel(sock)), "SockServerLoop", "")->ServerLoop(); } -void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { - RPCSession::Create(std::unique_ptr( - new CallbackChannel(fsend, frecv)), +void RPCServerLoop(PackedFunc fsend, + PackedFunc frecv) { + RPCEndpoint::Create( + std::unique_ptr(new CallbackChannel(fsend, frecv)), "SockServerLoop", "")->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc._Connect") -.set_body_typed(RPCClientConnect); +TVM_REGISTER_GLOBAL("rpc.Connect") +.set_body([](TVMArgs args, TVMRetValue *rv) { + std::string url = args[0]; + int port = args[1]; + std::string key = args[2]; + *rv = RPCClientConnect( + url, port, key, + TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); +}); -TVM_REGISTER_GLOBAL("rpc._ServerLoop") +TVM_REGISTER_GLOBAL("rpc.ServerLoop") .set_body([](TVMArgs args, TVMRetValue* rv) { - if (args.size() == 1) { - RPCServerLoop(args[0]); - } else { - CHECK_EQ(args.size(), 2); - RPCServerLoop( - args[0].operator tvm::runtime::PackedFunc(), - args[1].operator tvm::runtime::PackedFunc()); - } - }); + if (args[0].type_code() == kDLInt) { + RPCServerLoop(args[0]); + } else { + RPCServerLoop( + args[0].operator tvm::runtime::PackedFunc(), + args[1].operator tvm::runtime::PackedFunc()); + } +}); + } // namespace runtime } // namespace tvm diff --git a/src/support/arena.h b/src/support/arena.h index 744ff4f12188..b06227680808 100644 --- a/src/support/arena.h +++ b/src/support/arena.h @@ -26,42 +26,114 @@ #ifndef TVM_SUPPORT_ARENA_H_ #define TVM_SUPPORT_ARENA_H_ +#ifndef TVM_ARENA_HAS_DESTRUCTOR +#define TVM_ARENA_HAS_DESTRUCTOR 1 +#endif + +#include #include #include + namespace tvm { namespace support { -const constexpr int kArenaPageSize = 16 << 10; +/*! + * \brief An arena page header. + */ +struct ArenaPageHeader { + /*! \brief points to the next page. */ + ArenaPageHeader* next; + /*! + * \brief Total size of the page. + */ + size_t size; + /*! \brief memory allocator offset inside page. */ + size_t offset; +}; + +/*! + * \brief Simple page allocator that uses new and delete. + */ +class SimplePageAllocator { + public: + /*! + * \brief Allocate a new page. + * \param min_size Minimum size of the page. + * \return The allocated page. + * \note This function can return a bigger page to meet the min_size requirement. + */ + ArenaPageHeader* allocate(size_t min_size) { + size_t npages = ((min_size + kPageSize - 1) / kPageSize); + ArenaPageHeader* header = reinterpret_cast(new Page[npages]); + header->size = npages * kPageSize; + header->offset = sizeof(ArenaPageHeader); + return header; + } + /*! + * \brief De-allocate an allocate page. + * \param page The page to be de-allocated. + */ + void deallocate(ArenaPageHeader* page) { + delete [] reinterpret_cast(page); + } + + static const constexpr int kPageSize = 16 << 10; + static const constexpr int kPageAlign = 1024; + + private: + // page size 16 KB + // The page data type; + using Page = std::aligned_storage::type; +}; /*! * \brief Arena allocator that allocates memory from continuous * chunk and frees them all only during destruction. */ -class Arena { +template +class GenericArena { public: - Arena() { + explicit GenericArena(PageAllocator alloc = PageAllocator()) + : alloc_(alloc) { // eagerly allocate the first page. - head_ = reinterpret_cast(new Page()); + head_ = tail_ = alloc_.allocate(1); head_->next = nullptr; - head_->ptr = sizeof(PageHeader); } - ~Arena() { - // delete all the allocated pages. - while (head_ != nullptr) { - Page* page = reinterpret_cast(head_); - head_ = head_->next; - delete page; - } + +#if TVM_ARENA_HAS_DESTRUCTOR + ~GenericArena() { + this->FreeAll(); + } +#endif + + /*! \brief Free all pages. */ + void FreeAll() { + FreePageList(&head_); + FreePageList(&free_list_); + } + /*! \brief Recycle all the pages in the arena */ + void RecycleAll() { + // put all the current list to the free list. + tail_->next = free_list_; + // allocate the first in the free list to head + free_list_ = head_->next; + head_->next = nullptr; + // Reset the head. + head_->offset = sizeof(ArenaPageHeader); + tail_ = head_; } /*! * \brief Allocate a space from Arena for type T * \param T the data type to be allocated + * \param count Numberof elements * \note The space of T is not initialized. */ template - T* allocate_() { - return static_cast(Alloc(sizeof(T), alignof(T))); + T* allocate_(int count = 1) { + static_assert(PageAllocator::kPageAlign % alignof(T) == 0, + "To large alignment"); + return static_cast(Alloc(sizeof(T) * count, alignof(T))); } /*! * \brief Create a new instance of type T. @@ -82,25 +154,21 @@ class Arena { } private: - // page size 16 KB - // The page data type; - using Page = std::aligned_storage::type; - /*! \brief Page header */ - struct PageHeader { - /*! \brief points to the next page */ - PageHeader* next; - /*! \brief memory allocator ptr inside page */ - size_t ptr; - }; - /* \brief The page header */ - PageHeader* head_{nullptr}; + /*! \brief internal page allocator. */ + PageAllocator alloc_; + /* \brief The the head of the allocated list. */ + ArenaPageHeader* head_{nullptr}; + /*! \brief The tail of the allocated list. */ + ArenaPageHeader* tail_{nullptr}; + /* \brief List of free pages. */ + ArenaPageHeader* free_list_{nullptr}; /*! * \brief Align ptr by upper bound. - * \param ptr The pointer value. + * \param offset The offset value. * \param align The alignment requirement. */ - size_t UpperAlign(size_t ptr, size_t align) { - return ptr + (align - (ptr % align)) % align; + size_t UpperAlign(size_t offset, size_t align) { + return offset + (align - (offset % align)) % align; } /*! * \brief Internal aligned alloc function. @@ -108,22 +176,41 @@ class Arena { * \param align The alignment requirement. */ void* Alloc(size_t size, size_t align) { - size_t ptr = UpperAlign(head_->ptr, align); - if (ptr + size <= kArenaPageSize) { - head_->ptr = ptr + size; - return reinterpret_cast(head_) + ptr; + size_t offset = UpperAlign(head_->offset, align); + if (offset + size <= head_->size) { + head_->offset = offset + size; + return reinterpret_cast(head_) + offset; } else { - PageHeader* new_head = reinterpret_cast(new Page()); + ArenaPageHeader* new_head; + offset = UpperAlign(sizeof(ArenaPageHeader), align); + if (free_list_ != nullptr && offset + size <= free_list_-> size) { + new_head = free_list_; + free_list_ = free_list_->next; + } else { + new_head = alloc_.allocate(offset + size); + } new_head->next = head_; - ptr = UpperAlign(sizeof(PageHeader), align); - CHECK_LE(ptr + size, kArenaPageSize); - new_head->ptr = ptr + size; + new_head->offset = offset + size; head_ = new_head; - return reinterpret_cast(head_) + ptr; + return reinterpret_cast(head_) + offset; + } + } + /*! + * \brief Free all the pages in the list. + * \param ptr The head ptr. + */ + void FreePageList(ArenaPageHeader** ptr) { + // delete all the allocated pages. + while (ptr[0] != nullptr) { + ArenaPageHeader* temp = ptr[0]; + ptr[0] = ptr[0]->next; + alloc_.deallocate(temp); } } }; +using Arena = GenericArena; + /*! * \brief Link list node * \tparam T the content data type diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index b61e6bb9fa01..1757c995a78c 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -22,6 +22,7 @@ import time import multiprocessing +import pytest import numpy as np from tvm import rpc from tvm.contrib import util @@ -77,11 +78,9 @@ def remotethrow(name): f1 = client.get_function("rpc.test.addone") assert f1(10) == 11 f3 = client.get_function("rpc.test.except") - try: + + with pytest.raises(tvm.error.RPCError): f3("abc") - assert False - except tvm.error.TVMError as e: - assert "abc" in str(e) f2 = client.get_function("rpc.test.strcat") assert f2("abc", 11) == "abc:11" @@ -101,6 +100,40 @@ def remote_array_func(y): fremote = remote.get_function("rpc.test.remote_array_func") fremote(r_cpu) + +def test_rpc_echo(): + def check(remote): + fecho = remote.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + with pytest.raises(RuntimeError): + raise_err = remote.get_function( + "testing.test_raise_error_callback")("RuntimeError") + raise_err() + + temp = rpc.server._server_env([]) + server = rpc.Server("localhost") + client = rpc.connect(server.host, server.port) + check(rpc.LocalSession()) + + check(client) + # Test minrpc server. + temp = util.tempdir() + minrpc_exec = temp.relpath("minrpc") + tvm.rpc.with_minrpc("g++")(minrpc_exec, []) + check(rpc.PopenSession(minrpc_exec)) + # minrpc on the remote + server = rpc.Server("localhost") + client = rpc.connect( + server.host, server.port, + session_constructor=["rpc.PopenSession", + open(minrpc_exec, "rb").read()]) + check(client) + + def test_rpc_file_exchange(): if not tvm.runtime.enabled("rpc"): return @@ -114,14 +147,15 @@ def test_rpc_file_exchange(): def test_rpc_remote_module(): if not tvm.runtime.enabled("rpc"): return - server = rpc.Server("localhost") - client = rpc.connect(server.host, server.port) # graph - n = tvm.runtime.convert(1024) + n = tvm.runtime.convert(102) A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) + server = rpc.Server("localhost") + client = rpc.connect(server.host, server.port) + def check_remote(remote): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") @@ -133,13 +167,34 @@ def check_remote(remote): f.export_library(path_dso) remote.upload(path_dso) f1 = remote.load_module("dev_lib.so") - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + def check_minrpc(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + if tvm.get_global_func("rpc.PopenSession", allow_missing=True) is None: + return + # export to minrpc + temp = util.tempdir() + f = tvm.build(s, [A, B], "llvm --system-lib", name="myadd") + path_minrpc = temp.relpath("dev_lib.minrpc") + f.export_library(path_minrpc, rpc.with_minrpc("g++")) + # statrt the minrpc session. + remote = tvm.rpc.PopenSession(path_minrpc) + ctx = remote.cpu(0) + f1 = remote.system_lib() + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) + time_f = f1.time_evaluator("myadd", remote.cpu(0), number=1) + cost = time_f(a, b).mean + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + def check_remote_link_cl(remote): """Test function to run remote code such as cl @@ -174,8 +229,8 @@ def check_remote_link_cl(remote): fhost = remote.load_module("myadd.o") fdev = remote.load_module("myadd.cl") fhost.import_module(fdev) - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) # Option 2: export library as a tar ball then handled by remote compiler @@ -183,13 +238,15 @@ def check_remote_link_cl(remote): f.export_library(path_tar) remote.upload(path_tar) fhost = remote.load_module("myadd.tar") - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote(client) check_remote(rpc.LocalSession()) + check_remote(client) + check_minrpc() + def test_rpc_return_func(): @@ -204,6 +261,37 @@ def addone(x): assert fadd(12) == 22 +def test_rpc_session_constructor(): + # start server + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + def check_multi_hop(): + # use server0 as proxy to connect to server1 + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor=[ + "rpc.Connect", server1.host, server1.port, "x1"]) + + fecho = client.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + nd = tvm.nd.array([1,2,3], ctx=client.cpu(0)) + assert(nd.asnumpy()[1] == 2) + + def check_error_handling(): + with pytest.raises(tvm.error.RPCError): + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor=["rpc.NonExistingConstructor"]) + + check_multi_hop() + check_error_handling() + + def test_rpc_return_ndarray(): # Use closure to check the ref counter correctness nd = tvm.nd.array(np.zeros(10).astype("float32")) @@ -221,6 +309,7 @@ def my_module(name): # start server server = rpc.Server("localhost", key="x1") client = rpc.connect(server.host, server.port, key="x1") + m = client.get_function("rpc.test.remote_return_nd") get_arr = m("get_arr") ref_count = m("ref_count") @@ -315,6 +404,7 @@ def target(host, port, device_key, timeout): time.sleep(0.5) summary = client.summary() + assert summary['queue_info'][device_key]['free'] == 0 assert summary['queue_info'][device_key]['pending'] == 1 @@ -334,6 +424,8 @@ def target(host, port, device_key, timeout): if __name__ == "__main__": logging.basicConfig(level=logging.INFO) + test_rpc_echo() + test_rpc_session_constructor() test_rpc_return_ndarray() test_rpc_return_func() test_bigendian_rpc() diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js index b62b298d969e..86ef59cb73b1 100644 --- a/web/tvm_runtime.js +++ b/web/tvm_runtime.js @@ -907,7 +907,7 @@ var tvm_runtime = tvm_runtime || {}; if (typeof systemFunc.fcreateServer === "undefined") { systemFunc.fcreateServer = - getGlobalFunc("rpc._CreateEventDrivenServer"); + getGlobalFunc("rpc.CreateEventDrivenServer"); } if (systemFunc.fcreateServer == null) { throwError("RPCServer is not included in runtime"); From 20be767b162f27f77dea7ca2b51bba3516343366 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 4 May 2020 13:02:05 -0700 Subject: [PATCH 3/4] Address review comments --- apps/bundle_deploy/runtime.cc | 1 + apps/cpp_rpc/rpc_tracker_client.h | 2 +- python/tvm/_ffi/_ctypes/packed_func.py | 19 +++++++------------ python/tvm/_ffi/_cython/packed_func.pxi | 20 +++++++------------- python/tvm/error.py | 2 +- python/tvm/rpc/client.py | 20 +++++++++++++------- python/tvm/rpc/minrpc.py | 2 +- src/runtime/rpc/minrpc/minrpc_server.h | 16 ++++++++-------- src/runtime/rpc/rpc_endpoint.h | 8 ++++---- tests/python/unittest/test_runtime_rpc.py | 21 ++++++++++++++++----- 10 files changed, 59 insertions(+), 52 deletions(-) diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc index 844f404d98f4..7a116e89fa88 100644 --- a/apps/bundle_deploy/runtime.cc +++ b/apps/bundle_deploy/runtime.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ + #include #include #include diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index 9b9a707dd376..112f7d214a27 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -31,7 +31,7 @@ #include #include -#include "../../src/runtime/rpc/rpc_end_point.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/support/socket.h" namespace tvm { diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 6d2b966b8815..b17174a7c6bf 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -141,18 +141,13 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, TVMContext): values[i].v_int64 = _ctx_to_int64(arg) type_codes[i] = TypeCode.TVM_CONTEXT - elif isinstance(arg, bytes): - byte_arr = bytearray(arg) - arr = TVMByteArray() - arr.data = ctypes.cast( - (ctypes.c_byte * len(arg)).from_buffer(byte_arr), - ctypes.POINTER(ctypes.c_byte)) - arr.size = len(arg) - values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) - temp_args.append(byte_arr) - temp_args.append(arr) - type_codes[i] = TypeCode.BYTES - elif isinstance(arg, bytearray): + elif isinstance(arg, (bytearray, bytes)): + # from_buffer only taeks in bytearray. + if isinstance(arg, bytes): + byte_arr = bytearray(arg) + temp_args.append(byte_arr) + arg = byte_arr + arr = TVMByteArray() arr.data = ctypes.cast( (ctypes.c_byte * len(arg)).from_buffer(arg), diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 4a1bdfd97817..45bcf64a616d 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -142,19 +142,13 @@ cdef inline int make_arg(object arg, value[0].v_ctx = (( ctypes.addressof(arg)))[0] tcode[0] = kTVMContext - elif isinstance(arg, bytes): - byte_arr = bytearray(arg) - arr = TVMByteArray() - arr.data = ctypes.cast( - (ctypes.c_byte * len(arg)).from_buffer(byte_arr), - ctypes.POINTER(ctypes.c_byte)) - arr.size = len(arg) - value[0].v_handle = ( - ctypes.addressof(arr)) - tcode[0] = kTVMBytes - temp_args.append(byte_arr) - temp_args.append(arr) - elif isinstance(arg, bytearray): + elif isinstance(arg, (bytes, bytearray)): + # from_buffer only taeks in bytearray. + if isinstance(arg, bytes): + byte_arr = bytearray(arg) + temp_args.append(byte_arr) + arg = byte_arr + arr = TVMByteArray() arr.data = ctypes.cast( (ctypes.c_byte * len(arg)).from_buffer(arg), diff --git a/python/tvm/error.py b/python/tvm/error.py index 7366fdb3be39..b3502f6b0ead 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -59,7 +59,7 @@ def __init__(self, msg): @register_error class RPCError(RuntimeError): - """Error thrown by the RPC call.""" + """Error thrown by the remote server handling the RPC call.""" @register_error diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index d4250353f8f9..9997673f52c2 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -210,10 +210,14 @@ def _popen_session(binary): path_exec = temp.relpath("server.minrpc") with open(path_exec, "wb") as outfile: outfile.write(binary) - os.chmod(path_exec, stat.S_IXUSR) + os.chmod(path_exec, stat.S_IXUSR | stat.S_IRUSR) path_exec = os.path.abspath(path_exec) else: path_exec = os.path.abspath(binary) + if not os.path.isfile(path_exec): + raise RuntimeError(f"{path_exec} does not exist.") + if not os.access(path_exec, os.X_OK): + raise RuntimeError(f"{path_exec} is not executable.") sess = _ffi_api.CreatePipeClient(path_exec) return sess @@ -402,7 +406,7 @@ def request_and_run(self, key, max_retry, str(last_err))) -def connect(url, port, key="", session_timeout=0, session_constructor=None): +def connect(url, port, key="", session_timeout=0, session_constructor_args=None): """Connect to RPC Server Parameters @@ -421,8 +425,10 @@ def connect(url, port, key="", session_timeout=0, session_constructor=None): the connection when duration is longer than this value. When duration is zero, it means the request must always be kept alive. - session_constructor: List + session_constructor_args: List List of additional arguments to passed as the remote session constructor. + The first element of the list is always a string specifying the name of + the session constructor, the following args are the positional args to that function. Returns ------- @@ -445,17 +451,17 @@ def connect(url, port, key="", session_timeout=0, session_constructor=None): client_via_proxy = rpc.connect( proxy_server_url, proxy_server_port, proxy_server_key, - session_constructor=[ + session_constructor_args=[ "rpc.Connect", internal_url, internal_port, internal_key]) """ try: if session_timeout: key += " -timeout=%s" % str(session_timeout) - session_constructor = session_constructor if session_constructor else [] - if not isinstance(session_constructor, (list, tuple)): + session_constructor_args = session_constructor_args if session_constructor_args else [] + if not isinstance(session_constructor_args, (list, tuple)): raise TypeError("Expect the session constructor to be a list or tuple") - sess = _ffi_api.Connect(url, port, key, *session_constructor) + sess = _ffi_api.Connect(url, port, key, *session_constructor_args) except NameError: raise RuntimeError("Please compile with USE_RPC=1") return RPCSession(sess) diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py index 768b52886ca2..29257c1c3bb2 100644 --- a/python/tvm/rpc/minrpc.py +++ b/python/tvm/rpc/minrpc.py @@ -52,7 +52,7 @@ def with_minrpc(compile_func, Parameters ---------- - compile_func : function + compile_func : Union[str, Callable[[str, str, Optional[str]], None]] The compilation function to decorate. server : str diff --git a/src/runtime/rpc/minrpc/minrpc_server.h b/src/runtime/rpc/minrpc/minrpc_server.h index 370720d14be6..63ad359fef09 100644 --- a/src/runtime/rpc/minrpc/minrpc_server.h +++ b/src/runtime/rpc/minrpc/minrpc_server.h @@ -173,12 +173,12 @@ class MinRPCServer { this->Read(&ctx); this->Read(&type_hint); - char* data_ptr; + uint8_t* data_ptr; int call_ecode = 0; if (ctx.device_type == kDLCPU) { - data_ptr = reinterpret_cast(handle) + offset; + data_ptr = reinterpret_cast(handle) + offset; } else { - data_ptr = this->ArenaAlloc(num_bytes); + data_ptr = this->ArenaAlloc(num_bytes); call_ecode = TVMDeviceCopyDataFromTo( reinterpret_cast(handle), offset, data_ptr, 0, num_bytes, @@ -211,10 +211,10 @@ class MinRPCServer { int call_ecode = 0; if (ctx.device_type == kDLCPU) { - char* dptr = reinterpret_cast(handle) + offset; - this->ReadArray(dptr, num_bytes); + uint8_t* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); } else { - char* temp_data = this->ArenaAlloc(num_bytes); + uint8_t* temp_data = this->ArenaAlloc(num_bytes); this->ReadArray(temp_data, num_bytes); call_ecode = TVMDeviceCopyDataFromTo( @@ -551,7 +551,7 @@ class MinRPCServer { } void ReadRawBytes(void* data, size_t size) { - char* buf = reinterpret_cast(data); + uint8_t* buf = reinterpret_cast(data); size_t ndone = 0; while (ndone < size) { ssize_t ret = io_.PosixRead(buf, size - ndone); @@ -572,7 +572,7 @@ class MinRPCServer { } void WriteRawBytes(const void* data, size_t size) { - const char *buf = reinterpret_cast(data); + const uint8_t *buf = reinterpret_cast(data); size_t ndone = 0; while (ndone < size) { ssize_t ret = io_.PosixWrite(buf, size - ndone); diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index 0fc064baa702..9a6afcdc1ca4 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -93,13 +93,13 @@ class RPCEndpoint { * * If no session constructor arguments is passed, LocalSession will be used in the remote. * Otherwise the remote serving session will be constructed using the arguments - * specified in the session_constructor. + * specified in the session_constructor_args. * * The construction rule can be summarized as follows: * * \code * - * auto args = session_constructor; + * auto args = session_constructor_args; * int n = args.size(); * if (n != 0) { * std::string constructor = args[0]; @@ -110,9 +110,9 @@ class RPCEndpoint { * } * \endcode * - * \param session_constructor Optional sequence of the remote sesssion constructor. + * \param session_constructor_args Optional sequence of the remote sesssion constructor. */ - void InitRemoteSession(TVMArgs session_constructor); + void InitRemoteSession(TVMArgs session_constructor_args); /*! * \brief Call into remote function diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 1757c995a78c..091e9427a25a 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -18,6 +18,7 @@ from tvm import te import tvm.testing import os +import stat import logging import time import multiprocessing @@ -129,7 +130,7 @@ def check(remote): server = rpc.Server("localhost") client = rpc.connect( server.host, server.port, - session_constructor=["rpc.PopenSession", + session_constructor_args=["rpc.PopenSession", open(minrpc_exec, "rb").read()]) check(client) @@ -185,6 +186,10 @@ def check_minrpc(): f = tvm.build(s, [A, B], "llvm --system-lib", name="myadd") path_minrpc = temp.relpath("dev_lib.minrpc") f.export_library(path_minrpc, rpc.with_minrpc("g++")) + + with pytest.raises(RuntimeError): + rpc.PopenSession("filenotexist") + # statrt the minrpc session. remote = tvm.rpc.PopenSession(path_minrpc) ctx = remote.cpu(0) @@ -195,6 +200,12 @@ def check_minrpc(): cost = time_f(a, b).mean np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + # change to not executable + os.chmod(path_minrpc, stat.S_IRUSR) + with pytest.raises(RuntimeError): + rpc.PopenSession(path_minrpc) + + def check_remote_link_cl(remote): """Test function to run remote code such as cl @@ -261,7 +272,7 @@ def addone(x): assert fadd(12) == 22 -def test_rpc_session_constructor(): +def test_rpc_session_constructor_args(): # start server server0 = rpc.Server("localhost", key="x0") server1 = rpc.Server("localhost", key="x1") @@ -270,7 +281,7 @@ def check_multi_hop(): # use server0 as proxy to connect to server1 client = rpc.connect( server0.host, server0.port, key="x0", - session_constructor=[ + session_constructor_args=[ "rpc.Connect", server1.host, server1.port, "x1"]) fecho = client.get_function("testing.echo") @@ -286,7 +297,7 @@ def check_error_handling(): with pytest.raises(tvm.error.RPCError): client = rpc.connect( server0.host, server0.port, key="x0", - session_constructor=["rpc.NonExistingConstructor"]) + session_constructor_args=["rpc.NonExistingConstructor"]) check_multi_hop() check_error_handling() @@ -425,7 +436,7 @@ def target(host, port, device_key, timeout): if __name__ == "__main__": logging.basicConfig(level=logging.INFO) test_rpc_echo() - test_rpc_session_constructor() + test_rpc_session_constructor_args() test_rpc_return_ndarray() test_rpc_return_func() test_bigendian_rpc() From 2236a9c3e48c686b33d7369b1c8e6a90649b03e3 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 4 May 2020 14:11:49 -0700 Subject: [PATCH 4/4] Remove ld library path --- python/tvm/_ffi/base.py | 4 ---- python/tvm/rpc/minrpc.py | 9 ++++++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 61360107f671..8674e31c3b84 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -49,10 +49,6 @@ def _load_lib(): lib_path = libinfo.find_lib_path() lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) lib.TVMGetLastError.restype = ctypes.c_char_p - # Put the libpath to LD_LIBRARY_PATH - # will be useful for pipe session to find libtvm - os.environ["LD_LIBRARY_PATH"] = "%s:%s" % ( - os.path.dirname(lib_path[0]), os.environ.get("LD_LIBRARY_PATH", "")) return lib, os.path.basename(lib_path[0]) # version number diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py index 29257c1c3bb2..760c5362f11d 100644 --- a/python/tvm/rpc/minrpc.py +++ b/python/tvm/rpc/minrpc.py @@ -70,9 +70,16 @@ def with_minrpc(compile_func, runtime_path = libinfo.find_lib_path( [runtime, runtime + ".so", runtime + ".dylib"])[0] + runtime_dir = os.path.abspath(os.path.dirname(runtime_path)) + options = ["-std=c++14"] + # Make sure the rpath to the libtvm is set so we can do local tests. + # Note that however, this approach won't work on remote. + # Always recommend to to link statically. + options += ["-Wl,-rpath=" + runtime_dir] + options += ["-I" + path for path in libinfo.find_include_path()] fcompile = cc.cross_compiler( compile_func, - options=["-std=c++14"] + ["-I" + path for path in libinfo.find_include_path()], + options=options, add_files=[server_path, runtime_path]) fcompile.__name__ = "with_minrpc" fcompile.need_system_lib = True