diff --git a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py index 4487d11ca1e..e390d52fa1a 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py @@ -101,7 +101,7 @@ def __init__( if backend == "pytorch": cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node num_nodes = ray_ctx.num_ray_nodes * workers_per_node - RemoteRunner = ray.remote(num_cpus=1)(TorchRunner) + RemoteRunner = ray.remote(num_cpus=cores_per_node)(TorchRunner) self.remote_workers = [ RemoteRunner.remote(**params) for i in range(num_nodes) ] @@ -110,9 +110,10 @@ def __init__( for i, worker in enumerate(self.remote_workers) ]) - ip = ray.services.get_node_ip_address() - port = utils.find_free_port() - address = "tcp://{ip}:{port}".format(ip=ip, port=port) + head_worker = self.remote_workers[0] + address = ray.get(head_worker.setup_address.remote()) + + logger.info(f"initializing pytorch process group on {address}") ray.get([ worker.setup_torch_distribute.remote(address, i, num_nodes) diff --git a/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py b/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py index f4bf90e6dc4..3ef03bf5887 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py @@ -47,6 +47,7 @@ from zoo.orca.learn.pytorch.constants import SCHEDULER_STEP, NUM_STEPS from zoo.orca.learn.pytorch.training_operator import TrainingOperator from zoo.orca.learn.pytorch import utils +from zoo.orca.learn.pytorch.utils import find_free_port logger = logging.getLogger(__name__) @@ -124,6 +125,11 @@ def setup_horovod(self): self.setup_components_horovod() self.setup_operator() + def setup_address(self): + ip = ray.services.get_node_ip_address() + port = find_free_port() + return f"tcp://{ip}:{port}" + def setup_torch_distribute(self, url, world_rank, world_size): import torch.distributed as dist dist.init_process_group(