From f7db9eac52162c15be42d1af15662a56ac3b9342 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 9 Apr 2024 01:49:02 -0700 Subject: [PATCH] [Core] separate distributed_init from worker (#3904) --- .../parallel_utils/parallel_state.py | 63 ++++++++++++++++++- vllm/test_utils.py | 13 ++-- vllm/worker/cpu_worker.py | 28 +++------ vllm/worker/worker.py | 39 ++++-------- 4 files changed, 85 insertions(+), 58 deletions(-) diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index bcda5ebf8548b..3bbfa1bd5443a 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -4,6 +4,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" import contextlib +from typing import Optional import torch @@ -14,14 +15,59 @@ # Pipeline model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None +# when people blindly call `torch.distributed.all_reduce` etc, +# it will use this group. It is initialized with the `backend` +# parameter of `init_distributed_environment` below. +# Essentially, this is `torch.distributed.group.WORLD`. +# We leave a line here to note that this is device-specific. +# Note that this variable is not safe to use, because when users +# call `init_distributed_environment` first, and then destroy +# the process group themselves, this variable will keep a reference to the +# destroyed process group, which is not useful. +_DEVICE_WORLD_GROUP = None + +# duing `init_distributed_environment`, we will also initialize a +# group with `gloo` backend, to allow direct coordination between +# processes through the CPU. +_CPU_WORLD_GROUP = None + +# In summary, after calling `init_distributed_environment`, we will +# always have two groups: one for device-specific (and is the default) +# and one for CPU. All processes will be part of both groups. + # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None +def init_distributed_environment( + world_size: int, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + backend: str = "nccl", +): + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank) + global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP + _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD + ranks = list(range(torch.distributed.get_world_size())) + _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, + backend="gloo") + + def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, ) -> None: """ Initialize model parallel groups. @@ -48,6 +94,8 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend() if (world_size != tensor_model_parallel_size * pipeline_model_parallel_size): @@ -69,7 +117,7 @@ def initialize_model_parallel( for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - group = torch.distributed.new_group(ranks) + group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group @@ -80,7 +128,7 @@ def initialize_model_parallel( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) - group = torch.distributed.new_group(ranks) + group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: _PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks @@ -89,14 +137,17 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend() if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size) + pipeline_model_parallel_size, backend) return assert ( @@ -117,6 +168,12 @@ def model_parallel_is_initialized(): and _PIPELINE_MODEL_PARALLEL_GROUP is not None) +def get_cpu_world_group(): + """Get the CPU world group.""" + assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized") + return _CPU_WORLD_GROUP + + def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 17ea67b740a8c..6a3e77fd009f2 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -1,8 +1,8 @@ import ray -from vllm.config import ParallelConfig +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized, init_distributed_environment) from vllm.utils import get_open_port -from vllm.worker.worker import init_distributed_environment def init_test_distributed_environment( @@ -12,15 +12,14 @@ def init_test_distributed_environment( distributed_init_port: str, local_rank: int = -1, ) -> None: - parallel_config = ParallelConfig(pipeline_parallel_size, - tensor_parallel_size, - worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - parallel_config, - rank, + world_size=pipeline_parallel_size * tensor_parallel_size, + rank=rank, distributed_init_method=distributed_init_method, local_rank=local_rank) + ensure_model_parallel_initialized(tensor_parallel_size, + pipeline_parallel_size) def multi_process_tensor_parallel( diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 262ed9abd36b7..e1daa64346a9c 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized) + ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.model_runner import ModelRunner @@ -251,26 +251,12 @@ def init_distributed_environment(self) -> None: parallel_config = self.parallel_config rank = self.rank distributed_init_method = self.distributed_init_method - - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch " - "world size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - backend = "gloo" - torch.distributed.init_process_group( - backend=backend, - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cpu()) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 48facb57de190..bf0c6073ea9a9 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -15,7 +15,7 @@ broadcast_tensor_dict) from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized) + ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner @@ -97,9 +97,9 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) @@ -248,31 +248,15 @@ def get_cache_block_size_bytes(self, block_size: int, self.parallel_config) -def init_distributed_environment( +def init_worker_distributed_environment( parallel_config: ParallelConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch world " - "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - torch.distributed.init_process_group( - backend="nccl", - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) if pynccl_utils.is_initialized(): pynccl_world_size = pynccl_utils.get_world_size() @@ -291,10 +275,6 @@ def init_distributed_environment( init_method=distributed_init_method, ) - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) @@ -302,6 +282,11 @@ def init_distributed_environment( if not parallel_config.disable_custom_all_reduce: init_custom_ar() + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + if pynccl_utils.is_initialized(): + pynccl_utils.all_reduce(torch.zeros(1).cuda()) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype.