From 515386ef3cacb44a2bcfab9d66eaee6143d94e95 Mon Sep 17 00:00:00 2001 From: Roy Date: Fri, 29 Mar 2024 06:01:55 +0800 Subject: [PATCH] [Core] Support multi-node inference(eager and cuda graph) (#3686) --- tests/distributed/test_comm_ops.py | 6 +++--- tests/distributed/test_custom_all_reduce.py | 4 ++-- vllm/executor/ray_gpu_executor.py | 2 -- vllm/model_executor/parallel_utils/pynccl.py | 18 ++++++++---------- .../parallel_utils/pynccl_utils.py | 4 +++- vllm/test_utils.py | 6 +++++- vllm/worker/worker.py | 7 ++++--- 7 files changed, 25 insertions(+), 22 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 0395f7200fd77..13916fc8c147b 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -24,7 +24,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, rank, distributed_init_port) num_elements = 8 all_tensors = [ @@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) @@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, rank, distributed_init_port) test_dict = { "a": torch.arange(8, dtype=torch.float32, device="cuda"), diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 1e6e7f89a528c..0bd3bf8837450 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -23,7 +23,7 @@ def graph_allreduce(world_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(1, world_size, rank, rank, distributed_init_port) custom_ar.init_custom_ar() @@ -58,7 +58,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(1, world_size, rank, rank, distributed_init_port) sz = 1024 diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4ac72bb0de34c..8f80c20738bba 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -188,8 +188,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", is_driver_worker=True, ) - # FIXME(woosuk): We are not properly initializing pynccl when - # we have multiple nodes. self._run_workers("init_device") self._run_workers( "load_model", diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 0eb75e02d62cf..968dd7e17d021 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -202,6 +202,7 @@ def __init__( init_method=None, timeout=datetime.timedelta(seconds=10), world_size: int = -1, + local_rank: int = -1, rank: int = -1, store=None, group_name: str = "", @@ -219,25 +220,22 @@ def __init__( store=store, group_name=group_name, pg_options=pg_options) - self.world_size = dist.get_world_size() - self.rank = dist.get_rank() - torch.cuda.set_device(self.rank) - if self.rank == 0: + torch.cuda.set_device(local_rank) + if rank == 0: self.unique_id = ncclGetUniqueId() else: self.unique_id = NcclUniqueId() - tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( - self.rank) + tensor = torch.ByteTensor(list( + self.unique_id.internal)).cuda(local_rank) dist.broadcast(tensor, src=0) byte_list = tensor.cpu().tolist() - self.unique_id = NcclUniqueId() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte self.comm = ctypes.c_void_p() - result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank) + result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size, + self.unique_id, rank) assert result == 0 - self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}") + self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}") def all_reduce(self, tensor: torch.Tensor, diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py index a12d620d7a24c..5b8ee4c4de598 100644 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ b/vllm/model_executor/parallel_utils/pynccl_utils.py @@ -36,11 +36,13 @@ def set_pynccl_stream(stream: torch.cuda.Stream): pass -def init_process_group(world_size: int, rank: int, init_method: str) -> None: +def init_process_group(world_size: int, local_rank: int, rank: int, + init_method: str) -> None: assert not is_initialized() global comm comm = NCCLCommunicator(init_method=init_method, world_size=world_size, + local_rank=local_rank, rank=rank) diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 5b2eeafad197e..735cc0037ba5f 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -8,6 +8,7 @@ def init_test_distributed_environment( pipeline_parallel_size: int, tensor_parallel_size: int, + local_rank: int, rank: int, distributed_init_port: str, ) -> None: @@ -16,7 +17,10 @@ def init_test_distributed_environment( worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - parallel_config, rank, distributed_init_method=distributed_init_method) + parallel_config, + local_rank, + rank, + distributed_init_method=distributed_init_method) def multi_process_tensor_parallel( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6459c0cda669a..4ffe780400101 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -97,8 +97,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + init_distributed_environment(self.parallel_config, self.local_rank, + self.rank, self.distributed_init_method) # Set random seed. set_random_seed(self.model_config.seed) @@ -249,6 +249,7 @@ def get_cache_block_size_bytes(self, block_size: int, def init_distributed_environment( parallel_config: ParallelConfig, + local_rank: int, rank: int, distributed_init_method: Optional[str] = None, ) -> None: @@ -282,9 +283,9 @@ def init_distributed_environment( elif parallel_config.world_size > 1: # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. - # TODO(woosuk): Support multi-node connection. pynccl_utils.init_process_group( world_size=parallel_config.world_size, + local_rank=local_rank, rank=rank, init_method=distributed_init_method, )