diff --git a/python/orca/src/bigdl/orca/ray/util/raycontext.py b/python/orca/src/bigdl/orca/ray/util/raycontext.py index 890deb72259..7ba75842da0 100755 --- a/python/orca/src/bigdl/orca/ray/util/raycontext.py +++ b/python/orca/src/bigdl/orca/ray/util/raycontext.py @@ -92,7 +92,8 @@ def _prepare_env(self, cores=None): return modified_env def __init__(self, python_loc, redis_port, ray_node_cpu_cores, mkl_cores, - password, object_store_memory, waitting_time_sec=6, verbose=False, env=None): + password, object_store_memory, waitting_time_sec=6, verbose=False, env=None, + extra_params=None): """object_store_memory: integer in bytes""" self.env = env self.python_loc = python_loc @@ -103,6 +104,7 @@ def __init__(self, python_loc, redis_port, ray_node_cpu_cores, mkl_cores, self.ray_exec = self._get_ray_exec() self.object_store_memory = object_store_memory self.waiting_time_sec = waitting_time_sec + self.extra_params = extra_params self.verbose = verbose self.labels = """--resources='{"trainer": %s, "ps": %s }' """ % (1, 1) @@ -115,22 +117,25 @@ def _stop(iter): return _stop - def _gen_master_command(self): - command = "{} start --head " \ - "--include-webui --redis-port {} \ - --redis-password {} --num-cpus {} ". \ - format(self.ray_exec, self.redis_port, self.password, self.ray_node_cpu_cores) + def _enrich_command(self, command): if self.object_store_memory: command = command + "--object-store-memory {} ".format(str(self.object_store_memory)) + if self.extra_params: + for pair in self.extra_params.items(): + command = command + " --{} {} ".format(pair[0], pair[1]) return command + def _gen_master_command(self): + command = "{} start --head " \ + "--include-webui --redis-port {} " \ + "--redis-password {} --num-cpus {} ". \ + format(self.ray_exec, self.redis_port, self.password, self.ray_node_cpu_cores) + return self._enrich_command(command) + def _get_raylet_command(self, redis_address): command = "{} start --redis-address {} --redis-password {} --num-cpus {} {} ".format( self.ray_exec, redis_address, self.password, self.ray_node_cpu_cores, self.labels) - - if self.object_store_memory: - command = command + "--object-store-memory {} ".format(str(self.object_store_memory)) - return command + return self._enrich_command(command) def _start_ray_node(self, command, tag, wait_before=5, wait_after=5): modified_env = self._prepare_env(self.mkl_cores) @@ -180,18 +185,24 @@ def _start_ray_services(iter): class RayContext(object): def __init__(self, sc, redis_port=None, password="123456", object_store_memory=None, - verbose=False, env=None, local_ray_node_num=2, waitting_time_sec=8): + verbose=False, env=None, local_ray_node_num=2, waiting_time_sec=8, + extra_params=None): """ - The RayContext would init a ray cluster on top of the configuration of the SparkContext. + The RayContext would init a ray cluster on top of the configuration of SparkContext. For spark cluster mode: The number of raylets is equal to number of executors. For Spark local mode: The number of raylets is controlled by local_ray_node_num. - CPU cores for each raylet equals to spark_cores/local_ray_node_num. + CPU cores for each is raylet equals to spark_cores/local_ray_node_num. :param sc: :param redis_port: redis port for the "head" node. - The value would be randomly picked if not specified - :param local_ray_node_num number of raylets to be created. + The value would be randomly picked if not specified. + :param password: [optional] password for the redis. :param object_store_memory: Memory size for the object_store. + :param verbose: True for more logs. :param env: The environment variable dict for running Ray. + :param local_ray_node_num number of raylets to be created. + :param waiting_time_sec: Waiting time for the raylets before connecting to redis. + :param extra_params: key value dictionary for extra options to launch Ray. + i.e extra_params={"temp-dir": "/tmp/ray2/"} """ self.sc = sc self.stopped = False @@ -212,7 +223,8 @@ def __init__(self, sc, redis_port=None, password="123456", object_store_memory=N object_store_memory=self._enrich_object_sotre_memory(sc, object_store_memory), verbose=verbose, env=env, - waitting_time_sec=waitting_time_sec) + waitting_time_sec=waiting_time_sec, + extra_params=extra_params) self._gather_cluster_ips() from bigdl.util.common import init_executor_gateway print("Start to launch the JVM guarding process") diff --git a/python/orca/src/bigdl/orca/ray/util/spark.py b/python/orca/src/bigdl/orca/ray/util/spark.py index b0c58c84aef..802a8bad8df 100644 --- a/python/orca/src/bigdl/orca/ray/util/spark.py +++ b/python/orca/src/bigdl/orca/ray/util/spark.py @@ -98,6 +98,18 @@ def _detect_python_location(self): "Cannot detect current python location. Please set it manually by python_location") return process_info.out + def _gather_essential_jars(self): + from bigdl.util.engine import get_bigdl_classpath + from zoo.util.engine import get_analytics_zoo_classpath + bigdl_classpath = get_bigdl_classpath() + zoo_classpath = get_analytics_zoo_classpath() + assert bigdl_classpath, "Cannot find bigdl classpath" + assert zoo_classpath, "Cannot find Analytics-Zoo classpath" + if bigdl_classpath == zoo_classpath: + return [zoo_classpath] + else: + return [zoo_classpath, bigdl_classpath] + def init_spark_on_local(self, cores, conf=None, python_location=None): print("Start to getOrCreate SparkContext") os.environ['PYSPARK_PYTHON'] =\ @@ -124,29 +136,24 @@ def init_spark_on_yarn(self, penv_archive=None, hadoop_user_name="root", spark_yarn_archive=None, + spark_conf=None, jars=None): os.environ["HADOOP_CONF_DIR"] = hadoop_conf os.environ['HADOOP_USER_NAME'] = hadoop_user_name os.environ['PYSPARK_PYTHON'] = "python_env/bin/python" def _yarn_opt(jars): - from zoo.util.engine import get_analytics_zoo_classpath command = " --archives {}#python_env --num-executors {} " \ " --executor-cores {} --executor-memory {}".\ format(penv_archive, num_executor, executor_cores, executor_memory) - path_to_zoo_jar = get_analytics_zoo_classpath() + jars_list = self._gather_essential_jars() + if jars: + jars_list.append(jars) if extra_python_lib: command = command + " --py-files {} ".format(extra_python_lib) - if jars: - command = command + " --jars {},{} ".format(jars, path_to_zoo_jar) - elif path_to_zoo_jar: - command = command + " --jars {} ".format(path_to_zoo_jar) - - if path_to_zoo_jar: - command = command + " --conf spark.driver.extraClassPath={} ".\ - format(get_analytics_zoo_classpath()) + command = command + " --jars {}".format(",".join(jars_list)) return command def _submit_opt(): @@ -158,7 +165,7 @@ def _submit_opt(): conf["spark.executor.memoryOverhead"] = extra_executor_memory_for_ray if spark_yarn_archive: conf.insert("spark.yarn.archive", spark_yarn_archive) - return " --master yarn " + _yarn_opt(jars) + 'pyspark-shell', conf + return " --master yarn --deploy-mode client" + _yarn_opt(jars) + ' pyspark-shell ', conf pack_env = False assert penv_archive or conda_name, \ @@ -169,6 +176,9 @@ def _submit_opt(): pack_env = True submit_args, conf = _submit_opt() + if spark_conf: + for item in spark_conf.items(): + conf[str(item[0])] = str(item[1]) sc = self._create_sc(submit_args, conf) finally: if conda_name and penv_archive and pack_env: diff --git a/python/orca/src/test/bigdl/orca/ray/integration/ray_on_yarn.py b/python/orca/src/test/bigdl/orca/ray/integration/ray_on_yarn.py index cc9207bf087..8c60c965381 100644 --- a/python/orca/src/test/bigdl/orca/ray/integration/ray_on_yarn.py +++ b/python/orca/src/test/bigdl/orca/ray/integration/ray_on_yarn.py @@ -30,10 +30,12 @@ executor_memory="10g", driver_memory="2g", driver_cores=4, - extra_executor_memory_for_ray="30g") + extra_executor_memory_for_ray="30g", + spark_conf={"hello": "world"}) ray_ctx = RayContext(sc=sc, object_store_memory="25g", + extra_params={"temp-dir": "/tmp/hello/"}, env={"http_proxy": "http://child-prc.intel.com:913", "http_proxys": "http://child-prc.intel.com:913"}) ray_ctx.init() diff --git a/python/orca/src/test/bigdl/orca/ray/test_ray_on_local.py b/python/orca/src/test/bigdl/orca/ray/test_ray_on_local.py index 689ac0ba8c2..e381c028080 100644 --- a/python/orca/src/test/bigdl/orca/ray/test_ray_on_local.py +++ b/python/orca/src/test/bigdl/orca/ray/test_ray_on_local.py @@ -47,8 +47,6 @@ def test_local(self): for process_info in ray_ctx.ray_processesMonitor.process_infos: for pid in process_info.pids: assert not psutil.pid_exists(pid) - sc.stop() - if __name__ == "__main__": pytest.main([__file__])