diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 58728993bc3f..34095acb7ad1 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1004,6 +1004,7 @@ cd_unittest_ubuntu() { # Adding these here as CI doesn't test all CUDA environments $python_cmd example/image-classification/test_score.py integrationtest_ubuntu_gpu_dist_kvstore + integrationtest_ubuntu_gpu_byteps fi if [[ ${mxnet_variant} = *mkl ]]; then @@ -1351,6 +1352,24 @@ integrationtest_ubuntu_gpu_dist_kvstore() { popd } +integrationtest_ubuntu_gpu_byteps() { + set -ex + pushd . + export PYTHONPATH=$PWD/python/ + export BYTEPS_WITHOUT_PYTORCH=1 + export BYTEPS_WITHOUT_TENSORFLOW=1 + git clone -b v0.2 https://github.com/bytedance/byteps/ --recursive + cd byteps && python3 setup.py install --user && cd - + + export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 + export MXNET_SUBGRAPH_VERBOSE=0 + export DMLC_LOG_STACK_TRACE_DEPTH=10 + cd tests/nightly/ + python3 ../../tools/launch.py -n 1 -s 1 --byteps --env NVIDIA_VISIBLE_DEVICES:0,1 python3 dist_device_sync_kvstore_byteps.py + popd +} + + test_ubuntu_cpu_python3() { set -ex pushd . diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index aea1c51d0095..a3d9003ec86f 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -1277,6 +1277,20 @@ def test_unix_distributed_kvstore_cpu() { }] } +def test_unix_byteps_gpu() { + return ['byteps tests GPU': { + node(NODE_LINUX_GPU) { + ws('workspace/it-byteps') { + timeout(time: max_time, unit: 'MINUTES') { + utils.unpack_and_init('gpu', mx_lib) + utils.docker_run('ubuntu_gpu_cu101', 'integrationtest_ubuntu_gpu_byteps', true) + utils.publish_test_coverage() + } + } + } + }] +} + def test_unix_distributed_kvstore_gpu() { return ['dist-kvstore tests GPU': { node(NODE_LINUX_GPU) { diff --git a/ci/jenkins/Jenkinsfile_edge b/ci/jenkins/Jenkinsfile_edge index 9d8e01399d7c..2b4ee529e987 100644 --- a/ci/jenkins/Jenkinsfile_edge +++ b/ci/jenkins/Jenkinsfile_edge @@ -44,7 +44,7 @@ core_logic: { utils.parallel_stage('Tests', [ custom_steps.test_qemu_armv7_cpu() - ]) + ]) } , failure_handler: { diff --git a/ci/jenkins/Jenkinsfile_unix_gpu b/ci/jenkins/Jenkinsfile_unix_gpu index 66d3c1391944..df374fbb986a 100644 --- a/ci/jenkins/Jenkinsfile_unix_gpu +++ b/ci/jenkins/Jenkinsfile_unix_gpu @@ -60,6 +60,7 @@ core_logic: { custom_steps.test_unix_cpp_package_gpu(), custom_steps.test_unix_scala_gpu(), custom_steps.test_unix_distributed_kvstore_gpu(), + custom_steps.test_unix_byteps_gpu(), custom_steps.test_static_python_gpu(), custom_steps.test_static_python_gpu_cmake(), custom_steps.test_unix_python3_gpu_no_tvm_op(), diff --git a/python/mxnet/kvstore/__init__.py b/python/mxnet/kvstore/__init__.py index ccb58a1c6229..bc099179a88a 100644 --- a/python/mxnet/kvstore/__init__.py +++ b/python/mxnet/kvstore/__init__.py @@ -22,3 +22,4 @@ from .kvstore import * from .base import * from .kvstore_server import * +from .byteps import * diff --git a/python/mxnet/kvstore/base.py b/python/mxnet/kvstore/base.py index 63d46823a1b7..39e84f52b2e3 100644 --- a/python/mxnet/kvstore/base.py +++ b/python/mxnet/kvstore/base.py @@ -313,6 +313,8 @@ def pushpull(self, key, value, out=None, priority=0): def is_capable(capability): """Queries if the KVStore type supports certain capability, such as optimizer algorithm, gradient compression, sparsity, etc. + If the kvstore does not store weights in server part, then no optimizer is supported, + this function will return False. Parameters ---------- @@ -427,10 +429,15 @@ def create(name='local'): No two updates happen on the same weight at the same time. However, the order is not guaranteed. + ``byteps``: Use byteps as broadcast/pushpull backend. + This kind of kvstore doesn't store weights, thus there won't be optimizer in this kvstore server. + Byteps doesn't support pure cpu training, so be sure to enable gpu training when using this kvstore. + Parameters ---------- - name : {'local', 'device', 'nccl', 'dist_sync', 'dist_device_sync', 'dist_async', 'horovod'} + name : {'local', 'device', 'nccl', 'dist_sync', 'dist_device_sync', 'dist_async', 'horovod', 'byteps'} The type of KVStore. + Returns ------- kv : KVStoreBase diff --git a/python/mxnet/kvstore/byteps.py b/python/mxnet/kvstore/byteps.py new file mode 100644 index 000000000000..d6493fd611a3 --- /dev/null +++ b/python/mxnet/kvstore/byteps.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +""" BytePS backend for MXNet KVStore""" +from __future__ import absolute_import + +from ..ndarray import NDArray +from .base import KVStoreBase + +__all__ = ['BytePS'] + + +@KVStoreBase.register +class BytePS(KVStoreBase): + """BytePS backend for MXNet KVStore interface.""" + + def __init__(self): + """Initializes a new KVStore.""" + try: + import byteps.mxnet as bps + self.handle = bps + except ImportError as err: + print('Did not find BytePS library. Please install BytePS first') + raise err + self.handle.init() + + def broadcast(self, key, value, out, priority=0): + """ Broadcast the value NDArray at rank 0 to all ranks' out. If out is None, + the result is stored in `value`. + + Parameters + ---------- + key : str, or int + The keys. + value : NDArray, or list of NDArray + Values corresponding to the key. + out : NDArray, or lise of NDArray + Values corresponding to the keys. + + Examples + -------- + >>> # broadcast a single key-value pair + >>> shape = (2,3) + >>> kv = mx.kv.create('byteps') + >>> a = mx.nd.zeros(shape) + >>> kv.broadcast('3', mx.nd.ones(shape)*2, out=a) + >>> print a.asnumpy() + [[ 2. 2. 2.] + [ 2. 2. 2.]] + """ + # do not accept list or tuple for key/value + assert isinstance(key, (str, int)) + + # unpack the list if it contains just one NDArray + value = value[0] if isinstance( + value, list) and len(value) == 1 else value + assert isinstance( + value, NDArray), "The type of value can only be NDArray or list of NDArray which has only one element." + assert value.context.device_type == 'gpu', "Byteps KVStore only support GPU context for broadcast value." + + # optimzation when out = value or out = [value] + if isinstance(out, (list, tuple)) and len(out) == 1: + inplace = value is out[0] + else: + inplace = value is out + + if inplace: + broadcast_value = value + else: + broadcast_value = value.copy() + # for non-root-rank, assign value with 0, thus the result of pushpull will be + # equal to the value of root-rank, thus implementing broadcast. + root_rank = 0 + if self.rank != root_rank: + broadcast_value.__imul__(0) + self.handle.byteps_push_pull(broadcast_value, version=0, priority=priority, + name=str(key), is_average=False) + # Make sure tensors pushed to MXNet engine get processed such that all + # workers are synced before starting training. + broadcast_value.wait_to_read() + + out = out if isinstance(out, list) else [out] + for o in out: + broadcast_value.copyto(o) + + def pushpull(self, key, value, out=None, priority=0): + """ Performs push and pull a single value from the store. + This function is coalesced form of push and pull operations. + `value` is pushed to the kvstore server for the specified keys and the aggregated + values are pulled from the server to `out`. If `out` is not specified the pulled + values are written to `value`. + + Parameters + ---------- + key : str, or int + The key. + value : NDArray, or list of NDArray + Values corresponding to the key. + out: NDArray, or list of NDArray + Values corresponding to the key. + priority : int, optional + The priority of the operation. + Higher priority operations are likely to be executed before other actions. + + Examples + -------- + >>> # pushpull a single key-value pair + >>> kv.pushpull('3', mx.nd.ones(shape)*8, out=a) + >>> print a.asnumpy() + [[ 8. 8. 8.] + [ 8. 8. 8.]] + """ + # the most common operation operates on one NDArray as `value`, and + # `out` is set to None, for inplace pushpull. + + assert isinstance(key, (str, int)) + + # unpack the list if it contains just one NDArray + value = value[0] if isinstance( + value, list) and len(value) == 1 else value + assert isinstance( + value, NDArray), "The type of value can only be NDArray or list of NDArray which has only one element." + assert value.context.device_type == 'gpu', "Byteps KVStore only support GPU context for pushpull value" + + # optimzation when out = value or out = [value] + if isinstance(out, (list, tuple)) and len(out) == 1: + inplace = value is out[0] + else: + inplace = value is out + + if inplace: + pushpull_value = value + else: + pushpull_value = value.copy() + + self.handle.byteps_push_pull(pushpull_value, version=0, priority=priority, + name=str(key), is_average=False) + + if out is not None: + out = out if isinstance(out, list) else [out] + for o in out: + pushpull_value.copyto(o) + + @staticmethod + def is_capable(capability): + """Queries if the KVStore type supports certain capability, such as optimizer algorithm, + gradient compression, sparsity, etc. + As byteps server does not store weight, this function will return false for any capabilities. + + Parameters + ---------- + capability: str + The capability to query + + Returns + ------- + result : bool + Whether the capability is supported or not. + """ + return False + + @property + def type(self): + """ Returns the type of this kvstore. + + Returns + ------- + type : str + the string type + """ + return 'byteps' + + @property + def local_rank(self): + """ Returns the local rank of this worker on the node. + + Returns + ------- + rank : int + The local rank of this node, which is in range [0, num_workers_on_current_node()) + """ + return self.handle.local_rank() + + @property + def rank(self): + """ Returns the rank of this worker node. + + Returns + ------- + rank : int + The rank of this node, which is in range [0, num_workers()) + """ + return self.handle.rank() + + @property + def num_workers(self): + """Returns the number of worker nodes. + + Returns + ------- + size :int + The number of worker nodes. + """ + return self.handle.size() + + def set_optimizer(self, optimizer): + """ + Not Implement yet. + + Parameters + ---------- + optimizer : KVStoreBase + The new optimizer for the store + """ + raise NotImplementedError() + + def save_optimizer_states(self, fname, dump_optimizer=False): + """ + Not Implement yet. + + Parameters + ---------- + fname : str + Path to the output states file. + dump_optimizer : bool, default False + Whether to also save the optimizer itself. This would also save optimizer + information such as learning rate and weight decay schedules. + """ + raise NotImplementedError() + + def load_optimizer_states(self, fname): + """ + Not Implement yet. + + Parameters + ---------- + fname : str + Path to input states file. + """ + raise NotImplementedError() diff --git a/tests/nightly/dist_device_sync_kvstore_byteps.py b/tests/nightly/dist_device_sync_kvstore_byteps.py new file mode 100644 index 000000000000..676be92611e1 --- /dev/null +++ b/tests/nightly/dist_device_sync_kvstore_byteps.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +sys.path.insert(0, "../../python/") +import mxnet as mx +import numpy as np +import numpy.random as rnd +import time +import argparse +from mxnet.log import get_logger +import logging +from mxnet.kvstore import BytePS +logger = get_logger("Byteps-Backend-Test", level=logging.DEBUG) + +# parser +parser = argparse.ArgumentParser(description='kvstore test') +parser.add_argument('--name', type=str, default='byteps') +args = parser.parse_args() + +def check_diff_to_scalar(A, x, rank=None): + """ assert A == x""" + assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x) + +# setup +keys = ['3', '5', '7'] +init_test_keys = [str(i) for i in range(200,300)] +init_test_keys_big = [str(i) for i in range(300,400)] +init_test_keys_device = [str(i) for i in range(400,500)] +init_test_keys_device_big = [str(i) for i in range(500,600)] + +shape = (2, 3) +big_shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND + +kv = mx.kv.create(args.name) +my_rank = kv.rank +my_num_workers = kv.num_workers + +has_gpu = mx.context.num_gpus() > 0 + +def get_current_context(device=False): + if has_gpu and device==True: + return mx.gpu(kv.local_rank) + else: + return mx.current_context() + +def test_pushpull(): + def check_default_keys(nrepeat=3): + # init kv dns keys + kv.broadcast('3', mx.nd.ones(shape, ctx=get_current_context(device=True)), mx.nd.ones(shape, ctx=get_current_context(device=True))) + kv.broadcast('99', mx.nd.ones(big_shape, ctx=get_current_context(device=True)), mx.nd.ones(big_shape, ctx=get_current_context(device=True))) + for i in range(nrepeat): + scale = my_rank + 1 + num = (my_num_workers + 1) * my_num_workers / 2 + + arr = mx.nd.ones(shape, ctx=get_current_context(device=True)) * scale + # inplace + kv.pushpull('3', arr) + check_diff_to_scalar(arr, num) + + big_arr = mx.nd.ones(big_shape, ctx=get_current_context(device=True)) * scale + # inplace + kv.pushpull('99', big_arr) + check_diff_to_scalar(big_arr, num) + + check_default_keys(nrepeat=3) + print('worker ' + str(my_rank) + ' is done') + +def test_broadcast(): + def check_broadcast(kv, cur_keys, cur_shape, device=False): + print("check_broadcast: {}, {}, {}, {}".format(kv, cur_keys, cur_shape, device)) + ctx = get_current_context(device=device) + val = [mx.nd.zeros(cur_shape, ctx) for i in cur_keys] + for i in range(len(cur_keys)): + expected = i + tmpNDarray = [mx.nd.ones(cur_shape, ctx) * i] + kv.broadcast(cur_keys[i], tmpNDarray, out=val[i]) + check_diff_to_scalar(val[i], expected, my_rank) + print("check_broadcast passed: ", val) + # check_broadcast(kv, init_test_keys, shape) #Byteps doesn't support pure CPU training + # check_broadcast(kv, init_test_keys_big, big_shape) #Byteps doesn't support pure CPU training + check_broadcast(kv, init_test_keys_device, shape, device=True) + check_broadcast(kv, init_test_keys_device_big, big_shape, device=True) + print('worker ' + str(my_rank) + ' is initialized') + +def test_type(): + assert kv.type == args.name + +if __name__ == "__main__": + print("Type Test Begin") + test_type() + print("Type Test Passed") + print("Broadcast Test Begin") + test_broadcast() + print("Broadcast Test Passed") + print("PushPull Test Begin") + test_pushpull() + print("PushPull Test Passed") diff --git a/tools/byteps_launcher.py b/tools/byteps_launcher.py new file mode 100644 index 000000000000..b65ffe2c89ce --- /dev/null +++ b/tools/byteps_launcher.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python + +# BytePS Copyright 2019 Bytedance Inc. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Launch a distributed job for BytePS +Combining the byteps/launcher/dist_launcher.py and byteps/launcher/launch.py of +https://github.com/bytedance/byteps.git @ 2152d88 +""" +import argparse +import os +import sys +import signal +import logging +import subprocess +from multiprocessing import Pool, Process +from threading import Thread + + +def preprocess_envs(args_envs): + envs_map = {} + for item in args_envs: + i = item.find(":") + if i != -1: + key = item[:i] + val = item[i+1:] + envs_map[key] = val + return envs_map + + +def get_env(envs_map): + envs = [] + # get system envs + keys = ['OMP_NUM_THREADS', 'KMP_AFFINITY'] + for k in keys: + v = os.getenv(k) + if v is not None: + envs.append('export ' + k + '=' + v + ';') + # get ass_envs + for k, v in envs_map.items(): + envs.append('export ' + str(k) + '=' + str(v) + ';') + return (' '.join(envs)) + + +def get_hosts_from_file(filename): + with open(filename) as f: + tmp = f.readlines() + assert len(tmp) > 0 + hosts = [] + for h in tmp: + if len(h.strip()) > 0: + # parse addresses of the form ip:port + h = h.strip() + i = h.find(":") + p = "22" + if i != -1: + p = h[i+1:] + h = h[:i] + # hosts now contain the pair ip, port + hosts.append((h, p)) + return hosts + + +def start_ssh(prog, node, port, username, fname): + def run(prog): + subprocess.check_call(prog, shell=True) + + dirname = 'sshlog' + if not os.path.exists(dirname): + os.mkdir(dirname) + + pname = dirname + '/' + fname + if username is not None: + prog = 'ssh -o StrictHostKeyChecking=no ' + ' -l ' + username \ + + ' ' + node + ' -p ' + port + ' \'' + prog + '\'' \ + + ' > ' + pname + '.stdout' + ' 2>' + pname + '.stderr&' + else: + prog = 'ssh -o StrictHostKeyChecking=no ' + node + ' -p ' + port + ' \'' + prog + '\'' \ + + ' > ' + pname + '.stdout' + ' 2>' + pname + '.stderr&' + + thread = Thread(target=run, args=(prog,)) + thread.setDaemon(True) + thread.start() + return thread + + +def submit(args): + if args.num_servers is None: + args.num_servers = args.num_workers + if args.server_hostfile is not None: + server_hosts = get_hosts_from_file(args.server_hostfile) + worker_hosts = get_hosts_from_file(args.hostfile) + args.num_workers = len(worker_hosts) + args.num_servers = len(server_hosts) + elif args.hostfile is not None: + assert (args.num_servers is not None and args.num_workers is not None), \ + "For BytePS backend, you must specify num_servers and num_workers" + all_hosts = get_hosts_from_file(args.hostfile) + assert(len(all_hosts) == args.num_workers + args.num_servers), \ + "The sum of the number of workers and servers must be equal to \ + the number of hosts in the hostfile" + server_hosts = all_hosts[:args.num_servers] + worker_hosts = all_hosts[args.num_servers:] + else: + print("Warning: no hostfile was specified, {} servers and {} workers will be launched in localhost".format( + args.num_servers, args.num_workers)) + server_hosts = [] + worker_hosts = [] + for i in range(args.num_servers): + server_hosts.append(('localhost', '22')) + for i in range(args.num_workers): + worker_hosts.append(('localhost', '22')) + + num_server = args.num_servers + num_worker = args.num_workers + assert num_server >= 1, "There must be at least one server." + assert num_worker >= 1, "There must be at least one worker." + + print('Launch %d workers and %d servers' % (num_worker, num_server)) + + # common env + pass_envs = preprocess_envs(args.env) + pass_envs['DMLC_NUM_WORKER'] = str(num_worker) + pass_envs['DMLC_NUM_SERVER'] = str(num_server) + pass_envs['DMLC_PS_ROOT_URI'] = '127.0.0.1' + pass_envs['DMLC_PS_ROOT_PORT'] = str(22) + + username = None + threads = [] + for (node, port) in [('127.0.0.1', str(22))]: + name = 'scheduler' + pass_envs['DMLC_ROLE'] = name + print('Launching Scheduler...') + prog = get_env(pass_envs) + (" python3 -c " + + "\"" + "import byteps.server" + "\"") + threads.append(start_ssh(prog, node, port, username, name)) + + for i, (node, port) in enumerate(worker_hosts): + name = 'worker' + pass_envs['DMLC_ROLE'] = name + pass_envs['DMLC_WORKER_ID'] = str(i) + print('Launching Worker{} ...'.format(i)) + local_size = max(len(os.getenv("NVIDIA_VISIBLE_DEVICES", "1").split(",")), len(pass_envs.get("NVIDIA_VISIBLE_DEVICES", "1").split(","))) + + for local_rank in range(local_size): + pass_envs["BYTEPS_LOCAL_RANK"] = str(local_rank) + pass_envs["BYTEPS_LOCAL_SIZE"] = str(local_size) + command = args.command + if int(os.getenv("BYTEPS_ENABLE_GDB", 0)) or pass_envs.get("BYTEPS_ENABLE_GDB", 0) == "1": + if command.find("python3") != 0: + command = "python3 " + command + command = ["gdb -ex 'run' -ex 'bt' -batch --args "] + command + prog = get_env(pass_envs) + (' '.join(command)) + + if pass_envs.get("BYTEPS_TRACE_ON", 0) == "1": + print("\n!!!Enable profiling for WORKER_ID: %s and local_rank: %d!!!" % ( + pass_envs["DMLC_WORKER_ID"], local_rank)) + print("BYTEPS_TRACE_START_STEP: %s\tBYTEPS_TRACE_END_STEP: %s\t BYTEPS_TRACE_DIR: %s" % ( + pass_envs["BYTEPS_TRACE_START_STEP"], pass_envs["BYTEPS_TRACE_END_STEP"], pass_envs["BYTEPS_TRACE_DIR"])) + print("Command: %s\n" % command) + sys.stdout.flush() + trace_path = os.path.join( + pass_envs["BYTEPS_TRACE_DIR"], str(local_rank)) + if not os.path.exists(trace_path): + os.makedirs(trace_path) + threads.append( + start_ssh(prog, node, port, username, name + str(i))) + + for i, (node, port) in enumerate(server_hosts): + name = 'server' + pass_envs['DMLC_ROLE'] = name + print('Launching Server{} ...'.format(i)) + prog = get_env(pass_envs) + (" python3 -c " + + "\"" + "import byteps.server" + "\"") + threads.append(start_ssh(prog, node, port, username, name + str(i))) + + for t in threads: + t.join() diff --git a/tools/launch.py b/tools/launch.py index 7000e061fd4b..c2ffcd96e6d4 100755 --- a/tools/launch.py +++ b/tools/launch.py @@ -63,7 +63,13 @@ def main(): in default it is equal to NUM_WORKERS') parser.add_argument('-H', '--hostfile', type=str, help = 'the hostfile of slave machines which will run \ - the job. Required for ssh and mpi launcher') + the job. Required for ssh and mpi launcher.\ + When -SH is set, the file provided by -H will \ + be used to recognize worker machines only. Otherwise, \ + -H is used for both server and worker machines.') + parser.add_argument('-SH', '--server-hostfile', type=str, + help = 'the hostfile of server machines which will run \ + the job. Required for byteps multi-machine launching.') parser.add_argument('--sync-dst-dir', type=str, help = 'if specificed, it will sync the current \ directory into slave machines\'s SYNC_DST_DIR if ssh \ @@ -71,6 +77,9 @@ def main(): parser.add_argument('--launcher', type=str, default='ssh', choices = ['local', 'ssh', 'mpi', 'sge', 'yarn'], help = 'the launcher to use') + bps_group = parser.add_argument_group('byteps-backend') + bps_group.add_argument('--byteps', action='store_true', + help = 'Whether use byteps launcher to launch') parser.add_argument('--env-server', action='append', default=[], help = 'Given a pair of environment_variable:value, sets this value of \ environment variable for the server processes. This overrides values of \ @@ -92,6 +101,12 @@ def main(): help = 'command for launching the program') args, unknown = parser.parse_known_args() args.command += unknown + + if args.byteps: + import byteps_launcher as bpsl + bpsl.submit(args) + return + if args.num_servers is None: args.num_servers = args.num_workers if args.p3: