From 092a64839f008346612c962a22b025c24c588c03 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Wed, 28 Oct 2020 23:30:21 +0800 Subject: [PATCH] Fix orca ray pytorch example (#3007) * fix horovod pytorch exampe * fix bug * fix process group * fix style * fix tests * fix test * fix tests * revert ray context change --- .../bigdl/orca/learn/pytorch/pytorch_ray_estimator.py | 9 +++++---- python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py | 6 ++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py index 4487d11ca1e..e390d52fa1a 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py @@ -101,7 +101,7 @@ def __init__( if backend == "pytorch": cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node num_nodes = ray_ctx.num_ray_nodes * workers_per_node - RemoteRunner = ray.remote(num_cpus=1)(TorchRunner) + RemoteRunner = ray.remote(num_cpus=cores_per_node)(TorchRunner) self.remote_workers = [ RemoteRunner.remote(**params) for i in range(num_nodes) ] @@ -110,9 +110,10 @@ def __init__( for i, worker in enumerate(self.remote_workers) ]) - ip = ray.services.get_node_ip_address() - port = utils.find_free_port() - address = "tcp://{ip}:{port}".format(ip=ip, port=port) + head_worker = self.remote_workers[0] + address = ray.get(head_worker.setup_address.remote()) + + logger.info(f"initializing pytorch process group on {address}") ray.get([ worker.setup_torch_distribute.remote(address, i, num_nodes) diff --git a/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py b/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py index f4bf90e6dc4..3ef03bf5887 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py @@ -47,6 +47,7 @@ from zoo.orca.learn.pytorch.constants import SCHEDULER_STEP, NUM_STEPS from zoo.orca.learn.pytorch.training_operator import TrainingOperator from zoo.orca.learn.pytorch import utils +from zoo.orca.learn.pytorch.utils import find_free_port logger = logging.getLogger(__name__) @@ -124,6 +125,11 @@ def setup_horovod(self): self.setup_components_horovod() self.setup_operator() + def setup_address(self): + ip = ray.services.get_node_ip_address() + port = find_free_port() + return f"tcp://{ip}:{port}" + def setup_torch_distribute(self, url, world_rank, world_size): import torch.distributed as dist dist.init_process_group(