From 96e0c9cbbd65ad0b8ad20611b90bcc86a8559aae Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 21:56:09 -0700 Subject: [PATCH] [torch.compile] directly register custom op (#9896) Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 20 ++++-- tests/compile/piecewise/test_toy_llama.py | 20 ++++-- vllm/attention/backends/flash_attn.py | 16 +++-- vllm/attention/backends/flashinfer.py | 17 +++-- vllm/distributed/parallel_state.py | 34 +++++++--- .../layers/fused_moe/fused_marlin_moe.py | 25 +++++-- .../layers/fused_moe/fused_moe.py | 68 +++++++++++-------- vllm/utils.py | 45 ++++++++++++ vllm/v1/attention/backends/flash_attn.py | 14 ++-- 9 files changed, 192 insertions(+), 67 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index a34d33efba1d8..d151d62516b07 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -6,18 +6,22 @@ import torch from torch import nn +from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.levels import CompilationLevel +from vllm.utils import direct_register_custom_op os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) global_counter = 0 +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + -@torch.library.custom_op("silly::attention", mutates_args=["out"]) def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: global global_counter @@ -27,12 +31,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out[0] += 1 -@silly_attention.register_fake -def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: return +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + @support_torch_compile class SillyModel(nn.Module): diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index db6a983d70feb..e3e5a7d0fc5a5 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,6 +8,7 @@ import torch from torch import nn +from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.config import CompilationConfig @@ -15,9 +16,12 @@ from vllm.compilation.decorators import support_torch_compile from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_compilation_config +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa -@torch.library.custom_op("silly::attention", mutates_args=["out"]) def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: out.copy_(q) @@ -25,12 +29,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out += v -@silly_attention.register_fake -def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: return +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + @dataclass class LlamaConfig: hidden_size: int = 128 diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ffa05e80623ac..c294fcf7f08fe 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,7 +14,8 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.forward_context import get_forward_context -from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.utils import (async_tensor_h2d, direct_register_custom_op, + make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -595,8 +596,6 @@ def forward( return output -@torch.library.custom_op("vllm::unified_flash_attention", - mutates_args=["kv_cache"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -755,8 +754,7 @@ def unified_flash_attention( return output.view(num_tokens, hidden_size) -@unified_flash_attention.register_fake -def _( +def unified_flash_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -773,3 +771,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query) + + +direct_register_custom_op( + op_name="unified_flash_attention", + op_func=unified_flash_attention, + mutates_args=["kv_cache"], + fake_impl=unified_flash_attention_fake, +) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5ea101ae0432f..234c87d5c4edb 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -28,8 +28,8 @@ is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention from vllm.forward_context import get_forward_context -from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, - make_tensor_with_pad) +from vllm.utils import (async_tensor_h2d, direct_register_custom_op, + get_kv_cache_torch_dtype, make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -785,8 +785,6 @@ def forward( ) -@torch.library.custom_op("vllm::unified_flash_infer", - mutates_args=["kv_cache"]) def unified_flash_infer( query: torch.Tensor, key: torch.Tensor, @@ -906,8 +904,7 @@ def unified_flash_infer( return output.view(num_tokens, hidden_size) -@unified_flash_infer.register_fake -def _( +def unified_flash_infer_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -924,3 +921,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query).contiguous() + + +direct_register_custom_op( + op_name="unified_flash_infer", + op_func=unified_flash_infer, + mutates_args=["kv_cache"], + fake_impl=unified_flash_infer_fake, +) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b04bbc478534c..94ba41a016f6d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,7 +37,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import supports_custom_op +from vllm.utils import direct_register_custom_op, supports_custom_op @dataclass @@ -99,8 +99,6 @@ def _register_group(group: "GroupCoordinator") -> None: if supports_custom_op(): - @torch.library.custom_op("vllm::inplace_all_reduce", - mutates_args=["tensor"]) def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() @@ -108,11 +106,16 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: raise ValueError(f"Group {group_name} is destroyed.") group._all_reduce_in_place(tensor) - @inplace_all_reduce.register_fake - def _(tensor: torch.Tensor, group_name: str) -> None: + def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: return - @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) + direct_register_custom_op( + op_name="inplace_all_reduce", + op_func=inplace_all_reduce, + mutates_args=["tensor"], + fake_impl=inplace_all_reduce_fake, + ) + def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." @@ -121,10 +124,17 @@ def outplace_all_reduce(tensor: torch.Tensor, raise ValueError(f"Group {group_name} is destroyed.") return group._all_reduce_out_place(tensor) - @outplace_all_reduce.register_fake - def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + def outplace_all_reduce_fake(tensor: torch.Tensor, + group_name: str) -> torch.Tensor: return torch.empty_like(tensor) + direct_register_custom_op( + op_name="outplace_all_reduce", + op_func=outplace_all_reduce, + mutates_args=[], + fake_impl=outplace_all_reduce_fake, + ) + class GroupCoordinator: """ @@ -338,6 +348,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ + if input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) + return input_ + if not supports_custom_op(): self._all_reduce_in_place(input_) return input_ @@ -369,9 +384,6 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) - elif input_.is_cpu: - import intel_extension_for_pytorch as ipex - ipex.distributed.all_reduce(input_, group=self.device_group) else: torch.distributed.all_reduce(input_, group=self.device_group) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 93019d0d0abb6..4741d69de11ac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -8,6 +8,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) from vllm.scalar_type import scalar_types +from vllm.utils import direct_register_custom_op def get_scalar_type(num_bits: int, has_zp: bool): @@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 -@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[]) def single_marlin_moe( hidden_states: torch.Tensor, w: torch.Tensor, @@ -119,8 +119,7 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) -@single_marlin_moe.register_fake -def _( +def single_marlin_moe_fake( hidden_states: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, @@ -136,7 +135,14 @@ def _( return torch.empty_like(hidden_states) -@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[]) +direct_register_custom_op( + op_name="single_marlin_moe", + op_func=single_marlin_moe, + mutates_args=[], + fake_impl=single_marlin_moe_fake, +) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -324,8 +330,7 @@ def fused_marlin_moe( dim=1) -@fused_marlin_moe.register_fake -def _( +def fused_marlin_moe_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -344,3 +349,11 @@ def _( is_k_full: bool = True, ) -> torch.Tensor: return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="fused_marlin_moe", + op_func=fused_marlin_moe, + mutates_args=[], + fake_impl=fused_marlin_moe_fake, +) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1cf5c2253ca0b..340da32263c1c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -12,6 +12,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -466,8 +467,6 @@ def get_config_dtype_str(dtype: torch.dtype, return None -@torch.library.custom_op("vllm::inplace_fused_experts", - mutates_args=["hidden_states"]) def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -484,22 +483,29 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a1_scale, a2_scale) -@inplace_fused_experts.register_fake -def _(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> None: +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> None: pass -@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[]) +direct_register_custom_op( + op_name="inplace_fused_experts", + op_func=inplace_fused_experts, + mutates_args=["hidden_states"], + fake_impl=inplace_fused_experts_fake, +) + + def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -517,21 +523,29 @@ def outplace_fused_experts( w2_scale, a1_scale, a2_scale) -@outplace_fused_experts.register_fake -def _(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: +def outplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.empty_like(hidden_states) +direct_register_custom_op( + op_name="outplace_fused_experts", + op_func=outplace_fused_experts, + mutates_args=[], + fake_impl=outplace_fused_experts_fake, +) + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, diff --git a/vllm/utils.py b/vllm/utils.py index 03cdbe6a0dc7b..5488719cc99b0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -32,6 +32,7 @@ import torch.types import yaml from packaging.version import Version +from torch.library import Library from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs @@ -1512,3 +1513,47 @@ def weak_ref_tensors( if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) raise ValueError("Invalid type for tensors") + + +def is_in_doc_build() -> bool: + try: + from sphinx.ext.autodoc.mock import _MockModule + return isinstance(torch, _MockModule) + except ModuleNotFoundError: + return False + + +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: List[str], + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ + if is_in_doc_build(): + return + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + my_lib = target_lib or vllm_lib + my_lib.define(op_name + schema_str) + my_lib.impl(op_name, op_func, "CUDA") + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ec07464e6a12a..b2af89ebf854a 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -7,6 +7,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.forward_context import get_forward_context +from vllm.utils import direct_register_custom_op from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -152,8 +153,6 @@ def forward( return output -@torch.library.custom_op("vllm::unified_flash_attention", - mutates_args=["kv_cache"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -217,8 +216,7 @@ def unified_flash_attention( return output.view(num_tokens, hidden_size) -@unified_flash_attention.register_fake -def _( +def unified_flash_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -235,3 +233,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query) + + +direct_register_custom_op( + op_name="unified_flash_attention", + op_func=unified_flash_attention, + mutates_args=["kv_cache"], + fake_impl=unified_flash_attention_fake, +)