Skip to content

Commit

Permalink
[RPC] graduate tvm.contrib.rpc -> tvm.rpc (#1410)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jul 9, 2018
1 parent 561e548 commit afd2b9b
Show file tree
Hide file tree
Showing 47 changed files with 117 additions and 97 deletions.
3 changes: 2 additions & 1 deletion apps/android_rpc/tests/android_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import tvm
import os
from tvm.contrib import rpc, util, ndk
from tvm import rpc
from tvm.contrib import util, ndk
import numpy as np

# Set to be address of tvm proxy.
Expand Down
3 changes: 2 additions & 1 deletion apps/ios_rpc/tests/ios_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import tvm
import os
from tvm.contrib import rpc, util, xcode
from tvm import rpc
from tvm.contrib import util, xcode
import numpy as np

# Set to be address of tvm proxy.
Expand Down
4 changes: 2 additions & 2 deletions apps/ios_rpc/tvmrpc/TVMRuntime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ void LaunchSyncServer() {
->ServerLoop();
}

TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});

TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string name = args[0];
std::string fmt = GetFileFormat(name, "");
Expand Down
18 changes: 9 additions & 9 deletions docs/api/python/rpc.rst
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
tvm.contrib.rpc
---------------
.. automodule:: tvm.contrib.rpc
tvm.rpc
-------
.. automodule:: tvm.rpc

.. autofunction:: tvm.contrib.rpc.connect
.. autofunction:: tvm.contrib.rpc.connect_tracker
.. autofunction:: tvm.rpc.connect
.. autofunction:: tvm.rpc.connect_tracker

.. autoclass:: tvm.contrib.rpc.TrackerSession
.. autoclass:: tvm.rpc.TrackerSession
:members:
:inherited-members:

.. autoclass:: tvm.contrib.rpc.RPCSession
.. autoclass:: tvm.rpc.RPCSession
:members:
:inherited-members:

.. autoclass:: tvm.contrib.rpc.LocalSession
.. autoclass:: tvm.rpc.LocalSession
:members:
:inherited-members:

.. autoclass:: tvm.contrib.rpc.Server
.. autoclass:: tvm.rpc.Server
:members:
:inherited-members:
6 changes: 3 additions & 3 deletions docs/install/docker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ We can then use the following command to launch a `tvmai/demo-cpu` image.

.. code:: bash
/path/to/tvm/docker/bash.sh tvmai/demo_cpu
/path/to/tvm/docker/bash.sh tvmai/demo-cpu
.. note::
You can find all the prebuilt images in `<https://hub.docker.com/r/tvmai/>`_
You can also change `demo-cpu` to `demo-gpu` to get a CUDA enabled image.
You can find all the prebuilt images in `<https://hub.docker.com/r/tvmai/>`_


This auxiliary script does the following things:
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ private static File serverEnv() throws IOException {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}

Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() {
Function.register("tvm.rpc.server.workpath", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return tempDir + File.separator + args[0].asString();
}
}, true);

Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() {
Function.register("tvm.rpc.server.load_module", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
String filename = args[0].asString();
String path = tempDir + File.separator + filename;
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ protected Map<String, Function> initialValue() {
static Function getApi(String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction("contrib.rpc." + name);
func = Function.getFunction("rpc." + name);
if (func == null) {
return null;
}
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public void upload(byte[] data, String target) {
final String funcName = "upload";
Function remoteFunc = remoteFuncs.get(funcName);
if (remoteFunc == null) {
remoteFunc = getFunction("tvm.contrib.rpc.server.upload");
remoteFunc = getFunction("tvm.rpc.server.upload");
remoteFuncs.put(funcName, remoteFunc);
}
remoteFunc.pushArg(target).pushArg(data).invoke();
Expand Down Expand Up @@ -205,7 +205,7 @@ public byte[] download(String path) {
final String name = "download";
Function func = remoteFuncs.get(name);
if (func == null) {
func = getFunction("tvm.contrib.rpc.server.download");
func = getFunction("tvm.rpc.server.download");
remoteFuncs.put(name, func);
}
return func.pushArg(path).invoke().asBytes();
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/test/scripts/test_rpc_proxy_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from tvm.contrib.rpc import proxy
from tvm.rpc import proxy

def start_proxy_server(port, timeout):
prox = proxy.Proxy("localhost", port=port, port_end=port+1)
Expand Down
3 changes: 2 additions & 1 deletion nnvm/tests/python/compiler/test_param_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
import nnvm.compiler
import tvm
from tvm.contrib import rpc, util, graph_runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime


def test_save_load():
Expand Down
3 changes: 2 additions & 1 deletion nnvm/tests/python/compiler/test_rpc_exec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tvm
from tvm.contrib import util, rpc, graph_runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime
import nnvm.symbol as sym
import nnvm.compiler
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Minimum graph runtime that executes graph containing TVM PackedFunc."""
from .._ffi.base import string_types
from .._ffi.function import get_global_func
from .rpc import base as rpc_base
from ..rpc import base as rpc_base
from .. import ndarray as nd


Expand Down
12 changes: 6 additions & 6 deletions python/tvm/contrib/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
"""measure bandwidth and compute peak"""

import logging

import tvm
from tvm.contrib import rpc, util
from . import util
from .. import rpc

def _convert_to_remote(func, remote):
""" convert module function to remote rpc function"""
Expand Down Expand Up @@ -47,7 +47,7 @@ def measure_bandwidth_sum(total_item, item_per_thread, stride,
host compilation target
ctx: TVMcontext
the context of array
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
n_times: int
number of runs for taking mean
Expand Down Expand Up @@ -107,7 +107,7 @@ def measure_bandwidth_all_types(total_item, item_per_thread, n_times,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
ctx: TVMcontext
the context of array
Expand Down Expand Up @@ -165,7 +165,7 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
if it is not None, use remote rpc session
ctx: TVMcontext
the context of array
Expand Down Expand Up @@ -250,7 +250,7 @@ def measure_compute_all_types(total_item, item_per_thread, n_times,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
ctx: TVMcontext
the context of array
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/contrib/rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Deprecation RPC module"""
# pylint: disable=unused-import
from __future__ import absolute_import as _abs
import warnings
from ..rpc import Server, RPCSession, LocalSession, TrackerSession, connect, connect_tracker

warnings.warn(
"Please use tvm.rpc instead of tvm.conrtib.rpc. tvm.contrib.rpc is going to be removed in 0.5",
DeprecationWarning)
2 changes: 1 addition & 1 deletion python/tvm/exec/query_rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import argparse
import os
from ..contrib import rpc
from .. import rpc

def main():
"""Main funciton"""
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/exec/rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import multiprocessing
import sys
import os
from ..contrib.rpc.proxy import Proxy
from ..rpc.proxy import Proxy


def find_example_resource():
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import multiprocessing
import sys
import logging
from ..contrib import rpc
from .. import rpc

def main(args):
"""Main function"""
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/exec/rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
import multiprocessing
import sys
from ..contrib.rpc.tracker import Tracker
from ..rpc.tracker import Tracker


def main(args):
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions python/tvm/contrib/rpc/base.py → python/tvm/rpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import random
import logging

from ..._ffi.function import _init_api
from ..._ffi.base import py_str
from .._ffi.function import _init_api
from .._ffi.base import py_str

# Magic header for RPC data plane
RPC_MAGIC = 0xff271
Expand Down Expand Up @@ -158,5 +158,5 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
time.sleep(retry_period)


# Still use tvm.contrib.rpc for the foreign functions
_init_api("tvm.contrib.rpc", "tvm.contrib.rpc.base")
# Still use tvm.rpc for the foreign functions
_init_api("tvm.rpc", "tvm.rpc.base")
14 changes: 7 additions & 7 deletions python/tvm/contrib/rpc/client.py → python/tvm/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import time

from . import base
from .. import util
from ..._ffi.base import TVMError
from ..._ffi import function as function
from ..._ffi import ndarray as nd
from ...module import load as _load_module
from ..contrib import util
from .._ffi.base import TVMError
from .._ffi import function as function
from .._ffi import ndarray as nd
from ..module import load as _load_module


class RPCSession(object):
Expand Down Expand Up @@ -82,7 +82,7 @@ def upload(self, data, target=None):

if "upload" not in self._remote_funcs:
self._remote_funcs["upload"] = self.get_function(
"tvm.contrib.rpc.server.upload")
"tvm.rpc.server.upload")
self._remote_funcs["upload"](target, blob)

def download(self, path):
Expand All @@ -100,7 +100,7 @@ def download(self, path):
"""
if "download" not in self._remote_funcs:
self._remote_funcs["download"] = self.get_function(
"tvm.contrib.rpc.server.download")
"tvm.rpc.server.download")
return self._remote_funcs["download"](path)

def load_module(self, path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from . import base
from .base import TrackerCode
from .server import _server_env
from ..._ffi.base import py_str
from .._ffi.base import py_str


class ForwardHandler(object):
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/contrib/rpc/server.py → python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import time
import sys

from ..._ffi.function import register_func
from ..._ffi.base import py_str
from ..._ffi.libinfo import find_lib_path
from ...module import load as _load_module
from .. import util
from .._ffi.function import register_func
from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module
from ..contrib import util
from . import base
from . base import TrackerCode

Expand All @@ -36,11 +36,11 @@ def _server_env(load_library, logger):
logger = logging.getLogger()

# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath")
@register_func("tvm.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)

@register_func("tvm.contrib.rpc.server.load_module", override=True)
@register_func("tvm.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
raise ImportError(
"RPCTracker module requires tornado package %s" % error_msg)

from ..._ffi.base import py_str
from .._ffi.base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void Module::Import(Module other) {
if (!std::strcmp((*this)->type_key(), "rpc")) {
static const PackedFunc* fimport_ = nullptr;
if (fimport_ == nullptr) {
fimport_ = runtime::Registry::Get("contrib.rpc._ImportRemoteModule");
fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
CHECK(fimport_ != nullptr);
}
(*fimport_)(*this, other);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rpc/rpc_event_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend,
});
}

TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEventDrivenServer(args[0], args[1], args[2]);
});
Expand Down
Loading

0 comments on commit afd2b9b

Please sign in to comment.