diff --git a/pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py b/pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py index 4487d11ca1e..e390d52fa1a 100644 --- a/pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py +++ b/pyzoo/zoo/orca/learn/pytorch/pytorch_ray_estimator.py @@ -101,7 +101,7 @@ def __init__( if backend == "pytorch": cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node num_nodes = ray_ctx.num_ray_nodes * workers_per_node - RemoteRunner = ray.remote(num_cpus=1)(TorchRunner) + RemoteRunner = ray.remote(num_cpus=cores_per_node)(TorchRunner) self.remote_workers = [ RemoteRunner.remote(**params) for i in range(num_nodes) ] @@ -110,9 +110,10 @@ def __init__( for i, worker in enumerate(self.remote_workers) ]) - ip = ray.services.get_node_ip_address() - port = utils.find_free_port() - address = "tcp://{ip}:{port}".format(ip=ip, port=port) + head_worker = self.remote_workers[0] + address = ray.get(head_worker.setup_address.remote()) + + logger.info(f"initializing pytorch process group on {address}") ray.get([ worker.setup_torch_distribute.remote(address, i, num_nodes) diff --git a/pyzoo/zoo/orca/learn/pytorch/torch_runner.py b/pyzoo/zoo/orca/learn/pytorch/torch_runner.py index f4bf90e6dc4..3ef03bf5887 100644 --- a/pyzoo/zoo/orca/learn/pytorch/torch_runner.py +++ b/pyzoo/zoo/orca/learn/pytorch/torch_runner.py @@ -47,6 +47,7 @@ from zoo.orca.learn.pytorch.constants import SCHEDULER_STEP, NUM_STEPS from zoo.orca.learn.pytorch.training_operator import TrainingOperator from zoo.orca.learn.pytorch import utils +from zoo.orca.learn.pytorch.utils import find_free_port logger = logging.getLogger(__name__) @@ -124,6 +125,11 @@ def setup_horovod(self): self.setup_components_horovod() self.setup_operator() + def setup_address(self): + ip = ray.services.get_node_ip_address() + port = find_free_port() + return f"tcp://{ip}:{port}" + def setup_torch_distribute(self, url, world_rank, world_size): import torch.distributed as dist dist.init_process_group( diff --git a/pyzoo/zoo/ray/raycontext.py b/pyzoo/zoo/ray/raycontext.py index 82f04845265..a204b34a7c6 100755 --- a/pyzoo/zoo/ray/raycontext.py +++ b/pyzoo/zoo/ray/raycontext.py @@ -465,8 +465,9 @@ def init(self, driver_cores=0): from bigdl.util.common import init_executor_gateway init_executor_gateway(self.sc) print("JavaGatewayServer has been successfully launched on executors") - self._start_cluster() - self._address_info = self._start_driver(num_cores=driver_cores) + redis_address = self._start_cluster() + self._address_info = self._start_driver(num_cores=driver_cores, + redis_address=redis_address) print(self._address_info) kill_redundant_log_monitors(self._address_info["redis_address"]) @@ -494,14 +495,14 @@ def _start_cluster(self): self.ray_processesMonitor = ProcessMonitor(process_infos, self.sc, ray_rdd, self, verbose=self.verbose) - return self + return self.ray_processesMonitor.master.master_addr - def _start_restricted_worker(self, num_cores, node_ip_address): + def _start_restricted_worker(self, num_cores, node_ip_address, redis_address): extra_param = {"node-ip-address": node_ip_address} if self.extra_params is not None: extra_param.update(self.extra_params) command = RayServiceFuncGenerator._get_raylet_command( - redis_address=self.redis_address, + redis_address=redis_address, ray_exec="ray", password=self.redis_password, ray_node_cpu_cores=num_cores, @@ -513,13 +514,14 @@ def _start_restricted_worker(self, num_cores, node_ip_address): tag="raylet", fail_fast=True) ProcessMonitor.register_shutdown_hook(pgid=process_info.pgid) - def _start_driver(self, num_cores=0): + def _start_driver(self, num_cores, redis_address): print("Start to launch ray driver on local") import ray.services - node_ip = ray.services.get_node_ip_address(self.redis_address) + node_ip = ray.services.get_node_ip_address(redis_address) self._start_restricted_worker(num_cores=num_cores, - node_ip_address=node_ip) + node_ip_address=node_ip, + redis_address=redis_address) ray.shutdown() - return ray.init(address=self.redis_address, + return ray.init(address=redis_address, redis_password=self.ray_service.password, node_ip_address=node_ip)