From 1b2ca66ede13d643cca60c09086d9cba4ff5fd0e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Sep 2024 22:57:57 -0700 Subject: [PATCH] [torch.compile] register allreduce operations as custom ops (#8526) --- .buildkite/test-pipeline.yaml | 10 +- csrc/custom_all_reduce.cu | 12 -- csrc/ops.h | 2 - csrc/torch_bindings.cpp | 5 - tests/compile/__init__.py | 0 tests/compile/test_full_graph.py | 15 ++- vllm/_custom_ops.py | 6 - .../device_communicators/custom_all_reduce.py | 21 +++- vllm/distributed/parallel_state.py | 116 +++++++++++++++--- 9 files changed, 137 insertions(+), 50 deletions(-) create mode 100644 tests/compile/__init__.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9b0cb6663a55b..9483adcc5d587 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -163,13 +163,6 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py -- label: torch compile integration test - source_file_dependencies: - - vllm/ - commands: - - pytest -v -s ./compile/test_full_graph.py - - pytest -v -s ./compile/test_wrapper.py - - label: Prefix Caching Test # 7min #mirror_hardwares: [amd] source_file_dependencies: @@ -348,7 +341,10 @@ steps: - vllm/executor/ - vllm/model_executor/models/ - tests/distributed/ + - vllm/compilation commands: + - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus # Avoid importing model tests that cause CUDA reinitialization error diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 82a3563979f16..9b82bec44c3c6 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink) { - auto inp_size = inp.numel() * inp.element_size(); - // custom allreduce requires input byte size to be multiples of 16 - if (inp_size % 16 != 0) return false; - if (!_is_weak_contiguous(inp)) return false; - if (world_size == 2 || full_nvlink) return inp_size <= max_size; - // for 4 or more non NVLink-capable GPUs, custom allreduce provides little - // performance improvement over NCCL. - return false; -} - void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { auto fa = reinterpret_cast(_fa); diff --git a/csrc/ops.h b/csrc/ops.h index 681ab4b898ca3..ee89ad32cb025 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, const std::vector& offsets, int64_t rank, bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d7f7547fbef55..7009180a8687c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { "bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - custom_ar.def( - "should_custom_ar(Tensor inp, int max_size, int world_size, " - "bool full_nvlink) -> bool"); - custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); - custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); diff --git a/tests/compile/__init__.py b/tests/compile/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 5452ce6be8110..6fc445539bbbe 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -2,9 +2,20 @@ import pytest +from vllm.utils import cuda_device_count_stateless + +from ..utils import fork_new_process_for_each_test + @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_full_graph(model): +@pytest.mark.parametrize("tp_size", [1, 2]) +@fork_new_process_for_each_test +def test_full_graph(model, tp_size): + + # Skip the test if there are not enough CUDA devices. + if cuda_device_count_stateless() < tp_size: + pytest.skip("Not enough CUDA devices for the test.") + # make sure these models can be captured in full graph mode if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" @@ -17,7 +28,7 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model=model, enforce_eager=True) + llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size) outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 065fec6519ef7..a71b0a59e7269 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, offsets, rank, full_nvlink) -def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, - full_nvlink: bool) -> bool: - return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, - full_nvlink) - - def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 6229f1d6ec788..d239d645edc14 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool: return True +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] @@ -224,8 +230,19 @@ def register_graph_buffers(self): ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - return ops.should_custom_ar(inp, self.max_size, self.world_size, - self.full_nvlink) + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False # all reduce, assuming inp tensor is IPC registered with register_buffer, # or, in the context of cuda graphs, register_graph_buffers diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6755b20eec9bb..1c864bcd5d708 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -21,11 +21,12 @@ """ import contextlib import pickle +import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch @@ -69,6 +70,58 @@ def _split_tensor_dict( return metadata_list, tensor_list +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + # looks like Python 3.8 does not understand `ReferenceType` + _groups[group.unique_name] = weakref.ref(group) # type: ignore + + +@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) +def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce(tensor) + + +@inplace_all_reduce.register_fake +def _(tensor: torch.Tensor, group_name: str) -> None: + return + + +@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) +def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce(tensor) + + +@outplace_all_reduce.register_fake +def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + class GroupCoordinator: """ PyTorch ProcessGroup wrapper for a group of processes. @@ -111,7 +164,11 @@ def __init__( use_custom_allreduce: bool, use_tpu_communicator: bool, use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) self.rank = torch.distributed.get_rank() self.local_rank = local_rank @@ -149,28 +206,24 @@ def __init__( from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - self.pynccl_comm: Optional[PyNcclCommunicator] + self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, ) - else: - self.pynccl_comm = None - self.ca_comm: Optional[CustomAllreduce] + self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, ) - else: - self.ca_comm = None from vllm.distributed.device_communicators.tpu_communicator import ( TpuCommunicator) - self.tpu_communicator: Optional[TpuCommunicator] + self.tpu_communicator: Optional[TpuCommunicator] = None if use_tpu_communicator and self.world_size > 1: self.tpu_communicator = TpuCommunicator(group=self.cpu_group) @@ -264,16 +317,46 @@ def graph_capture( def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we need to figure out if the op is + in-place or out-of-place ahead of time. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if self.tpu_communicator is not None and \ + not self.tpu_communicator.disabled: + # TPU handles Dynamo with its own logic. + return self._all_reduce(input_) + + if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): + return torch.ops.vllm.outplace_all_reduce( + input_, group_name=self.unique_name) + else: + torch.ops.vllm.inplace_all_reduce(input_, + group_name=self.unique_name) + return input_ + + def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + The actual all-reduce implementation. + NOTE: This operation will be applied in-place or out-of-place. Always assume this function modifies its input, but use the return value as the output. """ ca_comm = self.ca_comm - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - # For TPUs, use TPU communicator. tpu_comm = self.tpu_communicator if tpu_comm is not None and not tpu_comm.disabled: @@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int, use_pynccl=False, use_custom_allreduce=False, use_tpu_communicator=False, + group_name="world", ) @@ -767,6 +851,7 @@ def init_model_parallel_group( backend: str, use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE @@ -778,6 +863,7 @@ def init_model_parallel_group( use_custom_allreduce=use_custom_allreduce, use_tpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, ) @@ -931,7 +1017,8 @@ def initialize_model_parallel( _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_message_queue_broadcaster=True) + use_message_queue_broadcaster=True, + group_name="tp") # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -947,7 +1034,8 @@ def initialize_model_parallel( _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_custom_allreduce=False) + use_custom_allreduce=False, + group_name="pp") def ensure_model_parallel_initialized(