diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 1a1bdfcb09b8..eaa95e8cb79c 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -1,11 +1,15 @@ +# pylint: disable=redefined-outer-name, invalid-name """RPC web proxy, allows redirect to websocket based RPC servers(browsers)""" from __future__ import absolute_import import logging import argparse +import multiprocessing +import sys import os from ..contrib.rpc.proxy import Proxy + def find_example_resource(): """Find resource examples.""" curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -22,22 +26,8 @@ def find_example_resource(): return index_page, js_files -def main(): +def main(args): """Main funciton""" - parser = argparse.ArgumentParser() - parser.add_argument('--host', type=str, default="0.0.0.0", - help='the hostname of the server') - parser.add_argument('--port', type=int, default=9090, - help='The port of the PRC') - parser.add_argument('--web-port', type=int, default=8888, - help='The port of the http/websocket server') - parser.add_argument('--example-rpc', type=bool, default=False, - help='Whether to switch on example rpc mode') - parser.add_argument('--tracker', type=str, default="", - help="Report to RPC tracker") - args = parser.parse_args() - logging.basicConfig(level=logging.INFO) - if args.tracker: url, port = args.tracker.split(":") port = int(port) @@ -60,5 +50,33 @@ def main(): tracker_addr=tracker_addr) prox.proc.join() + if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument('--host', type=str, default="0.0.0.0", + help='the hostname of the server') + parser.add_argument('--port', type=int, default=9090, + help='The port of the PRC') + parser.add_argument('--web-port', type=int, default=8888, + help='The port of the http/websocket server') + parser.add_argument('--example-rpc', type=bool, default=False, + help='Whether to switch on example rpc mode') + parser.add_argument('--tracker', type=str, default="", + help="Report to RPC tracker") + parser.add_argument('--no-fork', dest='fork', action='store_false', + help="Use spawn mode to avoid fork. This option \ + is able to avoid potential fork problems with Metal, OpenCL \ + and ROCM compilers.") + parser.set_defaults(fork=True) + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + if args.fork is False: + if sys.version_info[0] < 3: + raise RuntimeError( + "Python3 is required for spawn mode." + ) + multiprocessing.set_start_method('spawn') + else: + logging.info("If you are running ROCM/Metal, \ + fork with cause compiler internal error. Try to launch with arg ```--no-fork```") + main(args) diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 46e7877a2803..773d14e2925a 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -1,26 +1,16 @@ +# pylint: disable=redefined-outer-name, invalid-name """Start an RPC server""" from __future__ import absolute_import import argparse +import multiprocessing +import sys import logging from ..contrib import rpc -def main(): + +def main(args): """Main funciton""" - parser = argparse.ArgumentParser() - parser.add_argument('--host', type=str, default="0.0.0.0", - help='the hostname of the server') - parser.add_argument('--port', type=int, default=9090, - help='The port of the PRC') - parser.add_argument('--port-end', type=int, default=9199, - help='The end search port of the PRC') - parser.add_argument('--key', type=str, default="", - help="RPC key used to identify the connection type.") - parser.add_argument('--load-library', type=str, default="", - help="Additional library to load") - parser.add_argument('--tracker', type=str, default="", - help="Report to RPC tracker") - args = parser.parse_args() if args.tracker: url, port = args.tracker.split(":") @@ -32,7 +22,6 @@ def main(): else: tracker_addr = None - logging.basicConfig(level=logging.INFO) server = rpc.Server(args.host, args.port, args.port_end, @@ -41,5 +30,35 @@ def main(): load_library=args.load_library) server.proc.join() + if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument('--host', type=str, default="0.0.0.0", + help='the hostname of the server') + parser.add_argument('--port', type=int, default=9090, + help='The port of the PRC') + parser.add_argument('--port-end', type=int, default=9199, + help='The end search port of the PRC') + parser.add_argument('--key', type=str, default="", + help="RPC key used to identify the connection type.") + parser.add_argument('--load-library', type=str, default="", + help="Additional library to load") + parser.add_argument('--tracker', type=str, default="", + help="Report to RPC tracker") + parser.add_argument('--no-fork', dest='fork', action='store_false', + help="Use spawn mode to avoid fork. This option \ + is able to avoid potential fork problems with Metal, OpenCL \ + and ROCM compilers.") + parser.set_defaults(fork=True) + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + if args.fork is False: + if sys.version_info[0] < 3: + raise RuntimeError( + "Python3 is required for spawn mode." + ) + multiprocessing.set_start_method('spawn') + else: + logging.info("If you are running ROCM/Metal, \ + fork with cause compiler internal error. Try to launch with arg ```--no-fork```") + main(args) diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py index a5b063f37908..3b76f57eb689 100644 --- a/python/tvm/exec/rpc_tracker.py +++ b/python/tvm/exec/rpc_tracker.py @@ -1,21 +1,40 @@ +# pylint: disable=redefined-outer-name, invalid-name """Tool to start RPC tracker""" from __future__ import absolute_import import logging import argparse +import multiprocessing +import sys from ..contrib.rpc.tracker import Tracker -def main(): + +def main(args): """Main funciton""" + tracker = Tracker(args.host, port=args.port) + tracker.proc.join() + + +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--host', type=str, default="0.0.0.0", help='the hostname of the tracker') parser.add_argument('--port', type=int, default=9190, help='The port of the PRC') + parser.add_argument('--no-fork', dest='fork', action='store_false', + help="Use spawn mode to avoid fork. This option \ + is able to avoid potential fork problems with Metal, OpenCL \ + and ROCM compilers.") + parser.set_defaults(fork=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) - tracker = Tracker(args.host, port=args.port) - tracker.proc.join() - -if __name__ == "__main__": - main() + if args.fork is False: + if sys.version_info[0] < 3: + raise RuntimeError( + "Python3 is required for spawn mode." + ) + multiprocessing.set_start_method('spawn') + else: + logging.info("If you are running ROCM/Metal, \ + fork with cause compiler internal error. Try to launch with arg ```--no-fork```") + main(args)