From 9628039443928d2c9562850d64ac6a702bb1b2a5 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Wed, 28 Oct 2020 23:30:21 +0800 Subject: [PATCH] Fix orca ray pytorch example (#3007) * fix horovod pytorch exampe * fix bug * fix process group * fix style * fix tests * fix test * fix tests * revert ray context change --- python/orca/src/bigdl/orca/ray/raycontext.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/orca/src/bigdl/orca/ray/raycontext.py b/python/orca/src/bigdl/orca/ray/raycontext.py index 82f04845265..a204b34a7c6 100755 --- a/python/orca/src/bigdl/orca/ray/raycontext.py +++ b/python/orca/src/bigdl/orca/ray/raycontext.py @@ -465,8 +465,9 @@ def init(self, driver_cores=0): from bigdl.util.common import init_executor_gateway init_executor_gateway(self.sc) print("JavaGatewayServer has been successfully launched on executors") - self._start_cluster() - self._address_info = self._start_driver(num_cores=driver_cores) + redis_address = self._start_cluster() + self._address_info = self._start_driver(num_cores=driver_cores, + redis_address=redis_address) print(self._address_info) kill_redundant_log_monitors(self._address_info["redis_address"]) @@ -494,14 +495,14 @@ def _start_cluster(self): self.ray_processesMonitor = ProcessMonitor(process_infos, self.sc, ray_rdd, self, verbose=self.verbose) - return self + return self.ray_processesMonitor.master.master_addr - def _start_restricted_worker(self, num_cores, node_ip_address): + def _start_restricted_worker(self, num_cores, node_ip_address, redis_address): extra_param = {"node-ip-address": node_ip_address} if self.extra_params is not None: extra_param.update(self.extra_params) command = RayServiceFuncGenerator._get_raylet_command( - redis_address=self.redis_address, + redis_address=redis_address, ray_exec="ray", password=self.redis_password, ray_node_cpu_cores=num_cores, @@ -513,13 +514,14 @@ def _start_restricted_worker(self, num_cores, node_ip_address): tag="raylet", fail_fast=True) ProcessMonitor.register_shutdown_hook(pgid=process_info.pgid) - def _start_driver(self, num_cores=0): + def _start_driver(self, num_cores, redis_address): print("Start to launch ray driver on local") import ray.services - node_ip = ray.services.get_node_ip_address(self.redis_address) + node_ip = ray.services.get_node_ip_address(redis_address) self._start_restricted_worker(num_cores=num_cores, - node_ip_address=node_ip) + node_ip_address=node_ip, + redis_address=redis_address) ray.shutdown() - return ray.init(address=self.redis_address, + return ray.init(address=redis_address, redis_password=self.ray_service.password, node_ip_address=node_ip)