Skip to content

Commit

Permalink
[REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra (#5484)
Browse files Browse the repository at this point in the history
* Update dmlc-core which was mistakenly overriden

* [REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra.

This PR refactors the RPC protocol to make it more modularized.

- RPCSession: represent a set of features that need to be implemented
- RPCEndPont: End point that forwards the RPCSession requests over a communication channel.
- RPCModule: Exposes an RPCSession as an rpc device in the TVM Runtime API.

In the new design, the local machine is presented as a special case of RPCSession.
The remote is just another client session that calls into RPCEndPoint.
The RPC communication path is as follows.

```
client -> ClientSession -> EndPoint[client@n0]
-> networking[between n0 <=> n1]
-> EndPoint[server@n1] -> LocalSession[@n1]

```

Because of the new modular design, we can now chain more sessions together.
For example, we can now run the following proxy setup (testcase in test_runtime_rpc.test_session_constructor).

```
client -> ClientSession -> Endpoint[client@n0]
-> networking[between n0 <=> n1]
-> Endpoint[server@n1] -> ClientSession -> Endpoint[client@n1]
-> networking[between n1 <=> n2]
-> Endpoint[server@n2] -> LocalSession[@n2]
```

We can also implement other types of Sessions.
As an example, We introduced a PopenSession that communicates with
the another process via a pipe.

We also add more comments about the internal of the RPC.
The communication protocol is simplfied using a similar convention as PackedFunc.
This allows us to further reduce the amount of special remote syscalls.

Due to the major improvement and simplification, we are making a non-compatible update to the RPC protocol.
It means that the client and server needs to be upgraded to together in order for it to function correctly.

This PR also introduces a versioning mechanism to the current RPC procotol,
so that future upgrade will be produce more user friendly with error messages.

* Address review comments

* Remove ld library path
  • Loading branch information
tqchen authored May 5, 2020
1 parent 7e88030 commit 95e06b3
Show file tree
Hide file tree
Showing 47 changed files with 4,082 additions and 1,868 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
__pycache__/
*.py[cod]
*$py.class

*.S
# C extensions
*.so
*.ll

# Distribution / packaging
.Python
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
Submodule dmlc-core updated 54 files
+38 −0 .github/workflows/githubci.yml
+0 −1 .gitignore
+0 −82 .travis.yml
+122 −111 CMakeLists.txt
+201 −13 LICENSE
+1 −1 README.md
+19 −6 appveyor.yml
+13 −0 cmake/Modules/FindASan.cmake
+13 −0 cmake/Modules/FindLSan.cmake
+13 −0 cmake/Modules/FindTSan.cmake
+13 −0 cmake/Modules/FindUBSan.cmake
+63 −0 cmake/Sanitizer.cmake
+4 −1 cmake/build_config.h.in
+1 −1 cmake/gtest_cmake.in
+1 −16 doc/Doxyfile
+23 −1 include/dmlc/base.h
+4 −1 include/dmlc/build_config_default.h
+4 −0 include/dmlc/concurrency.h
+18 −18 include/dmlc/concurrentqueue.h
+3 −2 include/dmlc/json.h
+24 −5 include/dmlc/logging.h
+1 −1 include/dmlc/omp.h
+10 −0 include/dmlc/optional.h
+106 −23 include/dmlc/parameter.h
+1 −3 include/dmlc/thread_group.h
+4 −2 include/dmlc/thread_local.h
+74 −46 include/dmlc/threadediter.h
+0 −2 make/dmlc.mk
+2 −2 scripts/lint.py
+12 −19 scripts/packages.mk
+0 −0 scripts/s390x/Dockerfile
+0 −0 scripts/s390x/build_via_cmake.sh
+1 −1 scripts/s390x/ci_build.sh
+0 −0 scripts/s390x/entrypoint.sh
+0 −32 scripts/setup_nvcc.sh
+65 −0 scripts/test_script.sh
+0 −3 scripts/travis/travis_before_cache.sh
+0 −9 scripts/travis/travis_osx_install.sh
+0 −57 scripts/travis/travis_script.sh
+0 −40 scripts/travis/travis_setup_env.sh
+0 −16 src/build_config.cc
+7 −3 src/data/csv_parser.h
+1 −1 test/logging_test.cc
+4 −0 test/unittest/CMakeLists.txt
+2 −1 test/unittest/unittest_env.cc
+30 −0 test/unittest/unittest_param.cc
+80 −56 test/unittest/unittest_parser.cc
+0 −1 test/unittest/unittest_thread_group.cc
+2 −2 test/unittest/unittest_threaditer.cc
+19 −15 test/unittest/unittest_threaditer_exc_handling.cc
+4 −0 tracker/dmlc_tracker/launcher.py
+7 −0 tracker/dmlc_tracker/ssh.py
+13 −0 tracker/dmlc_tracker/util.py
+4 −2 tracker/dmlc_tracker/yarn.py
20 changes: 10 additions & 10 deletions apps/cpp_rpc/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <string>

#include "../../src/support/socket.h"
#include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/runtime/rpc/rpc_endpoint.h"
#include "../../src/runtime/rpc/rpc_socket_impl.h"
#include "rpc_env.h"
#include "rpc_server.h"
Expand Down Expand Up @@ -86,7 +86,7 @@ class RPCServer {
tracker_addr_(std::move(tracker_addr)), key_(std::move(key)),
custom_addr_(std::move(custom_addr))
{

}

/*!
Expand All @@ -98,7 +98,7 @@ class RPCServer {
tracker_sock_.Close();
listen_sock_.Close();
} catch(...) {

}
}

Expand Down Expand Up @@ -144,7 +144,7 @@ class RPCServer {
}

int timeout = GetTimeOutFromOpts(opts);
#if defined(__linux__) || defined(__ANDROID__)
#if defined(__linux__) || defined(__ANDROID__)
// step 3: serving
if (timeout != 0) {
const pid_t timer_pid = fork();
Expand Down Expand Up @@ -197,7 +197,7 @@ class RPCServer {
try {
SpawnRPCChild(conn.sockfd, seconds(timeout));
} catch (const std::exception&) {

}
auto dur = high_resolution_clock::now() - start_time;

Expand All @@ -217,10 +217,10 @@ class RPCServer {
* \param opts Parsed options for socket
* \param ping_period Timeout for select call waiting
*/
void AcceptConnection(TrackerClient* tracker,
void AcceptConnection(TrackerClient* tracker,
support::TCPSocket* conn_sock,
support::SockAddr* addr,
std::string* opts,
support::SockAddr* addr,
std::string* opts,
int ping_period = 2) {
std::set<std::string> old_keyset;
std::string matchkey;
Expand Down Expand Up @@ -330,7 +330,7 @@ void ServerLoopFromChild(SOCKET socket) {
tvm::support::TCPSocket sock(socket);
const auto env = RPCEnv();
RPCServerLoop(int(sock.sockfd));

sock.Close();
env.CleanUp();
}
Expand All @@ -357,7 +357,7 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track
rpc.Start();
}

TVM_REGISTER_GLOBAL("rpc._ServerCreate")
TVM_REGISTER_GLOBAL("rpc.ServerCreate")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
});
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_rpc/rpc_tracker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <vector>
#include <string>

#include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/runtime/rpc/rpc_endpoint.h"
#include "../../src/support/socket.h"

namespace tvm {
Expand Down
48 changes: 48 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,54 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);

/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
* \param nbytes The number of bytes in memory.
* \param alignment The alignment of the memory.
* \param type_hint The type of elements. Only needed by certain backends such
* as nbytes & alignment are sufficient for most backends.
* \param out_data The allocated device pointer.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx,
size_t nbytes,
size_t alignment,
DLDataType type_hint,
void** out_data);

/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
* \param ptr The data space.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceFreeDataSpace(TVMContext ctx, void* ptr);

/*!
* \brief Copy data from one place to another.
* \param from The source array.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param num_bytes The size of the memory in bytes
* \param ctx_from The source context
* \param ctx_to The target context
* \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison.
* \param stream Optional stream object.
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMDeviceCopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t num_bytes,
TVMContext ctx_from,
TVMContext ctx_to,
DLDataType type_hint,
TVMStreamHandle stream);

#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ class TVM_DLL DeviceAPI {
* \param event_dst The destination stream to synchronize.
*/
virtual void SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
* \brief Allocate temporal workspace for backend execution.
*
* \note We have the following assumption about backend temporal
Expand All @@ -176,8 +176,8 @@ class TVM_DLL DeviceAPI {
* as OpenGL, as nbytes is sufficient for most backends.
*/
virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
DLDataType type_hint = {});
size_t nbytes,
DLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
Expand Down
53 changes: 7 additions & 46 deletions jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,53 +38,14 @@ public class GraphRuntime {
* @return Runtime graph module that can be used to execute the graph.
*/
public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) {
Module graphModule = null;
if (ctx.deviceType >= RPC.RPC_SESS_MASK) {
if (!(ctx instanceof TVMRemoteContext)) {
throw new IllegalArgumentException(
"Looks like you are using remote context with no RPCSession bind."
+ "Use session.context instead.");
}
RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession;
// check arguments
if (!"rpc".equals(libmod.typeKey())) {
throw new IllegalArgumentException("libmod.typeKey != rpc");
}
final int sessIndex = (int) ((Function) reflectionStaticCall(
RPC.class, "getApi", "_SessTableIndex"))
.pushArg(libmod).invoke().asLong();
if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) {
throw new IllegalArgumentException(String.format(
"libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d",
sessIndex, reflectionGetField(rpcSession, "tblIndex")));
}

Function rpcModuleHandle = (Function) reflectionStaticCall(
RPC.class, "getApi","_ModuleHandle");
if (rpcModuleHandle == null) {
throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle."
+ "Did you compile tvm_runtime with the correct version?");
}

Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create");
if (fcreate == null) {
throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create."
+ "Did you compile tvm_runtime with correct version?");
}

TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke();
graphModule = fcreate.call(graphJson, hmod,
ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule();
} else {
Function fcreate = Function.getFunction("tvm.graph_runtime.create");
if (fcreate == null) {
throw new RuntimeException("Cannot find global function tvm.graph_runtime.create."
+ "Did you compile tvm_runtime with correct version?");
}
graphModule = fcreate.pushArg(graphJson)
.pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId)
.invoke().asModule();
Function fcreate = Function.getFunction("tvm.graph_runtime.create");
if (fcreate == null) {
throw new RuntimeException("Cannot find global function tvm.graph_runtime.create."
+ "Did you compile tvm_runtime with correct version?");
}
Module graphModule = fcreate.pushArg(graphJson)
.pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId)
.invoke().asModule();

return new GraphModule(graphModule, ctx);
}
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/rpc/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class Client {
* @return The connected session.
*/
public static RPCSession connect(String url, int port, String key) {
Function doConnect = RPC.getApi("_Connect");
Function doConnect = RPC.getApi("Connect");
if (doConnect == null) {
throw new RuntimeException("Please compile with USE_RPC=1");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public NativeServerLoop(final Function fsend, final Function frecv) {
try {
tempDir = serverEnv();
System.err.println("starting server loop...");
RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
System.err.println("done server loop...");
} catch (IOException e) {
e.printStackTrace();
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class RPCSession {

RPCSession(Module sess) {
session = sess;
tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong();
tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong();
}

/**
Expand Down Expand Up @@ -237,7 +237,7 @@ public byte[] download(String path) {
* @return The remote module containing remote function.
*/
public Module loadModule(String path) {
return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
}


Expand Down
8 changes: 7 additions & 1 deletion python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,13 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, TVMContext):
values[i].v_int64 = _ctx_to_int64(arg)
type_codes[i] = TypeCode.TVM_CONTEXT
elif isinstance(arg, bytearray):
elif isinstance(arg, (bytearray, bytes)):
# from_buffer only taeks in bytearray.
if isinstance(arg, bytes):
byte_arr = bytearray(arg)
temp_args.append(byte_arr)
arg = byte_arr

arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,13 @@ cdef inline int make_arg(object arg,
value[0].v_ctx = (<DLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kTVMContext
elif isinstance(arg, bytearray):
elif isinstance(arg, (bytes, bytearray)):
# from_buffer only taeks in bytearray.
if isinstance(arg, bytes):
byte_arr = bytearray(arg)
temp_args.append(byte_arr)
arg = byte_arr

arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
Expand Down
1 change: 0 additions & 1 deletion python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
# DMatrix functions
lib.TVMGetLastError.restype = ctypes.c_char_p
return lib, os.path.basename(lib_path[0])

Expand Down
10 changes: 8 additions & 2 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def get_target_triple():
def cross_compiler(compile_func,
options=None,
output_format=None,
get_target_triple=None):
get_target_triple=None,
add_files=None):
"""Create a cross compiler function by specializing compile_func with options.
This function can be used to construct compile functions that
Expand All @@ -111,6 +112,10 @@ def cross_compiler(compile_func,
get_target_triple: Optional[Callable]
Function that can target triple according to dumpmachine option of compiler.
add_files: Optional[List[str]]
List of paths to additional object, source, library files
to pass as part of the compilation.
Returns
-------
fcompile : Callable[[str, str, Optional[str]], None]
Expand All @@ -133,6 +138,7 @@ def cross_compiler(compile_func,
"""
base_options = [] if options is None else options
kwargs = {}
add_files = [] if add_files is None else add_files

# handle case where compile_func is the name of the cc
if isinstance(compile_func, str):
Expand All @@ -144,7 +150,7 @@ def _fcompile(outputs, objects, options=None):
all_options = base_options
if options is not None:
all_options += options
compile_func(outputs, objects, options=all_options, **kwargs)
compile_func(outputs, objects + add_files, options=all_options, **kwargs)

if not output_format and hasattr(compile_func, "output_format"):
output_format = compile_func.output_format
Expand Down
9 changes: 5 additions & 4 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import numpy as np
import tvm._ffi

from .._ffi.base import string_types
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base
from tvm.rpc import _ffi_api as _rpc_ffi_api
from tvm.rpc import base as rpc_base
from tvm._ffi.base import string_types
from tvm._ffi.runtime_ctypes import TVMContext


def create(graph_json_str, libmod, ctx):
Expand Down Expand Up @@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx):
device_type = cur_ctx.device_type
if device_type >= rpc_base.RPC_SESS_MASK:
assert libmod.type_key == "rpc"
assert rpc_base._SessTableIndex(
assert _rpc_ffi_api.SessTableIndex(
libmod) == cur_ctx._rpc_sess._tbl_index
num_rpc_ctx += 1
device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def __init__(self, msg):
register_error("KeyError", KeyError)


@register_error
class RPCError(RuntimeError):
"""Error thrown by the remote server handling the RPC call."""


@register_error
class OpError(TVMError):
"""Base class of all operator errors in frontends."""
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@
"""

from .server import Server
from .client import RPCSession, LocalSession, TrackerSession, connect, connect_tracker
from .client import connect, connect_tracker
from .client import RPCSession, LocalSession, PopenSession, TrackerSession
from .minrpc import with_minrpc
Loading

0 comments on commit 95e06b3

Please sign in to comment.