Skip to content

Commit

Permalink
[REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra.
Browse files Browse the repository at this point in the history
This PR refactors the RPC protocol to make it more modularized.

- RPCSession: represent a set of features that need to be implemented
- RPCEndPont: End point that forwards the RPCSession requests over a communication channel.
- RPCModule: Exposes an RPCSession as an rpc device in the TVM Runtime API.

In the new design, the local machine is presented as a special case of RPCSession.
The remote is just another client session that calls into RPCEndPoint.
The RPC communication path is as follows.

```
client -> ClientSession -> EndPoint[client@n0]
-> networking[between n0 <=> n1]
-> EndPoint[server@n1] -> LocalSession[@n1]

```

Because of the new modular design, we can now chain more sessions together.
For example, we can now run the following proxy setup (testcase in test_runtime_rpc.test_session_constructor).

```
client -> ClientSession -> Endpoint[client@n0]
-> networking[between n0 <=> n1]
-> Endpoint[server@n1] -> ClientSession -> Endpoint[client@n1]
-> networking[between n1 <=> n2]
-> Endpoint[server@n2] -> LocalSession[@n2]
```

We can also implement other types of Sessions.
For example, we could make uTVM session a special case of the RPCSession
and use the same mechanism for session management.

We also add more comments about the internal of the RPC.
The communication protocol is simplfied using a similar convention as PackedFunc.
This allows us to further reduce the amount of special remote syscalls.

Due to the major improvement and simplification, we are making a non-compatible update to the RPC protocol.
It means that the client and server needs to be upgraded to together in order for it to function correctly.

This PR also introduces a versioning mechanism to the current RPC procotol,
so that future upgrade will be produce more user friendly with error messages.
  • Loading branch information
tqchen committed Apr 30, 2020
1 parent 8d72496 commit 50921a4
Show file tree
Hide file tree
Showing 31 changed files with 2,642 additions and 1,781 deletions.
20 changes: 10 additions & 10 deletions apps/cpp_rpc/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <string>

#include "../../src/support/socket.h"
#include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/runtime/rpc/rpc_endpoint.h"
#include "../../src/runtime/rpc/rpc_socket_impl.h"
#include "rpc_env.h"
#include "rpc_server.h"
Expand Down Expand Up @@ -86,7 +86,7 @@ class RPCServer {
tracker_addr_(std::move(tracker_addr)), key_(std::move(key)),
custom_addr_(std::move(custom_addr))
{

}

/*!
Expand All @@ -98,7 +98,7 @@ class RPCServer {
tracker_sock_.Close();
listen_sock_.Close();
} catch(...) {

}
}

Expand Down Expand Up @@ -144,7 +144,7 @@ class RPCServer {
}

int timeout = GetTimeOutFromOpts(opts);
#if defined(__linux__) || defined(__ANDROID__)
#if defined(__linux__) || defined(__ANDROID__)
// step 3: serving
if (timeout != 0) {
const pid_t timer_pid = fork();
Expand Down Expand Up @@ -197,7 +197,7 @@ class RPCServer {
try {
SpawnRPCChild(conn.sockfd, seconds(timeout));
} catch (const std::exception&) {

}
auto dur = high_resolution_clock::now() - start_time;

Expand All @@ -217,10 +217,10 @@ class RPCServer {
* \param opts Parsed options for socket
* \param ping_period Timeout for select call waiting
*/
void AcceptConnection(TrackerClient* tracker,
void AcceptConnection(TrackerClient* tracker,
support::TCPSocket* conn_sock,
support::SockAddr* addr,
std::string* opts,
support::SockAddr* addr,
std::string* opts,
int ping_period = 2) {
std::set<std::string> old_keyset;
std::string matchkey;
Expand Down Expand Up @@ -330,7 +330,7 @@ void ServerLoopFromChild(SOCKET socket) {
tvm::support::TCPSocket sock(socket);
const auto env = RPCEnv();
RPCServerLoop(int(sock.sockfd));

sock.Close();
env.CleanUp();
}
Expand All @@ -357,7 +357,7 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track
rpc.Start();
}

TVM_REGISTER_GLOBAL("rpc._ServerCreate")
TVM_REGISTER_GLOBAL("rpc.ServerCreate")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
});
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_rpc/rpc_tracker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <vector>
#include <string>

#include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/runtime/rpc/rpc_end_point.h"
#include "../../src/support/socket.h"

namespace tvm {
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ class TVM_DLL DeviceAPI {
* \param event_dst The destination stream to synchronize.
*/
virtual void SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
* \brief Allocate temporal workspace for backend execution.
*
* \note We have the following assumption about backend temporal
Expand All @@ -176,8 +176,8 @@ class TVM_DLL DeviceAPI {
* as OpenGL, as nbytes is sufficient for most backends.
*/
virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
DLDataType type_hint = {});
size_t nbytes,
DLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,14 @@ public static GraphModule create(String graphJson, Module libmod, TVMContext ctx
throw new IllegalArgumentException("libmod.typeKey != rpc");
}
final int sessIndex = (int) ((Function) reflectionStaticCall(
RPC.class, "getApi", "_SessTableIndex"))
RPC.class, "getApi", "SessTableIndex"))
.pushArg(libmod).invoke().asLong();
if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) {
throw new IllegalArgumentException(String.format(
"libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d",
sessIndex, reflectionGetField(rpcSession, "tblIndex")));
}

Function rpcModuleHandle = (Function) reflectionStaticCall(
RPC.class, "getApi","_ModuleHandle");
if (rpcModuleHandle == null) {
throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle."
+ "Did you compile tvm_runtime with the correct version?");
}

Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create");
if (fcreate == null) {
throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create."
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/rpc/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class Client {
* @return The connected session.
*/
public static RPCSession connect(String url, int port, String key) {
Function doConnect = RPC.getApi("_Connect");
Function doConnect = RPC.getApi("Connect");
if (doConnect == null) {
throw new RuntimeException("Please compile with USE_RPC=1");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public NativeServerLoop(final Function fsend, final Function frecv) {
try {
tempDir = serverEnv();
System.err.println("starting server loop...");
RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
System.err.println("done server loop...");
} catch (IOException e) {
e.printStackTrace();
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class RPCSession {

RPCSession(Module sess) {
session = sess;
tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong();
tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong();
}

/**
Expand Down Expand Up @@ -237,7 +237,7 @@ public byte[] download(String path) {
* @return The remote module containing remote function.
*/
public Module loadModule(String path) {
return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
}


Expand Down
9 changes: 5 additions & 4 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import numpy as np
import tvm._ffi

from .._ffi.base import string_types
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base
from tvm.rpc import _ffi_api as _rpc_ffi_api
from tvm.rpc import base as rpc_base
from tvm._ffi.base import string_types
from tvm._ffi.runtime_ctypes import TVMContext


def create(graph_json_str, libmod, ctx):
Expand Down Expand Up @@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx):
device_type = cur_ctx.device_type
if device_type >= rpc_base.RPC_SESS_MASK:
assert libmod.type_key == "rpc"
assert rpc_base._SessTableIndex(
assert _rpc_ffi_api.SessTableIndex(
libmod) == cur_ctx._rpc_sess._tbl_index
num_rpc_ctx += 1
device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def __init__(self, msg):
register_error("KeyError", KeyError)


@register_error
class RPCError(RuntimeError):
"""Error thrown by the RPC call."""


@register_error
class OpError(TVMError):
"""Base class of all operator errors in frontends."""
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/rpc/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.rpc"""
import tvm._ffi


tvm._ffi._init_api("rpc", __name__)
7 changes: 0 additions & 7 deletions python/tvm/rpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
"""Base definitions for RPC."""
# pylint: disable=invalid-name

from __future__ import absolute_import

import socket
import time
import json
import errno
import struct
import random
import logging
import tvm._ffi

from .._ffi.base import py_str

Expand Down Expand Up @@ -176,7 +173,3 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
logger.warning("Cannot connect to tracker %s, retry in %g secs...",
str(addr), retry_period)
time.sleep(retry_period)


# Still use tvm.rpc for the foreign functions
tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base")
65 changes: 35 additions & 30 deletions python/tvm/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""RPC client tools"""
from __future__ import absolute_import

import os
import socket
import struct
import time
import tvm._ffi
from tvm.contrib import util

from tvm._ffi.base import TVMError
from tvm.runtime import ndarray as nd
from tvm.runtime import load_module as _load_module

from . import base
from . import server
from . import _ffi_api


class RPCSession(object):
Expand All @@ -38,7 +36,7 @@ class RPCSession(object):
# pylint: disable=invalid-name
def __init__(self, sess):
self._sess = sess
self._tbl_index = base._SessTableIndex(sess)
self._tbl_index = _ffi_api.SessTableIndex(sess)
self._remote_funcs = {}

def get_function(self, name):
Expand Down Expand Up @@ -145,7 +143,7 @@ def load_module(self, path):
m : Module
The remote module containing remote function.
"""
return base._LoadRemoteModule(self._sess, path)
return _ffi_api.LoadRemoteModule(self._sess, path)

def cpu(self, dev_id=0):
"""Construct CPU device."""
Expand Down Expand Up @@ -184,27 +182,8 @@ class LocalSession(RPCSession):
"""
def __init__(self):
# pylint: disable=super-init-not-called
self.context = nd.context
self.get_function = tvm._ffi.get_global_func
self._temp = util.tempdir()

def upload(self, data, target=None):
if isinstance(data, bytearray):
if not target:
raise ValueError("target must present when file is a bytearray")
blob = data
else:
blob = bytearray(open(data, "rb").read())
if not target:
target = os.path.basename(data)
with open(self._temp.relpath(target), "wb") as f:
f.write(blob)

def download(self, path):
return bytearray(open(self._temp.relpath(path), "rb").read())

def load_module(self, path):
return _load_module(self._temp.relpath(path))
self._temp = server._server_env([])
RPCSession.__init__(self, _ffi_api.LocalSession())


class TrackerSession(object):
Expand Down Expand Up @@ -378,7 +357,7 @@ def request_and_run(self,
key, max_retry, str(last_err)))


def connect(url, port, key="", session_timeout=0):
def connect(url, port, key="", session_timeout=0, session_constructor=None):
"""Connect to RPC Server
Parameters
Expand All @@ -397,15 +376,41 @@ def connect(url, port, key="", session_timeout=0):
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
session_constructor: List
List of additional arguments to passed as the remote session constructor.
Returns
-------
sess : RPCSession
The connected session.
Examples
--------
Normal usage
.. code-block:: python
client = rpc.connect(server_url, server_port, server_key)
Session_constructor can be used to customize the session in the remote
The following code connects to a remote internal server via a proxy
by constructing another RPCClientSession on the proxy machine and use that
as the serving session of the proxy endpoint.
.. code-block:: python
client_via_proxy = rpc.connect(
proxy_server_url, proxy_server_port, proxy_server_key,
session_constructor=[
"rpc.Connect", internal_url, internal_port, internal_key])
"""
try:
if session_timeout:
key += " -timeout=%s" % str(session_timeout)
sess = base._Connect(url, port, key)
session_constructor = session_constructor if session_constructor else []
if not isinstance(session_constructor, (list, tuple)):
raise TypeError("Expect the session constructor to be a list or tuple")
sess = _ffi_api.Connect(url, port, key, *session_constructor)
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
return RPCSession(sess)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
raise ImportError(
"RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg)

from . import base
from . import _ffi_api
from .base import TrackerCode
from .server import _server_env
from .._ffi.base import py_str
Expand Down Expand Up @@ -549,7 +549,7 @@ def _fsend(data):
data = bytes(data)
conn.write_message(data, binary=True)
return len(data)
on_message = base._CreateEventDrivenServer(
on_message = _ffi_api.CreateEventDrivenServer(
_fsend, "WebSocketProxyServer", "%toinit")
return on_message

Expand Down
Loading

0 comments on commit 50921a4

Please sign in to comment.