From 29e77d9bd8848a5a18220289b892c01ce70403cf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:01:21 -0700 Subject: [PATCH 01/25] fix register of fused moe Signed-off-by: youkaichao --- .../layers/fused_moe/fused_moe.py | 68 +++++++++++-------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1cf5c2253ca0b..4cd0ef1df0377 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -7,6 +7,7 @@ import torch import triton import triton.language as tl +from torch.library import Library import vllm.envs as envs from vllm import _custom_ops as ops @@ -466,8 +467,6 @@ def get_config_dtype_str(dtype: torch.dtype, return None -@torch.library.custom_op("vllm::inplace_fused_experts", - mutates_args=["hidden_states"]) def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -484,22 +483,29 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a1_scale, a2_scale) -@inplace_fused_experts.register_fake -def _(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> None: +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> None: pass -@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[]) +my_lib = Library("vllm", "FRAGMENT") +my_lib.define( + "inplace_fused_experts(Tensor(a0!) hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool use_fp8_w8a8=False, bool use_int8_w8a16=False, Tensor? w1_scale=None, Tensor? w2_scale=None, Tensor? a1_scale=None, Tensor? a2_scale=None) -> ()" # noqa +) +my_lib.impl("inplace_fused_experts", inplace_fused_experts, "CUDA") +my_lib._register_fake("inplace_fused_experts", inplace_fused_experts_fake) + + def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -517,21 +523,29 @@ def outplace_fused_experts( w2_scale, a1_scale, a2_scale) -@outplace_fused_experts.register_fake -def _(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: +def outplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.empty_like(hidden_states) +my_lib = Library("vllm", "FRAGMENT") +my_lib.define( + "outplace_fused_experts(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool use_fp8_w8a8=False, bool use_int8_w8a16=False, Tensor? w1_scale=None, Tensor? w2_scale=None, Tensor? a1_scale=None, Tensor? a2_scale=None) -> Tensor" # noqa +) +my_lib.impl("outplace_fused_experts", outplace_fused_experts, "CUDA") +my_lib._register_fake("outplace_fused_experts", outplace_fused_experts_fake) + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, From ba4f3aded581087e84ffaa1468a970f20ead4ba4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:05:16 -0700 Subject: [PATCH 02/25] fix marlin Signed-off-by: youkaichao --- .../layers/fused_moe/fused_marlin_moe.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 93019d0d0abb6..beb72b15ea3dc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from torch.library import Library from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( @@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 -@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[]) def single_marlin_moe( hidden_states: torch.Tensor, w: torch.Tensor, @@ -119,8 +119,7 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) -@single_marlin_moe.register_fake -def _( +def single_marlin_moe_fake( hidden_states: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, @@ -136,7 +135,14 @@ def _( return torch.empty_like(hidden_states) -@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[]) +my_lib = Library("vllm", "FRAGMENT") +my_lib.define( + "single_marlin_moe(Tensor hidden_states, Tensor w, Tensor scales, Tensor gating_output, SymInt topk, bool renormalize, Tensor? g_idx=None, Tensor? sort_indices=None, Tensor? w_zeros=None, SymInt num_bits=8, bool is_k_full=True) -> Tensor" # noqa +) +my_lib.impl("single_marlin_moe", single_marlin_moe, "CUDA") +my_lib._register_fake("single_marlin_moe", single_marlin_moe_fake) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -324,8 +330,7 @@ def fused_marlin_moe( dim=1) -@fused_marlin_moe.register_fake -def _( +def fused_marlin_moe_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -344,3 +349,11 @@ def _( is_k_full: bool = True, ) -> torch.Tensor: return torch.empty_like(hidden_states) + + +my_lib = Library("vllm", "FRAGMENT") +my_lib.define( + "fused_marlin_moe(Tensor hidden_states, Tensor w1, Tensor w2, Tensor w1_scale, Tensor w2_scale, Tensor gating_output, Tensor topk_weights, Tensor topk_ids, Tensor? g_idx1=None, Tensor? g_idx2=None, Tensor? sort_indices1=None, Tensor? sort_indices2=None, Tensor? w1_zeros=None, Tensor? w2_zeros=None, SymInt num_bits=8, bool is_k_full=True) -> Tensor" # noqa +) +my_lib.impl("fused_marlin_moe", fused_marlin_moe, "CUDA") +my_lib._register_fake("fused_marlin_moe", fused_marlin_moe_fake) From 19b4bff4a219355807f0ed934d48aed5e49a485d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:08:50 -0700 Subject: [PATCH 03/25] parallel state Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b04bbc478534c..2c323a916e211 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -33,6 +33,7 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +from torch.library import Library import vllm.envs as envs from vllm.logger import init_logger @@ -99,8 +100,6 @@ def _register_group(group: "GroupCoordinator") -> None: if supports_custom_op(): - @torch.library.custom_op("vllm::inplace_all_reduce", - mutates_args=["tensor"]) def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() @@ -108,11 +107,16 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: raise ValueError(f"Group {group_name} is destroyed.") group._all_reduce_in_place(tensor) - @inplace_all_reduce.register_fake - def _(tensor: torch.Tensor, group_name: str) -> None: + def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: return - @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) + my_lib = Library("vllm", "FRAGMENT") + my_lib.define( + "inplace_all_reduce(Tensor(a0!) tensor, str group_name) -> ()" # noqa + ) + my_lib.impl("inplace_all_reduce", inplace_all_reduce, "CUDA") + my_lib._register_fake("inplace_all_reduce", inplace_all_reduce_fake) + def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." @@ -121,10 +125,17 @@ def outplace_all_reduce(tensor: torch.Tensor, raise ValueError(f"Group {group_name} is destroyed.") return group._all_reduce_out_place(tensor) - @outplace_all_reduce.register_fake - def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + def outplace_all_reduce_fake(tensor: torch.Tensor, + group_name: str) -> torch.Tensor: return torch.empty_like(tensor) + my_lib = Library("vllm", "FRAGMENT") + my_lib.define( + "outplace_all_reduce(Tensor tensor, str group_name) -> Tensor" # noqa + ) + my_lib.impl("outplace_all_reduce", outplace_all_reduce, "CUDA") + my_lib._register_fake("outplace_all_reduce", outplace_all_reduce_fake) + class GroupCoordinator: """ From a09f3cb9c52afdecb40d6657d8cadc59d75dfe88 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:11:34 -0700 Subject: [PATCH 04/25] fix parallel state Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2c323a916e211..a7d4a25ae0759 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -349,6 +349,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ + if input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) + return input_ + if not supports_custom_op(): self._all_reduce_in_place(input_) return input_ @@ -380,9 +385,6 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) - elif input_.is_cpu: - import intel_extension_for_pytorch as ipex - ipex.distributed.all_reduce(input_, group=self.device_group) else: torch.distributed.all_reduce(input_, group=self.device_group) From 12cac2821dd03a9e94a0e52e3dc8c077a7e37b42 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:13:28 -0700 Subject: [PATCH 05/25] add toy llama Signed-off-by: youkaichao --- tests/compile/piecewise/test_toy_llama.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index db6a983d70feb..a6d6f8a5f49f8 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,6 +8,7 @@ import torch from torch import nn +from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.config import CompilationConfig @@ -17,7 +18,6 @@ from vllm.plugins import set_compilation_config -@torch.library.custom_op("silly::attention", mutates_args=["out"]) def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: out.copy_(q) @@ -25,12 +25,17 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out += v -@silly_attention.register_fake -def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: return +my_lib = Library("silly", "FRAGMENT") +my_lib.define("attention(Tensor q, Tensor k, Tensor v, Tensor(a3!) out) -> ()") +my_lib.impl("attention", silly_attention, "CUDA") +my_lib._register_fake("attention", silly_attention_fake) + + @dataclass class LlamaConfig: hidden_size: int = 128 From 7f0fedb903d4a131ad575958ac46a64b862b4f88 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:15:56 -0700 Subject: [PATCH 06/25] add test simple Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index a34d33efba1d8..45276ecfce83e 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -6,6 +6,7 @@ import torch from torch import nn +from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter @@ -17,7 +18,6 @@ global_counter = 0 -@torch.library.custom_op("silly::attention", mutates_args=["out"]) def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: global global_counter @@ -27,12 +27,17 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out[0] += 1 -@silly_attention.register_fake -def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: return +my_lib = Library("silly", "FRAGMENT") +my_lib.define("attention(Tensor q, Tensor k, Tensor v, Tensor(a3!) out) -> ()") +my_lib.impl("attention", silly_attention, "CUDA") +my_lib._register_fake("attention", silly_attention_fake) + + @support_torch_compile class SillyModel(nn.Module): From 15c7728a98dedf065f52ada36ee39f3d8a871e97 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:21:26 -0700 Subject: [PATCH 07/25] flash attn Signed-off-by: youkaichao --- vllm/attention/backends/flash_attn.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ffa05e80623ac..a74c009818a5c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch +from torch.library import Library from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -595,8 +596,6 @@ def forward( return output -@torch.library.custom_op("vllm::unified_flash_attention", - mutates_args=["kv_cache"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -755,8 +754,7 @@ def unified_flash_attention( return output.view(num_tokens, hidden_size) -@unified_flash_attention.register_fake -def _( +def unified_flash_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -773,3 +771,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query) + + +my_lib = Library("vllm", "FRAGMENT") +my_lib.define( + "unified_flash_attention(Tensor query, Tensor key, Tensor value, SymInt num_heads, SymInt head_size, SymInt num_kv_heads, Tensor(a6!) kv_cache, str kv_cache_dtype, float k_scale, float v_scale, float softmax_scale, SymInt[]? window_size=None, Tensor? alibi_slopes=None, float? logits_soft_cap=None) -> Tensor" # noqa +) +my_lib.impl("unified_flash_attention", unified_flash_attention, "CUDA") +my_lib._register_fake("unified_flash_attention", unified_flash_attention_fake) From e3ca1b64dfa83b8302517cedab47f93810d4176f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:23:39 -0700 Subject: [PATCH 08/25] flashinfer Signed-off-by: youkaichao --- vllm/attention/backends/flashinfer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5ea101ae0432f..eeb7efadb8705 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -16,6 +16,7 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch +from torch.library import Library import vllm.envs as envs from vllm import _custom_ops as ops @@ -785,8 +786,6 @@ def forward( ) -@torch.library.custom_op("vllm::unified_flash_infer", - mutates_args=["kv_cache"]) def unified_flash_infer( query: torch.Tensor, key: torch.Tensor, @@ -906,8 +905,7 @@ def unified_flash_infer( return output.view(num_tokens, hidden_size) -@unified_flash_infer.register_fake -def _( +def unified_flash_infer_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -924,3 +922,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query).contiguous() + + +my_lib = Library("vllm", "FRAGMENT") +my_lib.define( + "unified_flash_infer(Tensor query, Tensor key, Tensor value, SymInt num_heads, SymInt head_size, SymInt num_kv_heads, Tensor(a6!) kv_cache, str kv_cache_dtype, float k_scale, float v_scale, float softmax_scale, SymInt[]? window_size=None, Tensor? alibi_slopes=None, float? logits_soft_cap=None) -> Tensor" # noqa +) +my_lib.impl("unified_flash_infer", unified_flash_infer, "CUDA") +my_lib._register_fake("unified_flash_infer", unified_flash_infer_fake) From f7e296654124edaa5e22a8b8586e84285dcca6bd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:27:16 -0700 Subject: [PATCH 09/25] v1.flash_attn Signed-off-by: youkaichao --- vllm/v1/attention/backends/flash_attn.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ec07464e6a12a..ff2138b73073f 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch +from torch.library import Library from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -152,8 +153,6 @@ def forward( return output -@torch.library.custom_op("vllm::unified_flash_attention", - mutates_args=["kv_cache"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -217,8 +216,7 @@ def unified_flash_attention( return output.view(num_tokens, hidden_size) -@unified_flash_attention.register_fake -def _( +def unified_flash_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -235,3 +233,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query) + + +my_lib = Library("vllm", "FRAGMENT") +my_lib.define( + "unified_flash_attention(Tensor query, Tensor key, Tensor value, SymInt num_heads, SymInt head_size, SymInt num_kv_heads, Tensor(a6!) kv_cache, str kv_cache_dtype, float k_scale, float v_scale, float softmax_scale, SymInt[]? window_size=None, Tensor? alibi_slopes=None, float? logits_soft_cap=None) -> Tensor" # noqa +) +my_lib.impl("unified_flash_attention", unified_flash_attention, "CUDA") +my_lib._register_fake("unified_flash_attention", unified_flash_attention_fake) From f8fe3dd4d79054c035e6879d039839f35df20c92 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:41:22 -0700 Subject: [PATCH 10/25] utils Signed-off-by: youkaichao --- vllm/attention/backends/flashinfer.py | 4 ++-- vllm/utils.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index eeb7efadb8705..483e719f04b2b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -926,7 +926,7 @@ def unified_flash_infer_fake( my_lib = Library("vllm", "FRAGMENT") my_lib.define( - "unified_flash_infer(Tensor query, Tensor key, Tensor value, SymInt num_heads, SymInt head_size, SymInt num_kv_heads, Tensor(a6!) kv_cache, str kv_cache_dtype, float k_scale, float v_scale, float softmax_scale, SymInt[]? window_size=None, Tensor? alibi_slopes=None, float? logits_soft_cap=None) -> Tensor" # noqa -) + "unified_flash_infer" + + torch.library.infer_schema(unified_flash_infer, mutates_args=["out"])) my_lib.impl("unified_flash_infer", unified_flash_infer, "CUDA") my_lib._register_fake("unified_flash_infer", unified_flash_infer_fake) diff --git a/vllm/utils.py b/vllm/utils.py index 03cdbe6a0dc7b..2433a9c4d3fd7 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -32,6 +32,7 @@ import torch.types import yaml from packaging.version import Version +from torch.library import Library from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs @@ -1512,3 +1513,18 @@ def weak_ref_tensors( if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) raise ValueError("Invalid type for tensors") + + +def direct_register_custom_op( + library_name: str, + op_name: str, + op_func: Callable, + mutates_args: List[str], + fake_impl: Optional[Callable] = None, +): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + my_lib = Library(library_name, "FRAGMENT") + my_lib.define(op_name + schema_str) + my_lib.impl(op_name, op_func, "CUDA") + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) From 8334099489b007598ab2fc9f1038a87a32d02607 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:43:04 -0700 Subject: [PATCH 11/25] flashinfer Signed-off-by: youkaichao --- vllm/attention/backends/flashinfer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 483e719f04b2b..36ebbe5ac0a87 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -16,7 +16,6 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch -from torch.library import Library import vllm.envs as envs from vllm import _custom_ops as ops @@ -30,7 +29,7 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.forward_context import get_forward_context from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, - make_tensor_with_pad) + make_tensor_with_pad, direct_register_custom_op) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -924,9 +923,10 @@ def unified_flash_infer_fake( return torch.empty_like(query).contiguous() -my_lib = Library("vllm", "FRAGMENT") -my_lib.define( - "unified_flash_infer" + - torch.library.infer_schema(unified_flash_infer, mutates_args=["out"])) -my_lib.impl("unified_flash_infer", unified_flash_infer, "CUDA") -my_lib._register_fake("unified_flash_infer", unified_flash_infer_fake) +direct_register_custom_op( + library_name="vllm", + op_name="unified_flash_infer", + op_func=unified_flash_infer, + mutates_args=["kv_cache"], + fake_impl=unified_flash_infer_fake, +) From 5ec48b934c9f29c9dbb31303bcaabaab48e3e3c0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:44:57 -0700 Subject: [PATCH 12/25] simple Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 13 ++++++++----- vllm/attention/backends/flashinfer.py | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 45276ecfce83e..667924f6b7d7b 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -6,12 +6,12 @@ import torch from torch import nn -from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.levels import CompilationLevel +from vllm.utils import direct_register_custom_op os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) @@ -32,10 +32,13 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return -my_lib = Library("silly", "FRAGMENT") -my_lib.define("attention(Tensor q, Tensor k, Tensor v, Tensor(a3!) out) -> ()") -my_lib.impl("attention", silly_attention, "CUDA") -my_lib._register_fake("attention", silly_attention_fake) +direct_register_custom_op( + library_name="silly", + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, +) @support_torch_compile diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 36ebbe5ac0a87..ad0b66cc59547 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -28,8 +28,8 @@ is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention from vllm.forward_context import get_forward_context -from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, - make_tensor_with_pad, direct_register_custom_op) +from vllm.utils import (async_tensor_h2d, direct_register_custom_op, + get_kv_cache_torch_dtype, make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, From 449121cc2879415eccd58a13dae1cc74aabe5cca Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:45:55 -0700 Subject: [PATCH 13/25] toy llama Signed-off-by: youkaichao --- tests/compile/piecewise/test_toy_llama.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index a6d6f8a5f49f8..9b4b046cecdca 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,7 +8,6 @@ import torch from torch import nn -from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.config import CompilationConfig @@ -16,6 +15,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_compilation_config +from vllm.utils import direct_register_custom_op def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -30,10 +30,13 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return -my_lib = Library("silly", "FRAGMENT") -my_lib.define("attention(Tensor q, Tensor k, Tensor v, Tensor(a3!) out) -> ()") -my_lib.impl("attention", silly_attention, "CUDA") -my_lib._register_fake("attention", silly_attention_fake) +direct_register_custom_op( + library_name="silly", + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, +) @dataclass From eca0f28a05b45794724237ef9263692ecdcd4d36 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:47:17 -0700 Subject: [PATCH 14/25] flash attn Signed-off-by: youkaichao --- vllm/attention/backends/flash_attn.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index a74c009818a5c..6244faadb5d37 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch -from torch.library import Library from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -15,7 +14,8 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.forward_context import get_forward_context -from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.utils import (async_tensor_h2d, direct_register_custom_op, + make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -773,9 +773,10 @@ def unified_flash_attention_fake( return torch.empty_like(query) -my_lib = Library("vllm", "FRAGMENT") -my_lib.define( - "unified_flash_attention(Tensor query, Tensor key, Tensor value, SymInt num_heads, SymInt head_size, SymInt num_kv_heads, Tensor(a6!) kv_cache, str kv_cache_dtype, float k_scale, float v_scale, float softmax_scale, SymInt[]? window_size=None, Tensor? alibi_slopes=None, float? logits_soft_cap=None) -> Tensor" # noqa +direct_register_custom_op( + library_name="vllm", + op_name="unified_flash_attention", + op_func=unified_flash_attention, + mutates_args=["kv_cache"], + fake_impl=unified_flash_attention_fake, ) -my_lib.impl("unified_flash_attention", unified_flash_attention, "CUDA") -my_lib._register_fake("unified_flash_attention", unified_flash_attention_fake) From dc23b25b5174d7b06b6d10b5530d2233ccd29cc6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:48:54 -0700 Subject: [PATCH 15/25] parallel state Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a7d4a25ae0759..09351833060a9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -33,12 +33,11 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -from torch.library import Library import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import supports_custom_op +from vllm.utils import supports_custom_op, direct_register_custom_op @dataclass @@ -110,12 +109,13 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: return - my_lib = Library("vllm", "FRAGMENT") - my_lib.define( - "inplace_all_reduce(Tensor(a0!) tensor, str group_name) -> ()" # noqa + direct_register_custom_op( + library_name="vllm", + op_name="inplace_all_reduce", + op_func=inplace_all_reduce, + mutates_args=["tensor"], + fake_impl=inplace_all_reduce_fake, ) - my_lib.impl("inplace_all_reduce", inplace_all_reduce, "CUDA") - my_lib._register_fake("inplace_all_reduce", inplace_all_reduce_fake) def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: @@ -129,12 +129,13 @@ def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) - my_lib = Library("vllm", "FRAGMENT") - my_lib.define( - "outplace_all_reduce(Tensor tensor, str group_name) -> Tensor" # noqa + direct_register_custom_op( + library_name="vllm", + op_name="outplace_all_reduce", + op_func=outplace_all_reduce, + mutates_args=[], + fake_impl=outplace_all_reduce_fake, ) - my_lib.impl("outplace_all_reduce", outplace_all_reduce, "CUDA") - my_lib._register_fake("outplace_all_reduce", outplace_all_reduce_fake) class GroupCoordinator: From cb4f15711529935ab3b9622ad668c5683e4b63e5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:50:08 -0700 Subject: [PATCH 16/25] marlin Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 2 +- .../layers/fused_moe/fused_marlin_moe.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 09351833060a9..fc984a510b76a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,7 +37,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import supports_custom_op, direct_register_custom_op +from vllm.utils import direct_register_custom_op, supports_custom_op @dataclass diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index beb72b15ea3dc..b0828f3e57d74 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -3,12 +3,12 @@ from typing import Optional import torch -from torch.library import Library from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) from vllm.scalar_type import scalar_types +from vllm.utils import direct_register_custom_op def get_scalar_type(num_bits: int, has_zp: bool): @@ -135,12 +135,13 @@ def single_marlin_moe_fake( return torch.empty_like(hidden_states) -my_lib = Library("vllm", "FRAGMENT") -my_lib.define( - "single_marlin_moe(Tensor hidden_states, Tensor w, Tensor scales, Tensor gating_output, SymInt topk, bool renormalize, Tensor? g_idx=None, Tensor? sort_indices=None, Tensor? w_zeros=None, SymInt num_bits=8, bool is_k_full=True) -> Tensor" # noqa +direct_register_custom_op( + library_name="vllm", + op_name="single_marlin_moe", + op_func=single_marlin_moe, + mutates_args=[], + fake_impl=single_marlin_moe_fake, ) -my_lib.impl("single_marlin_moe", single_marlin_moe, "CUDA") -my_lib._register_fake("single_marlin_moe", single_marlin_moe_fake) def fused_marlin_moe( @@ -351,9 +352,10 @@ def fused_marlin_moe_fake( return torch.empty_like(hidden_states) -my_lib = Library("vllm", "FRAGMENT") -my_lib.define( - "fused_marlin_moe(Tensor hidden_states, Tensor w1, Tensor w2, Tensor w1_scale, Tensor w2_scale, Tensor gating_output, Tensor topk_weights, Tensor topk_ids, Tensor? g_idx1=None, Tensor? g_idx2=None, Tensor? sort_indices1=None, Tensor? sort_indices2=None, Tensor? w1_zeros=None, Tensor? w2_zeros=None, SymInt num_bits=8, bool is_k_full=True) -> Tensor" # noqa +direct_register_custom_op( + library_name="vllm", + op_name="fused_marlin_moe", + op_func=fused_marlin_moe, + mutates_args=[], + fake_impl=fused_marlin_moe_fake, ) -my_lib.impl("fused_marlin_moe", fused_marlin_moe, "CUDA") -my_lib._register_fake("fused_marlin_moe", fused_marlin_moe_fake) From 3c5e7dd6aa6fa571d8059a4c153c8b34ab537342 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:51:46 -0700 Subject: [PATCH 17/25] fused moe Signed-off-by: youkaichao --- .../layers/fused_moe/fused_moe.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4cd0ef1df0377..710778a78b10c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -7,12 +7,12 @@ import torch import triton import triton.language as tl -from torch.library import Library import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -498,12 +498,13 @@ def inplace_fused_experts_fake( pass -my_lib = Library("vllm", "FRAGMENT") -my_lib.define( - "inplace_fused_experts(Tensor(a0!) hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool use_fp8_w8a8=False, bool use_int8_w8a16=False, Tensor? w1_scale=None, Tensor? w2_scale=None, Tensor? a1_scale=None, Tensor? a2_scale=None) -> ()" # noqa +direct_register_custom_op( + library_name="vllm", + op_name="inplace_fused_experts", + op_func=inplace_fused_experts, + mutates_args=["hidden_states"], + fake_impl=inplace_fused_experts_fake, ) -my_lib.impl("inplace_fused_experts", inplace_fused_experts, "CUDA") -my_lib._register_fake("inplace_fused_experts", inplace_fused_experts_fake) def outplace_fused_experts( @@ -538,12 +539,13 @@ def outplace_fused_experts_fake( return torch.empty_like(hidden_states) -my_lib = Library("vllm", "FRAGMENT") -my_lib.define( - "outplace_fused_experts(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool use_fp8_w8a8=False, bool use_int8_w8a16=False, Tensor? w1_scale=None, Tensor? w2_scale=None, Tensor? a1_scale=None, Tensor? a2_scale=None) -> Tensor" # noqa +direct_register_custom_op( + library_name="vllm", + op_name="outplace_fused_experts", + op_func=outplace_fused_experts, + mutates_args=[], + fake_impl=outplace_fused_experts_fake, ) -my_lib.impl("outplace_fused_experts", outplace_fused_experts, "CUDA") -my_lib._register_fake("outplace_fused_experts", outplace_fused_experts_fake) def fused_experts(hidden_states: torch.Tensor, From 0547328071dbd135f5ccd3258f1f968d42b9fc5e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:52:40 -0700 Subject: [PATCH 18/25] v1.flash_attn Signed-off-by: youkaichao --- vllm/v1/attention/backends/flash_attn.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ff2138b73073f..86f4665b27a91 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -3,11 +3,11 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch -from torch.library import Library from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.forward_context import get_forward_context +from vllm.utils import direct_register_custom_op from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -235,9 +235,10 @@ def unified_flash_attention_fake( return torch.empty_like(query) -my_lib = Library("vllm", "FRAGMENT") -my_lib.define( - "unified_flash_attention(Tensor query, Tensor key, Tensor value, SymInt num_heads, SymInt head_size, SymInt num_kv_heads, Tensor(a6!) kv_cache, str kv_cache_dtype, float k_scale, float v_scale, float softmax_scale, SymInt[]? window_size=None, Tensor? alibi_slopes=None, float? logits_soft_cap=None) -> Tensor" # noqa +direct_register_custom_op( + library_name="vllm", + op_name="unified_flash_attention", + op_func=unified_flash_attention, + mutates_args=["kv_cache"], + fake_impl=unified_flash_attention_fake, ) -my_lib.impl("unified_flash_attention", unified_flash_attention, "CUDA") -my_lib._register_fake("unified_flash_attention", unified_flash_attention_fake) From d87bfe0026d3172fa1cb5307ae8ed1999141d9c8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 16:55:10 -0700 Subject: [PATCH 19/25] add comments Signed-off-by: youkaichao --- vllm/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/utils.py b/vllm/utils.py index 2433a9c4d3fd7..456b6e5648bee 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1522,6 +1522,13 @@ def direct_register_custom_op( mutates_args: List[str], fake_impl: Optional[Callable] = None, ): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + """ schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) my_lib = Library(library_name, "FRAGMENT") my_lib.define(op_name + schema_str) From d515d619a4f2f97b0bdd17d1b8652a4d82629fc4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 18:40:15 -0700 Subject: [PATCH 20/25] hack fix library Signed-off-by: youkaichao --- tests/compile/piecewise/piecewise_compilation_config.json | 2 +- tests/compile/piecewise/test_simple.py | 8 ++++---- tests/compile/piecewise/test_toy_llama.py | 8 ++++---- vllm/utils.py | 7 ++++++- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/compile/piecewise/piecewise_compilation_config.json b/tests/compile/piecewise/piecewise_compilation_config.json index 03d077b76f627..c5d7d066e0103 100644 --- a/tests/compile/piecewise/piecewise_compilation_config.json +++ b/tests/compile/piecewise/piecewise_compilation_config.json @@ -1,4 +1,4 @@ { "use_cudagraph": true, - "non_cudagraph_ops": ["silly.attention"] + "non_cudagraph_ops": ["vllm.toy_attention"] } \ No newline at end of file diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 667924f6b7d7b..f28b4d624d4bd 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -33,8 +33,8 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, direct_register_custom_op( - library_name="silly", - op_name="attention", + library_name="vllm", + op_name="toy_attention", op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, @@ -57,12 +57,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + 1 x = x + 2 out = torch.empty_like(x) - torch.ops.silly.attention(x, x, x, out) + torch.ops.vllm.toy_attention(x, x, x, out) x = out x = x - 2 x = x - 1 out = torch.empty_like(x) - torch.ops.silly.attention(x, x, x, out) + torch.ops.vllm.toy_attention(x, x, x, out) x = out x = x + 1 return x diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 9b4b046cecdca..bc583186986d0 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -31,8 +31,8 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, direct_register_custom_op( - library_name="silly", - op_name="attention", + library_name="vllm", + op_name="toy_attention", op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, @@ -103,7 +103,7 @@ def forward( k = k + positions.unsqueeze(1) attn_output = torch.empty_like(q) - torch.ops.silly.attention(q, k, v, attn_output) + torch.ops.vllm.toy_attention(q, k, v, attn_output) output = self.output_projection(attn_output) return output @@ -179,7 +179,7 @@ def run_model(llama_config, set_compilation_config( CompilationConfig( use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], + non_cudagraph_ops=["vllm.toy_attention"], )) else: set_compilation_config(CompilationConfig(use_cudagraph=True, )) diff --git a/vllm/utils.py b/vllm/utils.py index 456b6e5648bee..2781edaf4f197 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1515,6 +1515,9 @@ def weak_ref_tensors( raise ValueError("Invalid type for tensors") +vllm_lib = Library("vllm", "FRAGMENT") + + def direct_register_custom_op( library_name: str, op_name: str, @@ -1530,7 +1533,9 @@ def direct_register_custom_op( for more details. """ schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - my_lib = Library(library_name, "FRAGMENT") + # FIXME after https://github.com/pytorch/pytorch/issues/139444 is resolved + assert library_name == "vllm" + my_lib = vllm_lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: From 012533a72bec0264576755a278e98ca83147169d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 19:06:06 -0700 Subject: [PATCH 21/25] fix doc build Signed-off-by: youkaichao --- vllm/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/utils.py b/vllm/utils.py index 2781edaf4f197..f397d5408fc68 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1515,6 +1515,14 @@ def weak_ref_tensors( raise ValueError("Invalid type for tensors") +def is_in_doc_build() -> bool: + try: + from sphinx.ext.autodoc.mock import _MockModule + return isinstance(torch, _MockModule) + except ModuleNotFoundError: + return False + + vllm_lib = Library("vllm", "FRAGMENT") @@ -1532,6 +1540,8 @@ def direct_register_custom_op( See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. """ + if is_in_doc_build(): + return schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) # FIXME after https://github.com/pytorch/pytorch/issues/139444 is resolved assert library_name == "vllm" From 2ef5e40b1205008f734ec1a2594cd644850607d1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 21:12:02 -0700 Subject: [PATCH 22/25] fix lifetime Signed-off-by: youkaichao --- .../piecewise_compilation_config.json | 2 +- tests/compile/piecewise/test_simple.py | 12 ++++++---- tests/compile/piecewise/test_toy_llama.py | 12 ++++++---- vllm/__init__.py | 5 +++++ vllm/utils.py | 22 +++++-------------- 5 files changed, 28 insertions(+), 25 deletions(-) diff --git a/tests/compile/piecewise/piecewise_compilation_config.json b/tests/compile/piecewise/piecewise_compilation_config.json index c5d7d066e0103..03d077b76f627 100644 --- a/tests/compile/piecewise/piecewise_compilation_config.json +++ b/tests/compile/piecewise/piecewise_compilation_config.json @@ -1,4 +1,4 @@ { "use_cudagraph": true, - "non_cudagraph_ops": ["vllm.toy_attention"] + "non_cudagraph_ops": ["silly.attention"] } \ No newline at end of file diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index f28b4d624d4bd..e7d5c55cd997d 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -6,6 +6,7 @@ import torch from torch import nn +from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter @@ -17,6 +18,9 @@ global_counter = 0 +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: @@ -33,8 +37,8 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, direct_register_custom_op( - library_name="vllm", - op_name="toy_attention", + library_name="silly", + op_name="attention", op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, @@ -57,12 +61,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + 1 x = x + 2 out = torch.empty_like(x) - torch.ops.vllm.toy_attention(x, x, x, out) + torch.ops.silly.attention(x, x, x, out) x = out x = x - 2 x = x - 1 out = torch.empty_like(x) - torch.ops.vllm.toy_attention(x, x, x, out) + torch.ops.silly.attention(x, x, x, out) x = out x = x + 1 return x diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index bc583186986d0..5290aad2658ab 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,6 +8,7 @@ import torch from torch import nn +from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.config import CompilationConfig @@ -17,6 +18,9 @@ from vllm.plugins import set_compilation_config from vllm.utils import direct_register_custom_op +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: @@ -31,8 +35,8 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, direct_register_custom_op( - library_name="vllm", - op_name="toy_attention", + library_name="silly", + op_name="attention", op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, @@ -103,7 +107,7 @@ def forward( k = k + positions.unsqueeze(1) attn_output = torch.empty_like(q) - torch.ops.vllm.toy_attention(q, k, v, attn_output) + torch.ops.silly.attention(q, k, v, attn_output) output = self.output_projection(attn_output) return output @@ -179,7 +183,7 @@ def run_model(llama_config, set_compilation_config( CompilationConfig( use_cudagraph=True, - non_cudagraph_ops=["vllm.toy_attention"], + non_cudagraph_ops=["silly.attention"], )) else: set_compilation_config(CompilationConfig(use_cudagraph=True, )) diff --git a/vllm/__init__.py b/vllm/__init__.py index 8f477ea84756d..4b3026fc47fc2 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,5 +1,7 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +from torch.library import Library + from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine @@ -14,6 +16,9 @@ from .version import __version__, __version_tuple__ +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + __all__ = [ "__version__", "__version_tuple__", diff --git a/vllm/utils.py b/vllm/utils.py index f397d5408fc68..1ff542fdf5934 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1515,17 +1515,6 @@ def weak_ref_tensors( raise ValueError("Invalid type for tensors") -def is_in_doc_build() -> bool: - try: - from sphinx.ext.autodoc.mock import _MockModule - return isinstance(torch, _MockModule) - except ModuleNotFoundError: - return False - - -vllm_lib = Library("vllm", "FRAGMENT") - - def direct_register_custom_op( library_name: str, op_name: str, @@ -1539,13 +1528,14 @@ def direct_register_custom_op( directly registers a custom op and dispatches it to the CUDA backend. See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. It is important to have one line of code + `my_lib = Library(library_name, "FRAGMENT")` outside of the function + to keep the library object alive. """ - if is_in_doc_build(): - return schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - # FIXME after https://github.com/pytorch/pytorch/issues/139444 is resolved - assert library_name == "vllm" - my_lib = vllm_lib + my_lib = Library(library_name, "FRAGMENT") my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: From 65945e0cf906eff5ee9da5a655ebbbe8862f8107 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 21:17:31 -0700 Subject: [PATCH 23/25] add lib arg Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 1 + tests/compile/piecewise/test_toy_llama.py | 1 + vllm/__init__.py | 5 ----- vllm/utils.py | 11 ++++++++++- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index e7d5c55cd997d..85f2b856e049e 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -42,6 +42,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, + lib=silly_lib, ) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 5290aad2658ab..a75303c45ccc4 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -40,6 +40,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, + lib=silly_lib, ) diff --git a/vllm/__init__.py b/vllm/__init__.py index 4b3026fc47fc2..8f477ea84756d 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,7 +1,5 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" -from torch.library import Library - from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine @@ -16,9 +14,6 @@ from .version import __version__, __version_tuple__ -# create a library to hold the custom op -vllm_lib = Library("vllm", "FRAGMENT") # noqa - __all__ = [ "__version__", "__version_tuple__", diff --git a/vllm/utils.py b/vllm/utils.py index 1ff542fdf5934..68e8c0efa4514 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1515,12 +1515,17 @@ def weak_ref_tensors( raise ValueError("Invalid type for tensors") +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + def direct_register_custom_op( library_name: str, op_name: str, op_func: Callable, mutates_args: List[str], fake_impl: Optional[Callable] = None, + lib: Optional[Library] = None, ): """ `torch.library.custom_op` can have significant overhead because it @@ -1535,7 +1540,11 @@ def direct_register_custom_op( to keep the library object alive. """ schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - my_lib = Library(library_name, "FRAGMENT") + if library_name == "vllm": + my_lib = vllm_lib + else: + assert lib is not None + my_lib = lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: From 529b28a975637571a426c6fc0099e3ae994ecf77 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 21:22:02 -0700 Subject: [PATCH 24/25] fix Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 3 +-- tests/compile/piecewise/test_toy_llama.py | 3 +-- vllm/attention/backends/flash_attn.py | 1 - vllm/attention/backends/flashinfer.py | 1 - vllm/distributed/parallel_state.py | 2 -- .../layers/fused_moe/fused_marlin_moe.py | 2 -- .../layers/fused_moe/fused_moe.py | 2 -- vllm/utils.py | 18 ++++++++---------- vllm/v1/attention/backends/flash_attn.py | 1 - 9 files changed, 10 insertions(+), 23 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 85f2b856e049e..d151d62516b07 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -37,12 +37,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, direct_register_custom_op( - library_name="silly", op_name="attention", op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, - lib=silly_lib, + target_lib=silly_lib, ) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index a75303c45ccc4..e3e5a7d0fc5a5 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -35,12 +35,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, direct_register_custom_op( - library_name="silly", op_name="attention", op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, - lib=silly_lib, + target_lib=silly_lib, ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6244faadb5d37..c294fcf7f08fe 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -774,7 +774,6 @@ def unified_flash_attention_fake( direct_register_custom_op( - library_name="vllm", op_name="unified_flash_attention", op_func=unified_flash_attention, mutates_args=["kv_cache"], diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ad0b66cc59547..234c87d5c4edb 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -924,7 +924,6 @@ def unified_flash_infer_fake( direct_register_custom_op( - library_name="vllm", op_name="unified_flash_infer", op_func=unified_flash_infer, mutates_args=["kv_cache"], diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index fc984a510b76a..94ba41a016f6d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -110,7 +110,6 @@ def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: return direct_register_custom_op( - library_name="vllm", op_name="inplace_all_reduce", op_func=inplace_all_reduce, mutates_args=["tensor"], @@ -130,7 +129,6 @@ def outplace_all_reduce_fake(tensor: torch.Tensor, return torch.empty_like(tensor) direct_register_custom_op( - library_name="vllm", op_name="outplace_all_reduce", op_func=outplace_all_reduce, mutates_args=[], diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index b0828f3e57d74..4741d69de11ac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -136,7 +136,6 @@ def single_marlin_moe_fake( direct_register_custom_op( - library_name="vllm", op_name="single_marlin_moe", op_func=single_marlin_moe, mutates_args=[], @@ -353,7 +352,6 @@ def fused_marlin_moe_fake( direct_register_custom_op( - library_name="vllm", op_name="fused_marlin_moe", op_func=fused_marlin_moe, mutates_args=[], diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 710778a78b10c..340da32263c1c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -499,7 +499,6 @@ def inplace_fused_experts_fake( direct_register_custom_op( - library_name="vllm", op_name="inplace_fused_experts", op_func=inplace_fused_experts, mutates_args=["hidden_states"], @@ -540,7 +539,6 @@ def outplace_fused_experts_fake( direct_register_custom_op( - library_name="vllm", op_name="outplace_fused_experts", op_func=outplace_fused_experts, mutates_args=[], diff --git a/vllm/utils.py b/vllm/utils.py index 68e8c0efa4514..e989257302858 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1520,12 +1520,11 @@ def weak_ref_tensors( def direct_register_custom_op( - library_name: str, op_name: str, op_func: Callable, mutates_args: List[str], fake_impl: Optional[Callable] = None, - lib: Optional[Library] = None, + target_lib: Optional[Library] = None, ): """ `torch.library.custom_op` can have significant overhead because it @@ -1534,17 +1533,16 @@ def direct_register_custom_op( See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + IMPORTANT: the lifetime of the operator is tied to the lifetime of the - library object. It is important to have one line of code - `my_lib = Library(library_name, "FRAGMENT")` outside of the function - to keep the library object alive. + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. """ schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - if library_name == "vllm": - my_lib = vllm_lib - else: - assert lib is not None - my_lib = lib + my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 86f4665b27a91..b2af89ebf854a 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -236,7 +236,6 @@ def unified_flash_attention_fake( direct_register_custom_op( - library_name="vllm", op_name="unified_flash_attention", op_func=unified_flash_attention, mutates_args=["kv_cache"], From d4b13991fcc25b6a0498ff580d57835752dd0787 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 21:25:58 -0700 Subject: [PATCH 25/25] fix doc build Signed-off-by: youkaichao --- vllm/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/utils.py b/vllm/utils.py index e989257302858..5488719cc99b0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1515,6 +1515,14 @@ def weak_ref_tensors( raise ValueError("Invalid type for tensors") +def is_in_doc_build() -> bool: + try: + from sphinx.ext.autodoc.mock import _MockModule + return isinstance(torch, _MockModule) + except ModuleNotFoundError: + return False + + # create a library to hold the custom op vllm_lib = Library("vllm", "FRAGMENT") # noqa @@ -1541,6 +1549,8 @@ def direct_register_custom_op( library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ + if is_in_doc_build(): + return schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str)