diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3e940549862ea..705e81d15ad65 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -120,6 +120,7 @@ steps: - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile commands: + - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py @@ -431,7 +432,6 @@ steps: - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py - label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index a51a9909f6f41..3c7facc12c59a 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -1,9 +1,15 @@ +import pytest import ray +import torch +import torch.distributed as dist import vllm.envs as envs +from vllm.distributed.utils import stateless_init_process_group from vllm.utils import (cuda_device_count_stateless, update_environment_variables) +from ..utils import multi_gpu_test + @ray.remote class _CUDADeviceCountStatelessTestActor: @@ -24,10 +30,75 @@ def test_cuda_device_count_stateless(): CUDA_VISIBLE_DEVICES is changed.""" actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore num_gpus=2).remote() - assert sorted(ray.get( - actor.get_cuda_visible_devices.remote()).split(",")) == ["0", "1"] + assert len( + sorted(ray.get( + actor.get_cuda_visible_devices.remote()).split(","))) == 2 assert ray.get(actor.get_count.remote()) == 2 ray.get(actor.set_cuda_visible_devices.remote("0")) assert ray.get(actor.get_count.remote()) == 1 ray.get(actor.set_cuda_visible_devices.remote("")) assert ray.get(actor.get_count.remote()) == 0 + + +def cpu_worker(rank, WORLD_SIZE): + pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500", + rank=rank, + world_size=WORLD_SIZE, + backend="gloo") + if rank <= 2: + pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501", + rank=rank, + world_size=3, + backend="gloo") + data = torch.tensor([rank]) + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) + if rank <= 2: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) + item = data[0].item() + print(f"rank: {rank}, item: {item}") + if rank == 3: + assert item == 6 + else: + assert item == 18 + + +def gpu_worker(rank, WORLD_SIZE): + pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502", + rank=rank, + world_size=WORLD_SIZE, + backend="nccl") + if rank <= 2: + pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503", + rank=rank, + world_size=3, + backend="nccl") + torch.cuda.set_device(rank) + data = torch.tensor([rank]).cuda() + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) + if rank <= 2: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) + item = data[0].item() + print(f"rank: {rank}, item: {item}") + if rank == 3: + assert item == 6 + else: + assert item == 18 + + +@multi_gpu_test(num_gpus=4) +@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker]) +def test_stateless_init_process_group(worker): + WORLD_SIZE = 4 + from multiprocessing import get_context + ctx = get_context("fork") + processes = [] + for i in range(WORLD_SIZE): + rank = i + processes.append(ctx.Process(target=worker, args=(rank, WORLD_SIZE))) + for p in processes: + p.start() + for p in processes: + p.join() + for p in processes: + assert not p.exitcode + print("All processes finished.") diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 8c94ef8cb10ce..d24ce898707fc 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -5,6 +5,11 @@ from typing import Sequence, Tuple import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + is_nccl_available) +from torch.distributed.rendezvous import rendezvous import vllm.envs as envs from vllm.logger import init_logger @@ -84,3 +89,71 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, end_layer = num_hidden_layers return (start_layer, end_layer) + + +def stateless_init_process_group(init_method: str, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `ProcessGroup` object that can be used + for collective communication. With this function, process A and process B + can call `stateless_init_process_group` to form a group, and then process A, B, + C, and D can call `stateless_init_process_group` to form another group. + """ # noqa + + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + pg_options = ProcessGroup.Options(backend=backend, timeout=timeout) + + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + pg_options, + ) + + if backend == "gloo": + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + elif backend == "nccl": + assert is_nccl_available() + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + + return pg