Skip to content

Commit

Permalink
support custom IP address from rpc server to tracker (PUT) (apache#1243)
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored and tqchen committed Jun 8, 2018
1 parent cb4e638 commit 2566a72
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
22 changes: 13 additions & 9 deletions python/tvm/contrib/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def _parse_server_opt(opts):
ret["timeout"] = float(kv[9:])
return ret


def _listen_loop(sock, port, rpc_key, tracker_addr, load_library):
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Lisenting loop of the server master."""
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.
Expand All @@ -93,7 +92,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
if tracker_conn:
matchkey = base.random_key(rpc_key + ":")
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)])
[TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
else:
matchkey = rpc_key
Expand All @@ -109,17 +108,18 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
pending_keys = base.recvjson(tracker_conn)
old_keyset.add(matchkey)
# if match key not in pending key set
# it means the key is aqquired by a client but not used.
# it means the key is acquired by a client but not used.
if matchkey not in pending_keys:
unmatch_period_count += 1
else:
unmatch_period_count = 0
# regenerate match key if key is aqquired but not used for a while
# regenerate match key if key is acquired but not used for a while
if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
logging.info("RPCServer: no incoming connections, regenerate key ...")
matchkey = base.random_key(rpc_key + ":", old_keyset)
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)])
[TrackerCode.PUT, rpc_key, (port, matchkey),
custom_addr])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
unmatch_period_count = 0
continue
Expand Down Expand Up @@ -278,7 +278,8 @@ def __init__(self,
use_popen=False,
tracker_addr=None,
key="",
load_library=None):
load_library=None,
custom_addr=None):
try:
if base._ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
Expand All @@ -287,6 +288,7 @@ def __init__(self,
self.host = host
self.port = port
self.libs = []
self.custom_addr = custom_addr

if use_popen:
cmd = ["python",
Expand All @@ -298,7 +300,9 @@ def __init__(self,
cmd += ["--tracker=%s:%d" % tracker_addr,
"--key=%s" % key]
if load_library:
cmd += ["--load-libary", load_library]
cmd += ["--load-library", load_library]
if custom_addr:
cmd += ["--custom-addr", custom_addr]
self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True
Expand All @@ -324,7 +328,7 @@ def __init__(self,
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr, load_library))
self.sock, self.port, key, tracker_addr, load_library, self.custom_addr))
self.proc.deamon = True
self.proc.start()
else:
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/contrib/rpc/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ def call_handler(self, args):
key = args[1]
port, matchkey = args[2]
self.pending_matchkeys.add(matchkey)
self._tracker.put(key, (self, self._addr[0], port, matchkey))
# got custom address (from rpc server)
if args[3] is not None:
self._tracker.put(key, (self, args[3], port, matchkey))
else:
self._tracker.put(key, (self, self._addr[0], port, matchkey))
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.REQUEST:
key = args[1]
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import logging
from ..contrib import rpc


def main(args):
"""Main funciton"""
"""Main function"""

if args.tracker:
url, port = args.tracker.split(":")
Expand All @@ -27,7 +26,8 @@ def main(args):
args.port_end,
key=args.key,
tracker_addr=tracker_addr,
load_library=args.load_library)
load_library=args.load_library,
custom_addr=args.custom_addr)
server.proc.join()


Expand All @@ -49,6 +49,9 @@ def main(args):
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.")
parser.add_argument('--custom-addr', type=str,
help="Custom IP Address to Report to RPC Tracker")

parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
Expand Down

0 comments on commit 2566a72

Please sign in to comment.