Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra #5484

Merged
merged 4 commits into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
__pycache__/
*.py[cod]
*$py.class

*.S
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to ignore assembly files generated in commands.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok. I guess this means we can't check in assembly without removing it, but we can do that later too.

# C extensions
*.so
*.ll

# Distribution / packaging
.Python
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
Submodule dmlc-core updated 54 files
+38 −0 .github/workflows/githubci.yml
+0 −1 .gitignore
+0 −82 .travis.yml
+122 −111 CMakeLists.txt
+201 −13 LICENSE
+1 −1 README.md
+19 −6 appveyor.yml
+13 −0 cmake/Modules/FindASan.cmake
+13 −0 cmake/Modules/FindLSan.cmake
+13 −0 cmake/Modules/FindTSan.cmake
+13 −0 cmake/Modules/FindUBSan.cmake
+63 −0 cmake/Sanitizer.cmake
+4 −1 cmake/build_config.h.in
+1 −1 cmake/gtest_cmake.in
+1 −16 doc/Doxyfile
+23 −1 include/dmlc/base.h
+4 −1 include/dmlc/build_config_default.h
+4 −0 include/dmlc/concurrency.h
+18 −18 include/dmlc/concurrentqueue.h
+3 −2 include/dmlc/json.h
+24 −5 include/dmlc/logging.h
+1 −1 include/dmlc/omp.h
+10 −0 include/dmlc/optional.h
+106 −23 include/dmlc/parameter.h
+1 −3 include/dmlc/thread_group.h
+4 −2 include/dmlc/thread_local.h
+74 −46 include/dmlc/threadediter.h
+0 −2 make/dmlc.mk
+2 −2 scripts/lint.py
+12 −19 scripts/packages.mk
+0 −0 scripts/s390x/Dockerfile
+0 −0 scripts/s390x/build_via_cmake.sh
+1 −1 scripts/s390x/ci_build.sh
+0 −0 scripts/s390x/entrypoint.sh
+0 −32 scripts/setup_nvcc.sh
+65 −0 scripts/test_script.sh
+0 −3 scripts/travis/travis_before_cache.sh
+0 −9 scripts/travis/travis_osx_install.sh
+0 −57 scripts/travis/travis_script.sh
+0 −40 scripts/travis/travis_setup_env.sh
+0 −16 src/build_config.cc
+7 −3 src/data/csv_parser.h
+1 −1 test/logging_test.cc
+4 −0 test/unittest/CMakeLists.txt
+2 −1 test/unittest/unittest_env.cc
+30 −0 test/unittest/unittest_param.cc
+80 −56 test/unittest/unittest_parser.cc
+0 −1 test/unittest/unittest_thread_group.cc
+2 −2 test/unittest/unittest_threaditer.cc
+19 −15 test/unittest/unittest_threaditer_exc_handling.cc
+4 −0 tracker/dmlc_tracker/launcher.py
+7 −0 tracker/dmlc_tracker/ssh.py
+13 −0 tracker/dmlc_tracker/util.py
+4 −2 tracker/dmlc_tracker/yarn.py
20 changes: 10 additions & 10 deletions apps/cpp_rpc/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <string>

#include "../../src/support/socket.h"
#include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/runtime/rpc/rpc_endpoint.h"
#include "../../src/runtime/rpc/rpc_socket_impl.h"
#include "rpc_env.h"
#include "rpc_server.h"
Expand Down Expand Up @@ -86,7 +86,7 @@ class RPCServer {
tracker_addr_(std::move(tracker_addr)), key_(std::move(key)),
custom_addr_(std::move(custom_addr))
{

}

/*!
Expand All @@ -98,7 +98,7 @@ class RPCServer {
tracker_sock_.Close();
listen_sock_.Close();
} catch(...) {

}
}

Expand Down Expand Up @@ -144,7 +144,7 @@ class RPCServer {
}

int timeout = GetTimeOutFromOpts(opts);
#if defined(__linux__) || defined(__ANDROID__)
#if defined(__linux__) || defined(__ANDROID__)
// step 3: serving
if (timeout != 0) {
const pid_t timer_pid = fork();
Expand Down Expand Up @@ -197,7 +197,7 @@ class RPCServer {
try {
SpawnRPCChild(conn.sockfd, seconds(timeout));
} catch (const std::exception&) {

}
auto dur = high_resolution_clock::now() - start_time;

Expand All @@ -217,10 +217,10 @@ class RPCServer {
* \param opts Parsed options for socket
* \param ping_period Timeout for select call waiting
*/
void AcceptConnection(TrackerClient* tracker,
void AcceptConnection(TrackerClient* tracker,
tqchen marked this conversation as resolved.
Show resolved Hide resolved
support::TCPSocket* conn_sock,
support::SockAddr* addr,
std::string* opts,
support::SockAddr* addr,
std::string* opts,
int ping_period = 2) {
std::set<std::string> old_keyset;
std::string matchkey;
Expand Down Expand Up @@ -330,7 +330,7 @@ void ServerLoopFromChild(SOCKET socket) {
tvm::support::TCPSocket sock(socket);
const auto env = RPCEnv();
RPCServerLoop(int(sock.sockfd));

sock.Close();
env.CleanUp();
}
Expand All @@ -357,7 +357,7 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track
rpc.Start();
}

TVM_REGISTER_GLOBAL("rpc._ServerCreate")
TVM_REGISTER_GLOBAL("rpc.ServerCreate")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
});
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_rpc/rpc_tracker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <vector>
#include <string>

#include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/runtime/rpc/rpc_endpoint.h"
#include "../../src/support/socket.h"

namespace tvm {
Expand Down
48 changes: 48 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,54 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);

/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
* \param nbytes The number of bytes in memory.
* \param alignment The alignment of the memory.
* \param type_hint The type of elements. Only needed by certain backends such
* as nbytes & alignment are sufficient for most backends.
* \param out_data The allocated device pointer.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx,
size_t nbytes,
size_t alignment,
DLDataType type_hint,
void** out_data);

/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
* \param ptr The data space.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceFreeDataSpace(TVMContext ctx, void* ptr);

/*!
* \brief Copy data from one place to another.
* \param from The source array.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param num_bytes The size of the memory in bytes
* \param ctx_from The source context
* \param ctx_to The target context
* \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison.
* \param stream Optional stream object.
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMDeviceCopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t num_bytes,
TVMContext ctx_from,
TVMContext ctx_to,
DLDataType type_hint,
TVMStreamHandle stream);

#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ class TVM_DLL DeviceAPI {
* \param event_dst The destination stream to synchronize.
*/
virtual void SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
* \brief Allocate temporal workspace for backend execution.
*
* \note We have the following assumption about backend temporal
Expand All @@ -176,8 +176,8 @@ class TVM_DLL DeviceAPI {
* as OpenGL, as nbytes is sufficient for most backends.
*/
virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
DLDataType type_hint = {});
size_t nbytes,
DLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
Expand Down
53 changes: 7 additions & 46 deletions jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,53 +38,14 @@ public class GraphRuntime {
* @return Runtime graph module that can be used to execute the graph.
*/
public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) {
Module graphModule = null;
if (ctx.deviceType >= RPC.RPC_SESS_MASK) {
if (!(ctx instanceof TVMRemoteContext)) {
throw new IllegalArgumentException(
"Looks like you are using remote context with no RPCSession bind."
+ "Use session.context instead.");
}
RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession;
// check arguments
if (!"rpc".equals(libmod.typeKey())) {
throw new IllegalArgumentException("libmod.typeKey != rpc");
}
final int sessIndex = (int) ((Function) reflectionStaticCall(
RPC.class, "getApi", "_SessTableIndex"))
.pushArg(libmod).invoke().asLong();
if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) {
throw new IllegalArgumentException(String.format(
"libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d",
sessIndex, reflectionGetField(rpcSession, "tblIndex")));
}

Function rpcModuleHandle = (Function) reflectionStaticCall(
RPC.class, "getApi","_ModuleHandle");
if (rpcModuleHandle == null) {
throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle."
+ "Did you compile tvm_runtime with the correct version?");
}

Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create");
if (fcreate == null) {
throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create."
+ "Did you compile tvm_runtime with correct version?");
}

TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke();
graphModule = fcreate.call(graphJson, hmod,
ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule();
} else {
Function fcreate = Function.getFunction("tvm.graph_runtime.create");
if (fcreate == null) {
throw new RuntimeException("Cannot find global function tvm.graph_runtime.create."
+ "Did you compile tvm_runtime with correct version?");
}
graphModule = fcreate.pushArg(graphJson)
.pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId)
.invoke().asModule();
Function fcreate = Function.getFunction("tvm.graph_runtime.create");
if (fcreate == null) {
tqchen marked this conversation as resolved.
Show resolved Hide resolved
throw new RuntimeException("Cannot find global function tvm.graph_runtime.create."
+ "Did you compile tvm_runtime with correct version?");
}
Module graphModule = fcreate.pushArg(graphJson)
.pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId)
.invoke().asModule();

return new GraphModule(graphModule, ctx);
}
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/rpc/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class Client {
* @return The connected session.
*/
public static RPCSession connect(String url, int port, String key) {
Function doConnect = RPC.getApi("_Connect");
Function doConnect = RPC.getApi("Connect");
if (doConnect == null) {
throw new RuntimeException("Please compile with USE_RPC=1");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public NativeServerLoop(final Function fsend, final Function frecv) {
try {
tempDir = serverEnv();
System.err.println("starting server loop...");
RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
System.err.println("done server loop...");
} catch (IOException e) {
e.printStackTrace();
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class RPCSession {

RPCSession(Module sess) {
session = sess;
tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong();
tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong();
}

/**
Expand Down Expand Up @@ -237,7 +237,7 @@ public byte[] download(String path) {
* @return The remote module containing remote function.
*/
public Module loadModule(String path) {
return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
}


Expand Down
8 changes: 7 additions & 1 deletion python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,13 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, TVMContext):
values[i].v_int64 = _ctx_to_int64(arg)
type_codes[i] = TypeCode.TVM_CONTEXT
elif isinstance(arg, bytearray):
elif isinstance(arg, (bytearray, bytes)):
# from_buffer only taeks in bytearray.
if isinstance(arg, bytes):
byte_arr = bytearray(arg)
temp_args.append(byte_arr)
arg = byte_arr

arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,13 @@ cdef inline int make_arg(object arg,
value[0].v_ctx = (<DLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kTVMContext
elif isinstance(arg, bytearray):
elif isinstance(arg, (bytes, bytearray)):
# from_buffer only taeks in bytearray.
if isinstance(arg, bytes):
byte_arr = bytearray(arg)
temp_args.append(byte_arr)
arg = byte_arr

arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
Expand Down
1 change: 0 additions & 1 deletion python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
# DMatrix functions
lib.TVMGetLastError.restype = ctypes.c_char_p
return lib, os.path.basename(lib_path[0])

Expand Down
10 changes: 8 additions & 2 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def get_target_triple():
def cross_compiler(compile_func,
options=None,
output_format=None,
get_target_triple=None):
get_target_triple=None,
add_files=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a little bit of a weird place (compiler configuration time) to add extra files (I.e. invocation configuration time). why can't this be pushed to the invocation of fcompile?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it is an API tradeoff. The main advantage of functor decorator style is that it is composable.

For example, we could pass use the following to attach minserver to the ndk compiler.

mod.export_library("xyz", tvm.rpc.with_minrpc(cc.ndk))

Another reason is that in the case of minrpc, we need to attach a special property(need_system_lib) to tell the export_library to check for the --system-lib option.

We might need a bit more thoughts to make the alternative API(as follows) work.

mod.export_library("xyz", cc.ndk, libs=[tvm.rpc.minrpc()])

Finally, such functor decorator is easier to work with in the AutoTVM(we just have to pass the decorated version as build_func, instead of passing an additional argument)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very attached to the argument though. Pehaps changig to something like default_libs or add_libs would alleviate the concern?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like based on how this is used now, the problem is that we are trying to make export_library build a binary in a few small cases (but, which will become increasingly common as µTVM is more integrated). i think it's reasonable for export_library to control the compiler invocation up to the point of generating a .o. beyond that ldflags and additional statically-linked dependencies (which will have their own cflags and compiler toolchains) seem like a bit much of an API abuse. maybe this will be better addressed by the compilation and linking work i'm doing? if so, is there an intermediate hack we can do that avoids adding to this API (if you agree that's the right direction to go)?

Copy link
Member Author

@tqchen tqchen May 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is mainly a difference in API choices.

  • A0 pushing more flags to export_library
  • A1 using functional style programming decorator, to decorate a existing compiler by adding ldflags implicitly. Pass in fcompile variants, or use fcompile->fcompile functors to decorate new compilers.

My take is that A1 is as powerful as A0, and the complexity of compiler customization goes to the construction process of fcompile. It also have the advantage of moving customization to a single place (fcompile) so that we don't have to worry about customizing the export_library step in autotvm and downstream toolchains that uses export_library.

We can certainly de-couple the APIs in A1. Right now compilation and linking are coupled, and a more powerful way to construct A1 might be as follows, and we can work toward that direction.

fcompile = attach_linker_opts(get_compiler())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I think de-coupling the APIs in A1 makes sense. if we do that though, then only autotvm (and other user scripts and any µTVM veneers) need to care about linker flags and linker scripts, correct? we can remove that complexity from export_library.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your understanding is correct

"""Create a cross compiler function by specializing compile_func with options.

This function can be used to construct compile functions that
Expand All @@ -111,6 +112,10 @@ def cross_compiler(compile_func,
get_target_triple: Optional[Callable]
Function that can target triple according to dumpmachine option of compiler.

add_files: Optional[List[str]]
List of paths to additional object, source, library files
to pass as part of the compilation.

Returns
-------
fcompile : Callable[[str, str, Optional[str]], None]
Expand All @@ -133,6 +138,7 @@ def cross_compiler(compile_func,
"""
base_options = [] if options is None else options
kwargs = {}
add_files = [] if add_files is None else add_files

# handle case where compile_func is the name of the cc
if isinstance(compile_func, str):
Expand All @@ -144,7 +150,7 @@ def _fcompile(outputs, objects, options=None):
all_options = base_options
if options is not None:
all_options += options
compile_func(outputs, objects, options=all_options, **kwargs)
compile_func(outputs, objects + add_files, options=all_options, **kwargs)

if not output_format and hasattr(compile_func, "output_format"):
output_format = compile_func.output_format
Expand Down
9 changes: 5 additions & 4 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import numpy as np
import tvm._ffi

from .._ffi.base import string_types
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base
from tvm.rpc import _ffi_api as _rpc_ffi_api
from tvm.rpc import base as rpc_base
from tvm._ffi.base import string_types
from tvm._ffi.runtime_ctypes import TVMContext


def create(graph_json_str, libmod, ctx):
Expand Down Expand Up @@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx):
device_type = cur_ctx.device_type
if device_type >= rpc_base.RPC_SESS_MASK:
assert libmod.type_key == "rpc"
assert rpc_base._SessTableIndex(
assert _rpc_ffi_api.SessTableIndex(
libmod) == cur_ctx._rpc_sess._tbl_index
num_rpc_ctx += 1
device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def __init__(self, msg):
register_error("KeyError", KeyError)


@register_error
class RPCError(RuntimeError):
"""Error thrown by the remote server handling the RPC call."""


@register_error
class OpError(TVMError):
"""Base class of all operator errors in frontends."""
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@
"""

from .server import Server
from .client import RPCSession, LocalSession, TrackerSession, connect, connect_tracker
from .client import connect, connect_tracker
from .client import RPCSession, LocalSession, PopenSession, TrackerSession
from .minrpc import with_minrpc
Loading