Skip to content

Commit

Permalink
[Core] separate distributed_init from worker (vllm-project#3904)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and joerunde committed Apr 11, 2024
1 parent fff0da9 commit cee255f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 58 deletions.
63 changes: 60 additions & 3 deletions vllm/model_executor/parallel_utils/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
from typing import Optional

import torch

Expand All @@ -14,14 +15,59 @@
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None

# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None

# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None

# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.

# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None


def init_distributed_environment(
world_size: int,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
backend: str = "nccl",
):
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment")
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank)
global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
_DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")


def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Expand All @@ -48,6 +94,8 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()

if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size):
Expand All @@ -69,7 +117,7 @@ def initialize_model_parallel(
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group

Expand All @@ -80,7 +128,7 @@ def initialize_model_parallel(
"pipeline model parallel group is already initialized")
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
Expand All @@ -89,14 +137,17 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size)
pipeline_model_parallel_size, backend)
return

assert (
Expand All @@ -117,6 +168,12 @@ def model_parallel_is_initialized():
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)


def get_cpu_world_group():
"""Get the CPU world group."""
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
return _CPU_WORLD_GROUP


def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
Expand Down
13 changes: 6 additions & 7 deletions vllm/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ray

from vllm.config import ParallelConfig
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized, init_distributed_environment)
from vllm.utils import get_open_port
from vllm.worker.worker import init_distributed_environment


def init_test_distributed_environment(
Expand All @@ -12,15 +12,14 @@ def init_test_distributed_environment(
distributed_init_port: str,
local_rank: int = -1,
) -> None:
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
parallel_config,
rank,
world_size=pipeline_parallel_size * tensor_parallel_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank)
ensure_model_parallel_initialized(tensor_parallel_size,
pipeline_parallel_size)


def multi_process_tensor_parallel(
Expand Down
28 changes: 7 additions & 21 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
ensure_model_parallel_initialized, init_distributed_environment)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.model_runner import ModelRunner
Expand Down Expand Up @@ -251,26 +251,12 @@ def init_distributed_environment(self) -> None:
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method

if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
backend = "gloo"
torch.distributed.init_process_group(
backend=backend,
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)

# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
Expand Down
39 changes: 12 additions & 27 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
ensure_model_parallel_initialized, init_distributed_environment)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
Expand Down Expand Up @@ -97,9 +97,9 @@ def init_device(self) -> None:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)

Expand Down Expand Up @@ -248,31 +248,15 @@ def get_cache_block_size_bytes(self, block_size: int,
self.parallel_config)


def init_distributed_environment(
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)

if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size()
Expand All @@ -291,17 +275,18 @@ def init_distributed_environment(
init_method=distributed_init_method,
)

# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if pynccl_utils.is_initialized():
pynccl_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
init_custom_ar()

# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if pynccl_utils.is_initialized():
pynccl_utils.all_reduce(torch.zeros(1).cuda())


def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
Expand Down

0 comments on commit cee255f

Please sign in to comment.