Skip to content

Commit

Permalink
[Core][Distributed] refactor custom allreduce to support multiple tp …
Browse files Browse the repository at this point in the history
…groups (vllm-project#4754)
  • Loading branch information
youkaichao authored May 13, 2024
1 parent 0e2617f commit 296cba6
Show file tree
Hide file tree
Showing 10 changed files with 327 additions and 226 deletions.
22 changes: 11 additions & 11 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@


@ray.remote(num_gpus=1, max_calls=1)
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tensor_parallel_size)
(r + 1) for r in range(tp_size)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank]
Expand All @@ -38,15 +38,15 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,


@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
Expand All @@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
all_tensors = [
torch.arange(total_size, dtype=torch.float32,
device="cuda").reshape(tensor_size) * (r + 1)
for r in range(tensor_parallel_size)
for r in range(tp_size)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank]
Expand All @@ -66,15 +66,15 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,


@ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
test_dict = {
# device tensor
Expand Down Expand Up @@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target", [
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
multi_process_tensor_parallel(tensor_parallel_size, test_target)
def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tp_size, 1, test_target)
87 changes: 58 additions & 29 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import torch
import torch.distributed as dist

from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators import custom_all_reduce
from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_ca_communicator)
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)

Expand All @@ -18,17 +20,36 @@


@ray.remote(num_gpus=1, max_calls=1)
def graph_allreduce(world_size, rank, distributed_init_port):
def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank,
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

custom_all_reduce.init_custom_ar()
group = get_tensor_model_parallel_group()

# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
del data

# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication = rank // tp_size + 1

for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_all_reduce.capture():
with graph_capture():
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
Expand All @@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port):
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2)
for i in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1, group=group)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
assert torch.allclose(out1, inp1)
assert torch.allclose(out2, inp2)


@ray.remote(num_gpus=1, max_calls=1)
def eager_allreduce(world_size, rank, distributed_init_port):
def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank,
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication = rank // tp_size + 1
sz = 1024
custom_all_reduce.init_custom_ar()
fa = custom_all_reduce.get_handle()
fa = get_tp_ca_communicator()
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))

inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
multi_process_tensor_parallel(tensor_parallel_size, test_target)


if __name__ == "__main__":
multi_process_tensor_parallel(2, graph_allreduce)
def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
4 changes: 2 additions & 2 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from vllm.distributed.communication_op import ( # noqa
graph_capture_mode, tensor_model_parallel_all_reduce)
graph_mode, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
Expand Down Expand Up @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_capture_mode():
with graph_mode():
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
Expand Down
45 changes: 34 additions & 11 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -9,12 +9,13 @@
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_ca_communicator,
get_tp_pynccl_communicator)


@contextmanager
def graph_capture_mode():
# In graph capture, we have to be very careful about the collective
def graph_mode():
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
Expand All @@ -24,10 +25,32 @@ def graph_capture_mode():
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
assert pynccl_comm is not None
with pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream()):
if pynccl_comm is None:
context = nullcontext()
else:
context = pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream())
with context:
yield


@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should include the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed.
"""
ca_comm = get_tp_ca_communicator()
context = nullcontext() if ca_comm is None else ca_comm.capture()
with context:
yield


Expand All @@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
ca_comm = get_tp_ca_communicator()

# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
out = custom_all_reduce(input_)
if out is not None:
return out
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out
pynccl_comm = get_tp_pynccl_communicator()
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
Expand Down
Loading

0 comments on commit 296cba6

Please sign in to comment.