Skip to content

Commit

Permalink
[core][distributed] add stateless_init_process_group (vllm-project#10072
Browse files Browse the repository at this point in the history
)

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
  • Loading branch information
youkaichao authored and JC1DA committed Nov 11, 2024
1 parent acd46ed commit f9026cc
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ steps:
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
commands:
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
Expand Down Expand Up @@ -431,7 +432,6 @@ steps:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
Expand Down
75 changes: 73 additions & 2 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import pytest
import ray
import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.distributed.utils import stateless_init_process_group
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)

from ..utils import multi_gpu_test


@ray.remote
class _CUDADeviceCountStatelessTestActor:
Expand All @@ -24,10 +30,75 @@ def test_cuda_device_count_stateless():
CUDA_VISIBLE_DEVICES is changed."""
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
num_gpus=2).remote()
assert sorted(ray.get(
actor.get_cuda_visible_devices.remote()).split(",")) == ["0", "1"]
assert len(
sorted(ray.get(
actor.get_cuda_visible_devices.remote()).split(","))) == 2
assert ray.get(actor.get_count.remote()) == 2
ray.get(actor.set_cuda_visible_devices.remote("0"))
assert ray.get(actor.get_count.remote()) == 1
ray.get(actor.set_cuda_visible_devices.remote(""))
assert ray.get(actor.get_count.remote()) == 0


def cpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=WORLD_SIZE,
backend="gloo")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=3,
backend="gloo")
data = torch.tensor([rank])
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
assert item == 6
else:
assert item == 18


def gpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
rank=rank,
world_size=WORLD_SIZE,
backend="nccl")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3,
backend="nccl")
torch.cuda.set_device(rank)
data = torch.tensor([rank]).cuda()
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
assert item == 6
else:
assert item == 18


@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
def test_stateless_init_process_group(worker):
WORLD_SIZE = 4
from multiprocessing import get_context
ctx = get_context("fork")
processes = []
for i in range(WORLD_SIZE):
rank = i
processes.append(ctx.Process(target=worker, args=(rank, WORLD_SIZE)))
for p in processes:
p.start()
for p in processes:
p.join()
for p in processes:
assert not p.exitcode
print("All processes finished.")
73 changes: 73 additions & 0 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from typing import Sequence, Tuple

import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -84,3 +89,71 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
end_layer = num_hidden_layers

return (start_layer, end_layer)


def stateless_init_process_group(init_method: str, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `ProcessGroup` object that can be used
for collective communication. With this function, process A and process B
can call `stateless_init_process_group` to form a group, and then process A, B,
C, and D can call `stateless_init_process_group` to form another group.
""" # noqa

backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)

store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)

group_rank = rank
group_size = world_size

# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)

pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)

pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
pg_options,
)

if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL

backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout

backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")

backend_class._set_sequence_number_for_group()

pg._register_backend(device, backend_type, backend_class)

return pg

0 comments on commit f9026cc

Please sign in to comment.