diff --git a/python/tvm/contrib/rpc/server.py b/python/tvm/contrib/rpc/server.py index e3c731a20eb9..bef85bc5711c 100644 --- a/python/tvm/contrib/rpc/server.py +++ b/python/tvm/contrib/rpc/server.py @@ -71,8 +71,7 @@ def _parse_server_opt(opts): ret["timeout"] = float(kv[9:]) return ret - -def _listen_loop(sock, port, rpc_key, tracker_addr, load_library): +def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): """Lisenting loop of the server master.""" def _accept_conn(listen_sock, tracker_conn, ping_period=2): """Accept connection from the other places. @@ -93,7 +92,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): if tracker_conn: matchkey = base.random_key(rpc_key + ":") base.sendjson(tracker_conn, - [TrackerCode.PUT, rpc_key, (port, matchkey)]) + [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]) assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS else: matchkey = rpc_key @@ -109,17 +108,18 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): pending_keys = base.recvjson(tracker_conn) old_keyset.add(matchkey) # if match key not in pending key set - # it means the key is aqquired by a client but not used. + # it means the key is acquired by a client but not used. if matchkey not in pending_keys: unmatch_period_count += 1 else: unmatch_period_count = 0 - # regenerate match key if key is aqquired but not used for a while + # regenerate match key if key is acquired but not used for a while if unmatch_period_count * ping_period > unmatch_timeout + ping_period: logging.info("RPCServer: no incoming connections, regenerate key ...") matchkey = base.random_key(rpc_key + ":", old_keyset) base.sendjson(tracker_conn, - [TrackerCode.PUT, rpc_key, (port, matchkey)]) + [TrackerCode.PUT, rpc_key, (port, matchkey), + custom_addr]) assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS unmatch_period_count = 0 continue @@ -278,7 +278,8 @@ def __init__(self, use_popen=False, tracker_addr=None, key="", - load_library=None): + load_library=None, + custom_addr=None): try: if base._ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") @@ -287,6 +288,7 @@ def __init__(self, self.host = host self.port = port self.libs = [] + self.custom_addr = custom_addr if use_popen: cmd = ["python", @@ -298,7 +300,9 @@ def __init__(self, cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key] if load_library: - cmd += ["--load-libary", load_library] + cmd += ["--load-library", load_library] + if custom_addr: + cmd += ["--custom-addr", custom_addr] self.proc = multiprocessing.Process( target=subprocess.check_call, args=(cmd,)) self.proc.deamon = True @@ -324,7 +328,7 @@ def __init__(self, self.sock = sock self.proc = multiprocessing.Process( target=_listen_loop, args=( - self.sock, self.port, key, tracker_addr, load_library)) + self.sock, self.port, key, tracker_addr, load_library, self.custom_addr)) self.proc.deamon = True self.proc.start() else: diff --git a/python/tvm/contrib/rpc/tracker.py b/python/tvm/contrib/rpc/tracker.py index 165ff5b80789..812b3a9770ab 100644 --- a/python/tvm/contrib/rpc/tracker.py +++ b/python/tvm/contrib/rpc/tracker.py @@ -194,7 +194,11 @@ def call_handler(self, args): key = args[1] port, matchkey = args[2] self.pending_matchkeys.add(matchkey) - self._tracker.put(key, (self, self._addr[0], port, matchkey)) + # got custom address (from rpc server) + if args[3] is not None: + self._tracker.put(key, (self, args[3], port, matchkey)) + else: + self._tracker.put(key, (self, self._addr[0], port, matchkey)) self.ret_value(TrackerCode.SUCCESS) elif code == TrackerCode.REQUEST: key = args[1] diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 773d14e2925a..d874ed63b673 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -8,9 +8,8 @@ import logging from ..contrib import rpc - def main(args): - """Main funciton""" + """Main function""" if args.tracker: url, port = args.tracker.split(":") @@ -27,7 +26,8 @@ def main(args): args.port_end, key=args.key, tracker_addr=tracker_addr, - load_library=args.load_library) + load_library=args.load_library, + custom_addr=args.custom_addr) server.proc.join() @@ -49,6 +49,9 @@ def main(args): help="Use spawn mode to avoid fork. This option \ is able to avoid potential fork problems with Metal, OpenCL \ and ROCM compilers.") + parser.add_argument('--custom-addr', type=str, + help="Custom IP Address to Report to RPC Tracker") + parser.set_defaults(fork=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO)