Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Core][Distributed] enable multiple tp group (vllm-project#4512)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhuohan Li <[email protected]>
  • Loading branch information
2 people authored and robertgshaw2-neuralmagic committed May 6, 2024
1 parent 2017aaf commit 27f0c2b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
11 changes: 8 additions & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,24 @@ steps:
- label: Distributed Comm Ops Test
command: pytest -v -s test_comm_ops.py
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
num_gpus: 2

- label: Distributed Tests
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
num_gpus: 2
commands:
- pytest -v -s test_pynccl.py
- pytest -v -s test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py

- label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 4
commands:
- pytest -v -s test_pynccl.py

- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py

Expand Down
3 changes: 3 additions & 0 deletions .buildkite/test-template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ steps:
plugins:
- kubernetes:
podSpec:
{% if step.num_gpus %}
priorityClassName: gpu-priority-cls-{{ step.num_gpus }}
{% endif %}
volumes:
- name: dshm
emptyDir:
Expand Down
28 changes: 28 additions & 0 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ def test_pynccl():
distributed_run(worker_fn, 2)


@worker_fn_wrapper
def multiple_tp_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
comm = NCCLCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
comm.all_reduce(tensor)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_multiple_tp():
distributed_run(worker_fn, 4)


@worker_fn_wrapper
def worker_fn_with_cudagraph():
with torch.no_grad():
Expand Down
5 changes: 4 additions & 1 deletion vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,17 @@ def __init__(
assert dist.get_backend(group) != dist.Backend.NCCL, (
"NCCLCommunicator should be attached to a non-NCCL group.")
self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
dist.broadcast(tensor, src=0, group=group)
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
Expand Down

0 comments on commit 27f0c2b

Please sign in to comment.