From bbf023b24ad53810f59926e5ad2bef422ec90e42 Mon Sep 17 00:00:00 2001 From: skrider Date: Fri, 8 Mar 2024 01:11:50 +0000 Subject: [PATCH 01/81] vendor flash-attention --- .gitmodules | 3 +++ csrc/flash-attention | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 csrc/flash-attention diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000..8790f31f51adb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "csrc/flash-attention"] + path = csrc/flash-attention + url = git@github.com:skrider/flash-attention.git diff --git a/csrc/flash-attention b/csrc/flash-attention new file mode 160000 index 0000000000000..61a777247900f --- /dev/null +++ b/csrc/flash-attention @@ -0,0 +1 @@ +Subproject commit 61a777247900f6c2a37376f3ffd7134385fdc95c From 38d422d15de9a9dd043f9a237ef4283fcdc10039 Mon Sep 17 00:00:00 2001 From: skrider Date: Wed, 27 Mar 2024 05:26:49 +0000 Subject: [PATCH 02/81] update vendored flash-attention --- csrc/flash-attention | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash-attention b/csrc/flash-attention index 61a777247900f..7429988c59155 160000 --- a/csrc/flash-attention +++ b/csrc/flash-attention @@ -1 +1 @@ -Subproject commit 61a777247900f6c2a37376f3ffd7134385fdc95c +Subproject commit 7429988c59155582366f5644db292adb6f6f22b1 From 00dc8ade30ad2d7951f198f7b4c1cfcfeb6fc5bc Mon Sep 17 00:00:00 2001 From: skrider Date: Fri, 8 Mar 2024 03:22:06 +0000 Subject: [PATCH 03/81] add reshape_and_cache_flash --- csrc/cache.h | 7 +++++ csrc/cache_kernels.cu | 73 +++++++++++++++++++++++++++++++++++++++++++ csrc/pybind.cpp | 4 +++ 3 files changed, 84 insertions(+) diff --git a/csrc/cache.h b/csrc/cache.h index 765e231abd26f..7ceb43f19a94c 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -23,6 +23,13 @@ void reshape_and_cache( torch::Tensor& slot_mapping, const std::string& kv_cache_dtype); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& kv_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); + // Just for unittest void convert_fp8_e5m2( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7254010b8e3a9..003bda342bb47 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -209,6 +209,40 @@ __global__ void reshape_and_cache_kernel( } } +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ kv_cache, // [num_blocks, 2, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_value_idx = block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + kv_cache[tgt_value_idx] = __ldg(&key[src_key_idx]); + kv_cache[tgt_value_idx + block_size * num_heads * head_size] = __ldg(&value[src_value_idx]); + } +} + } // namespace vllm #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ @@ -267,6 +301,45 @@ void reshape_and_cache( } } +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& kv_cache, // [num_blocks, 2, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) +{ + if (kv_cache_dtype != "auto") { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = kv_cache.size(2); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_flash", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + kv_cache.data_ptr(), + slot_mapping.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); +} + namespace vllm { template diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a5c6439fd6909..4b193f0b6c856 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -90,6 +90,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); cache_ops.def( "convert_fp8_e5m2", &convert_fp8_e5m2, From 4ffe256fd3f15e04a2de5c4aeb596c086d634d04 Mon Sep 17 00:00:00 2001 From: skrider Date: Fri, 8 Mar 2024 03:55:48 +0000 Subject: [PATCH 04/81] refactor reshape_and_cache_flash --- csrc/cache_kernels.cu | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 003bda342bb47..1498d73caa506 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -213,8 +213,10 @@ template __global__ void reshape_and_cache_flash_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ kv_cache, // [num_blocks, 2, block_size, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, const int num_heads, @@ -234,12 +236,12 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_size * num_heads * head_size + const int64_t tgt_value_idx = block_idx * block_stride + block_offset * num_heads * head_size + head_idx * head_size + head_offset; - kv_cache[tgt_value_idx] = __ldg(&key[src_key_idx]); - kv_cache[tgt_value_idx + block_size * num_heads * head_size] = __ldg(&value[src_value_idx]); + k_cache[tgt_value_idx] = __ldg(&key[src_key_idx]); + v_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); } } @@ -304,7 +306,8 @@ void reshape_and_cache( void reshape_and_cache_flash( torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& kv_cache, // [num_blocks, 2, block_size, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype) { @@ -314,10 +317,12 @@ void reshape_and_cache_flash( int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); - int block_size = kv_cache.size(2); + int block_size = kv_cache.size(1); int key_stride = key.stride(0); int value_stride = value.stride(0); + int block_stride = k_cache.stride(0); + TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); @@ -330,8 +335,10 @@ void reshape_and_cache_flash( vllm::reshape_and_cache_flash_kernel<<>>( key.data_ptr(), value.data_ptr(), - kv_cache.data_ptr(), + k_cache.data_ptr(), + v_cache.data_ptr(), slot_mapping.data_ptr(), + block_stride, key_stride, value_stride, num_heads, From 6a2ddf409d7db58ef3b287a1470d2996adb3395f Mon Sep 17 00:00:00 2001 From: skrider Date: Fri, 8 Mar 2024 05:44:49 +0000 Subject: [PATCH 05/81] refactor reshape_and_cache_flash --- csrc/cache.h | 3 ++- csrc/cache_kernels.cu | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 7ceb43f19a94c..9448978ad9fcd 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -26,7 +26,8 @@ void reshape_and_cache( void reshape_and_cache_flash( torch::Tensor& key, torch::Tensor& value, - torch::Tensor& kv_cache, + torch::Tensor& key_cache, + torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1498d73caa506..dad156cb80cc9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -317,7 +317,7 @@ void reshape_and_cache_flash( int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); - int block_size = kv_cache.size(1); + int block_size = k_cache.size(1); int key_stride = key.stride(0); int value_stride = value.stride(0); From 0e45f5d486e26edf7dcddcf508da6ba1ccdb30f0 Mon Sep 17 00:00:00 2001 From: skrider Date: Wed, 27 Mar 2024 06:26:20 +0000 Subject: [PATCH 06/81] implement flash attention decode backend --- vllm/attention/backends/flash_attn_decode.py | 232 +++++++++++++++++++ vllm/attention/ops/flash_attn.py | 150 ++++++++++++ vllm/attention/selector.py | 14 +- 3 files changed, 392 insertions(+), 4 deletions(-) create mode 100644 vllm/attention/backends/flash_attn_decode.py create mode 100644 vllm/attention/ops/flash_attn.py diff --git a/vllm/attention/backends/flash_attn_decode.py b/vllm/attention/backends/flash_attn_decode.py new file mode 100644 index 0000000000000..707fc3cd6c785 --- /dev/null +++ b/vllm/attention/backends/flash_attn_decode.py @@ -0,0 +1,232 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch +from flash_attn import flash_attn_varlen_func + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.flash_attn import (FlashAttention, + FlashAttentionMetadata) + + +class FlashAttentionDecodeBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionDecodeImpl"]: + return FlashAttentionDecodeImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "FlashAttentionDecodeMetadata": + return FlashAttentionDecodeMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return FlashAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + FlashAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + FlashAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class FlashAttentionDecodeMetadata(AttentionMetadata, FlashAttentionMetadata): + """Metadata for FlashAttentionDecodeBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + + # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seqlen ----------------------| + # |- subquery_len -| + + # WARNING(sang): context_len has different definition depending on if it is + # prefill vs decoding. When it is prefill, it doesn't include new tokens. + # When it is for decoding, it includes a new token. + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # Maximum prompt length in the batch. + max_prompt_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + +class FlashAttentionDecodeImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = FlashAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionDecodeMetadata, + ) -> torch.Tensor: + """Forward pass with FlashAttentionDecode and FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = FlashAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + FlashAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype) + if attn_metadata.is_prompt: + # Prompt run. + if kv_cache is None or attn_metadata.block_tables.numel() == 0: + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: + # prefix-enabled attention + output = FlashAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.subquery_start_loc, + attn_metadata.prompt_lens_tensor, + attn_metadata.context_lens, + attn_metadata.max_subquery_len, + self.alibi_slopes, + ) + else: + # Decoding run. + output = FlashAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/ops/flash_attn.py b/vllm/attention/ops/flash_attn.py new file mode 100644 index 0000000000000..192a3e6e2a2cf --- /dev/null +++ b/vllm/attention/ops/flash_attn.py @@ -0,0 +1,150 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm._C import cache_ops, ops +from vllm.attention.ops.prefix_prefill import context_attention_fwd +from flash_attn import flash_attn_with_kvcache + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class FlashAttentionMetadata: + """Metadata for FlashAttention.""" + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # (batch_size,). The length of context (tokens stored in KV cache) per + # sequence. WARNING: When it is a prefill request, it doesn't include new + # tokens. When it is for decoding, it includes a new token. + context_lens: Optional[torch.Tensor] + # Maximum context length in the batch. + max_context_len: Optional[int] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + kv_cache_dtype: str + + +class FlashAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, -1, num_kv_heads, head_size) + + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + ) -> None: + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + ) -> torch.Tensor: + # TODO(skrider) tune num_splits heuristic + return flash_attn_with_kvcache( + query.unsqueeze(1), + key_cache, + value_cache, + None, + None, + cache_seqlens=context_lens, + block_table=block_tables, + softmax_scale=scale, + alibi_slopes=alibi_slopes, + ) + + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + subquery_start_loc: torch.Tensor, + prompt_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_subquery_len: int, + alibi_slopes: Optional[torch.Tensor], + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 90fce1a0349b2..23e365e6ab294 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -12,10 +12,16 @@ @lru_cache(maxsize=None) def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: if _can_use_flash_attn(dtype): - logger.info("Using FlashAttention backend.") - from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend) - return FlashAttentionBackend + if __import__("os").environ.get("VLLM_TEMP_USE_FLASH_DECODE", "0") == "1": + from vllm.attention.backends.flash_attn_decode import ( + FlashAttentionDecodeBackend) + logger.info("Using FlashAttentionDecode backend.") + return FlashAttentionDecodeBackend + else: + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend) + logger.info("Using FlashAttention backend.") + return FlashAttentionBackend else: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 From 45c3662e85d6ac5449ec0e47f34e28cb31c03efb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Mar 2024 19:37:51 +0000 Subject: [PATCH 07/81] FlashAttentionDecode -> FlashAttention --- vllm/attention/backends/flash_attn.py | 107 +++++---- vllm/attention/backends/flash_attn_decode.py | 232 ------------------- vllm/attention/ops/flash_attn.py | 150 ------------ vllm/attention/selector.py | 12 +- 4 files changed, 64 insertions(+), 437 deletions(-) delete mode 100644 vllm/attention/backends/flash_attn_decode.py delete mode 100644 vllm/attention/ops/flash_attn.py diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e50d52377b8e0..69b4678dc7f30 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,19 +1,13 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" +"""Attention layer with FlashAttention.""" from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type import torch -from flash_attn import flash_attn_varlen_func +from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) class FlashAttentionBackend(AttentionBackend): @@ -33,8 +27,9 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + if block_size != 16: + raise ValueError("FlashAttention only supports block size 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -42,18 +37,26 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: Dict[int, int], ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: Dict[int, List[int]], ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -103,6 +106,26 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # (batch_size,). The length of context (tokens stored in KV cache) per + # sequence. WARNING: When it is a prefill request, it doesn't include new + # tokens. When it is for decoding, it includes a new token. + context_lens: Optional[torch.Tensor] + # Maximum context length in the batch. + max_context_len: Optional[int] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + kv_cache_dtype: str + class FlashAttentionImpl(AttentionImpl): """ @@ -143,10 +166,10 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() + suppored_head_sizes = [32, 64, 96, 128, 160, 192, 224, 256] if head_size not in suppored_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " + f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {suppored_head_sizes}.") def forward( @@ -157,13 +180,13 @@ def forward( kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. + """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -175,16 +198,20 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + key_cache = kv_cache[0] + value_cache = kv_cache[1] # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype) + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + attn_metadata.kv_cache_dtype, + ) if attn_metadata.is_prompt: # Prompt run. @@ -207,32 +234,20 @@ def forward( ) else: # prefix-enabled attention - output = PagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, - self.alibi_slopes, - ) + raise NotImplementedError( + "Prefix-enabled attention is not supported by " + "the FlashAttention backend yet.") else: # Decoding run. - output = PagedAttention.forward_decode( - query, + # TODO(skrider): tune num_splits heuristic + output = flash_attn_with_kvcache( + query.unsqueeze(1), key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, - attn_metadata.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, + block_table=attn_metadata.block_tables, + cache_seqlens=attn_metadata.context_lens, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/flash_attn_decode.py b/vllm/attention/backends/flash_attn_decode.py deleted file mode 100644 index 707fc3cd6c785..0000000000000 --- a/vllm/attention/backends/flash_attn_decode.py +++ /dev/null @@ -1,232 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type - -import torch -from flash_attn import flash_attn_varlen_func - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) -from vllm.attention.ops.flash_attn import (FlashAttention, - FlashAttentionMetadata) - - -class FlashAttentionDecodeBackend(AttentionBackend): - - @staticmethod - def get_impl_cls() -> Type["FlashAttentionDecodeImpl"]: - return FlashAttentionDecodeImpl - - @staticmethod - def make_metadata(*args, **kwargs) -> "FlashAttentionDecodeMetadata": - return FlashAttentionDecodeMetadata(*args, **kwargs) - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return FlashAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - ) -> None: - FlashAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], - ) -> None: - FlashAttention.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class FlashAttentionDecodeMetadata(AttentionMetadata, FlashAttentionMetadata): - """Metadata for FlashAttentionDecodeBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int - - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| - - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - -class FlashAttentionDecodeImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - suppored_head_sizes = FlashAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionDecodeMetadata, - ) -> torch.Tensor: - """Forward pass with FlashAttentionDecode and FlashAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache is not None: - key_cache, value_cache = FlashAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - FlashAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype) - if attn_metadata.is_prompt: - # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - else: - # prefix-enabled attention - output = FlashAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, - self.alibi_slopes, - ) - else: - # Decoding run. - output = FlashAttention.forward_decode( - query, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, - attn_metadata.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) - - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/ops/flash_attn.py b/vllm/attention/ops/flash_attn.py deleted file mode 100644 index 192a3e6e2a2cf..0000000000000 --- a/vllm/attention/ops/flash_attn.py +++ /dev/null @@ -1,150 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -import torch - -from vllm._C import cache_ops, ops -from vllm.attention.ops.prefix_prefill import context_attention_fwd -from flash_attn import flash_attn_with_kvcache - -# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 - - -@dataclass -class FlashAttentionMetadata: - """Metadata for FlashAttention.""" - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # Maximum context length in the batch. - max_context_len: Optional[int] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - kv_cache_dtype: str - - -class FlashAttention: - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (2, num_blocks, block_size * num_kv_heads * head_size) - - @staticmethod - def split_kv_cache( - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - num_blocks = kv_cache.shape[1] - - key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, -1, num_kv_heads, head_size) - - value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size) - return key_cache, value_cache - - @staticmethod - def write_to_paged_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - ) -> None: - cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping.flatten(), - kv_cache_dtype, - ) - - @staticmethod - def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - ) -> torch.Tensor: - # TODO(skrider) tune num_splits heuristic - return flash_attn_with_kvcache( - query.unsqueeze(1), - key_cache, - value_cache, - None, - None, - cache_seqlens=context_lens, - block_table=block_tables, - softmax_scale=scale, - alibi_slopes=alibi_slopes, - ) - - - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, - max_subquery_len: int, - alibi_slopes: Optional[torch.Tensor], - ) -> torch.Tensor: - raise NotImplementedError - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 23e365e6ab294..0bc394191cd5a 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -12,16 +12,10 @@ @lru_cache(maxsize=None) def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: if _can_use_flash_attn(dtype): - if __import__("os").environ.get("VLLM_TEMP_USE_FLASH_DECODE", "0") == "1": - from vllm.attention.backends.flash_attn_decode import ( - FlashAttentionDecodeBackend) - logger.info("Using FlashAttentionDecode backend.") - return FlashAttentionDecodeBackend - else: - from vllm.attention.backends.flash_attn import ( # noqa: F401 + from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - logger.info("Using FlashAttention backend.") - return FlashAttentionBackend + logger.info("Using FlashAttention backend.") + return FlashAttentionBackend else: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 From 18132d2d9de27009749e63a9ba51b4bded4418b3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Mar 2024 19:38:30 +0000 Subject: [PATCH 08/81] Remove gitmodule --- .gitmodules | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .gitmodules diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 8790f31f51adb..0000000000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "csrc/flash-attention"] - path = csrc/flash-attention - url = git@github.com:skrider/flash-attention.git From 3cc5ebd805958f6768eecc02bff1f9de5414b8cb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Mar 2024 19:39:28 +0000 Subject: [PATCH 09/81] Minor --- vllm/attention/selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 0bc394191cd5a..4869e6fc5ea95 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -12,9 +12,9 @@ @lru_cache(maxsize=None) def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: if _can_use_flash_attn(dtype): + logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - logger.info("Using FlashAttention backend.") return FlashAttentionBackend else: logger.info("Using XFormers backend.") From 70f6b16b3fb721692ecda98c56b0949ed5d0ffcc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Mar 2024 19:39:47 +0000 Subject: [PATCH 10/81] Minor --- vllm/attention/selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4869e6fc5ea95..90fce1a0349b2 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -14,7 +14,7 @@ def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: if _can_use_flash_attn(dtype): logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend) + FlashAttentionBackend) return FlashAttentionBackend else: logger.info("Using XFormers backend.") From 1ebf12e4a885921418d595cfa1524570e1f4576a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Mar 2024 19:40:39 +0000 Subject: [PATCH 11/81] Remove submodule --- csrc/flash-attention | 1 - 1 file changed, 1 deletion(-) delete mode 160000 csrc/flash-attention diff --git a/csrc/flash-attention b/csrc/flash-attention deleted file mode 160000 index 7429988c59155..0000000000000 --- a/csrc/flash-attention +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7429988c59155582366f5644db292adb6f6f22b1 From 8a209ff2b970e3d22d4d3d6d5bbb765e1b9998fa Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 02:28:38 +0000 Subject: [PATCH 12/81] Use prefix-enabled attention --- vllm/attention/backends/flash_attn.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 69b4678dc7f30..74b3cf67d6936 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -234,9 +234,17 @@ def forward( ) else: # prefix-enabled attention - raise NotImplementedError( - "Prefix-enabled attention is not supported by " - "the FlashAttention backend yet.") + assert self.alibi_slopes is None + output = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ) else: # Decoding run. # TODO(skrider): tune num_splits heuristic From 31f741d0ce1aa222ab7f2f0f8b58b30f32d5fab2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 02:33:43 +0000 Subject: [PATCH 13/81] Disable flash-attn backend --- vllm/attention/selector.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 90fce1a0349b2..7e7cc388dae12 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -43,4 +43,8 @@ def _can_use_flash_attn(dtype: torch.dtype) -> bool: except ImportError: logger.info("flash_attn is not found.") return False - return True + # TODO(woosuk): Remove this once our custom build of flash_attn becomes + # available. + logger.info("FlashAttention backend is disabled for now since it requires " + "our custom build of flash_attn.") + return False From 70efebb239fb640a2601f0e78f833d99e8a10de9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 02:42:02 +0000 Subject: [PATCH 14/81] Minor --- csrc/cache_kernels.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index dad156cb80cc9..d0d61e7375379 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -213,8 +213,8 @@ template __global__ void reshape_and_cache_flash_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, @@ -306,8 +306,8 @@ void reshape_and_cache( void reshape_and_cache_flash( torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype) { From 56b78e6a32b04f4f2110fcca1776e18530098bce Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 04:22:19 +0000 Subject: [PATCH 15/81] Remove __ldg for AMD portability --- csrc/cache_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d0d61e7375379..eab61861fb728 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -240,8 +240,8 @@ __global__ void reshape_and_cache_flash_kernel( + block_offset * num_heads * head_size + head_idx * head_size + head_offset; - k_cache[tgt_value_idx] = __ldg(&key[src_key_idx]); - v_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); + k_cache[tgt_value_idx] = key[src_key_idx]; + v_cache[tgt_value_idx] = value[src_value_idx]; } } From f119396709d9f49e825096727f2596033e8ae1cd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 07:42:11 +0000 Subject: [PATCH 16/81] Remove assert --- vllm/attention/backends/flash_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 74b3cf67d6936..caf82c5e0b988 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -234,7 +234,6 @@ def forward( ) else: # prefix-enabled attention - assert self.alibi_slopes is None output = flash_attn_varlen_func( q=query, k=key_cache, From b6a1833c656e85dcbd56abae27e481bc85ddd0c4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 15:26:15 +0000 Subject: [PATCH 17/81] Add causal=True --- vllm/attention/backends/flash_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index caf82c5e0b988..2d8b3c293082a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple, Type import torch -from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -254,6 +254,7 @@ def forward( block_table=attn_metadata.block_tables, cache_seqlens=attn_metadata.context_lens, softmax_scale=self.scale, + causal=True, alibi_slopes=self.alibi_slopes, ) From da5067881a75a620fb88574a6b3f9015296c1dc9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 15:28:43 +0000 Subject: [PATCH 18/81] Enable when vllm_flash_attn --- vllm/attention/selector.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7e7cc388dae12..54319a31e37f9 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -39,12 +39,8 @@ def _can_use_flash_attn(dtype: torch.dtype) -> bool: return False try: - import flash_attn # noqa: F401 + import vllm_flash_attn # noqa: F401 except ImportError: - logger.info("flash_attn is not found.") + logger.info("vllm_flash_attn is not found.") return False - # TODO(woosuk): Remove this once our custom build of flash_attn becomes - # available. - logger.info("FlashAttention backend is disabled for now since it requires " - "our custom build of flash_attn.") - return False + return True From 37cb5a9c6825f072421539be624e7851acbd2586 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 15:37:08 +0000 Subject: [PATCH 19/81] Add vllm-flash-attn as dependency --- requirements.txt | 3 ++- setup.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a85f5d2c60a6d..ad7ba99d09aa8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ numpy torch == 2.1.2 transformers >= 4.39.1 # Required for StarCoder2 & Llava. xformers == 0.0.23.post1 # Required for CUDA 12.1. +vllm-flash-attn == 2.5.6 # Requires PyTorch 2.1.2. fastapi uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. @@ -14,4 +15,4 @@ prometheus_client >= 0.18.0 pynvml == 11.5.0 triton >= 2.1.0 outlines == 0.0.34 -tiktoken == 0.6.0 # Required for DBRX tokenizer +tiktoken == 0.6.0 # Required for DBRX tokenizer diff --git a/setup.py b/setup.py index 225fda0a0b412..567f847dd1f8a 100644 --- a/setup.py +++ b/setup.py @@ -316,6 +316,13 @@ def get_requirements() -> List[str]: if _is_cuda(): with open(get_path("requirements.txt")) as f: requirements = f.read().strip().split("\n") + if get_nvcc_cuda_version() <= Version("11.8"): + # Remove vllm-flash-attn from requirements for CUDA 11.x build + # as it is optional and not supported. + requirements = [ + r for r in requirements + if not r.startswith("vllm-flash-attn") + ] elif _is_hip(): with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") From df138245c53c2f83616c19fa551fb7315e1ae0ad Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 15:58:17 +0000 Subject: [PATCH 20/81] Fix prefix attention --- vllm/attention/backends/flash_attn.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 2d8b3c293082a..9a74740865f48 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -234,14 +234,15 @@ def forward( ) else: # prefix-enabled attention - output = flash_attn_varlen_func( + output = flash_attn_with_kvcache( q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=attn_metadata.context_lens, # FIXME + block_table=attn_metadata.block_tables, softmax_scale=self.scale, causal=True, + window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, ) else: From 9a02294b83db73d50eb813052429025fb6d86586 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Mar 2024 16:29:30 +0000 Subject: [PATCH 21/81] Fix prefix attention --- vllm/attention/backends/flash_attn.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 9a74740865f48..44afd2f5360e7 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -234,16 +234,19 @@ def forward( ) else: # prefix-enabled attention - output = flash_attn_with_kvcache( + output = flash_attn_varlen_func( q=query, - k_cache=key_cache, - v_cache=value_cache, - cache_seqlens=attn_metadata.context_lens, # FIXME - block_table=attn_metadata.block_tables, + k=key_cache, + v=value_cache, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.context_lens, # FIXME + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_context_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + block_table=attn_metadata.block_tables, ) else: # Decoding run. From 4553846f3b935be3ec55bc0691c1e9f953612e85 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 28 Mar 2024 22:09:53 +0000 Subject: [PATCH 22/81] add test --- tests/kernels/test_flash_attn.py | 124 +++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 tests/kernels/test_flash_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py new file mode 100644 index 0000000000000..69217497494b2 --- /dev/null +++ b/tests/kernels/test_flash_attn.py @@ -0,0 +1,124 @@ +from xformers import ops as xops +import torch +import random +from vllm_flash_attn import flash_attn_varlen_func +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +import pytest + +NUM_HEADS = [8] +NUM_QUERIES_PER_KV = [1] +HEAD_SIZES = [128] +DTYPES = [torch.float16] + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_flashatten_varlen( + num_heads: int, num_queries_per_kv: int, head_size: int, dtype: torch.dtype +): + + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device("cuda") + batch_size = 10 + cache_size = 640 + block_size = 16 + + prefix_lens = [random.randint(16, 128) for _ in range(batch_size)] + append_lens = [random.randint(16, 128) for _ in range(batch_size)] + seq_lens = [a + b for a, b in zip(prefix_lens, append_lens)] + + num_tokens = sum(append_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + + num_kv_heads = num_heads // num_queries_per_kv + key_value = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + key_value.uniform_(-1e-3, 1e-3) + key, value = key_value.unbind(dim=1) + + append_key = torch.zeros(sum(append_lens), num_kv_heads, head_size, dtype=dtype) + append_value = torch.zeros(sum(append_lens), num_kv_heads, head_size, dtype=dtype) + + values = torch.arange(0, cache_size, dtype=torch.int32) + values = values[torch.randperm(cache_size)] + max_block_per_request = int(cache_size / batch_size) + block_table = values[: batch_size * max_block_per_request].view( + batch_size, max_block_per_request + ) + + k_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, dtype=dtype) + v_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, dtype=dtype) + + qo_indptr = torch.cumsum(torch.tensor([0] + append_lens), dim=0, dtype=torch.int32) + seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_lens), dim=0, dtype=torch.int32 + ) + cu_prefix_lens = torch.cumsum(torch.tensor([0]+prefix_lens), dim=0, dtype=torch.int32) + + for i in range(batch_size): + # copy key, value to append_key, append_value + for j in range(append_lens[i]): + append_key[qo_indptr[i] + j].copy_( + key[seq_start_loc[i] + prefix_lens[i] + j] + ) + append_value[qo_indptr[i] + j].copy_( + value[seq_start_loc[i] + prefix_lens[i] + j] + ) + + # copy key, value to kv cache + cur_prefix_id = 0 + block_id = 0 + while cur_prefix_id < prefix_lens[i]: + start_loc = seq_start_loc[i] + cur_prefix_id + if cur_prefix_id + block_size > prefix_lens[i]: + end_loc = seq_start_loc[i] + prefix_lens[i] + else: + end_loc = start_loc + block_size + + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc] + ) + v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc] + ) + cur_prefix_id += block_size + block_id += 1 + + + scale = float(1.0 / (head_size**0.5)) + output = flash_attn_varlen_func(q=query, + k=k_cache, + v=v_cache, + cu_seqlens_q=qo_indptr, + cu_seqlens_k=cu_prefix_lens, + max_seqlen_q=max(append_lens), + max_seqlen_k=max(seq_lens), + softmax_scale=scale, + block_table=block_table) + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( + append_lens, seq_lens + ) + attn_op = xops.fmha.cutlass.FwOp() + output_ref = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ).squeeze(0) + + assert torch.allclose(output_ref, output, atol=1e-4, rtol=1e-5) From ff193047c787227a05683b1bbb1694bef9609f26 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 28 Mar 2024 23:59:51 +0000 Subject: [PATCH 23/81] fix --- tests/kernels/test_flash_attn.py | 147 +++++++++++++++++++++++++++---- 1 file changed, 129 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 69217497494b2..b5cf493130e6c 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,12 +4,129 @@ from vllm_flash_attn import flash_attn_varlen_func from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask import pytest +import math + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat NUM_HEADS = [8] NUM_QUERIES_PER_KV = [1] HEAD_SIZES = [128] DTYPES = [torch.float16] +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, + reorder_ops=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @@ -35,15 +152,15 @@ def test_flashatten_varlen( num_tokens = sum(append_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - query.uniform_(-1e-3, 1e-3) + query.uniform_(-1e-1, 1e-1) num_kv_heads = num_heads // num_queries_per_kv key_value = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) - key_value.uniform_(-1e-3, 1e-3) + key_value.uniform_(-1e-1, 1e-1) key, value = key_value.unbind(dim=1) - append_key = torch.zeros(sum(append_lens), num_kv_heads, head_size, dtype=dtype) - append_value = torch.zeros(sum(append_lens), num_kv_heads, head_size, dtype=dtype) + # append_key = torch.zeros(sum(append_lens), num_kv_heads, head_size, dtype=dtype) + # append_value = torch.zeros(sum(append_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] @@ -59,25 +176,16 @@ def test_flashatten_varlen( seq_start_loc = torch.cumsum( torch.tensor([0] + seq_lens), dim=0, dtype=torch.int32 ) - cu_prefix_lens = torch.cumsum(torch.tensor([0]+prefix_lens), dim=0, dtype=torch.int32) + cu_prefix_lens = torch.cumsum(torch.tensor([0] + prefix_lens), dim=0, dtype=torch.int32) for i in range(batch_size): - # copy key, value to append_key, append_value - for j in range(append_lens[i]): - append_key[qo_indptr[i] + j].copy_( - key[seq_start_loc[i] + prefix_lens[i] + j] - ) - append_value[qo_indptr[i] + j].copy_( - value[seq_start_loc[i] + prefix_lens[i] + j] - ) - # copy key, value to kv cache cur_prefix_id = 0 block_id = 0 - while cur_prefix_id < prefix_lens[i]: + while cur_prefix_id < seq_lens[i]: start_loc = seq_start_loc[i] + cur_prefix_id - if cur_prefix_id + block_size > prefix_lens[i]: - end_loc = seq_start_loc[i] + prefix_lens[i] + if cur_prefix_id + block_size > seq_lens[i]: + end_loc = seq_start_loc[i] + seq_lens[i] else: end_loc = start_loc + block_size @@ -98,7 +206,7 @@ def test_flashatten_varlen( k=k_cache, v=v_cache, cu_seqlens_q=qo_indptr, - cu_seqlens_k=cu_prefix_lens, + cu_seqlens_k=seq_start_loc, max_seqlen_q=max(append_lens), max_seqlen_k=max(seq_lens), softmax_scale=scale, @@ -120,5 +228,8 @@ def test_flashatten_varlen( scale=scale, op=attn_op, ).squeeze(0) + + output_ref, _ = attention_ref(query,key,value, causal=True) + assert torch.allclose(output_ref, output, atol=1e-4, rtol=1e-5) From dbeeb8a174bb9694d3c9876238955f471bba0aea Mon Sep 17 00:00:00 2001 From: skrider Date: Tue, 2 Apr 2024 04:20:35 +0000 Subject: [PATCH 24/81] fix test --- tests/kernels/test_flash_attn.py | 180 +++++-------------------------- 1 file changed, 26 insertions(+), 154 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index b5cf493130e6c..936c1faddeb52 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,139 +4,18 @@ from vllm_flash_attn import flash_attn_varlen_func from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask import pytest -import math - -import pytest -import torch -import torch.nn.functional as F -from einops import rearrange, repeat - NUM_HEADS = [8] NUM_QUERIES_PER_KV = [1] HEAD_SIZES = [128] DTYPES = [torch.float16] - -def construct_local_mask( - seqlen_q, - seqlen_k, - window_size=(-1, -1), # -1 means infinite window size - query_padding_mask=None, - key_padding_mask=None, - device=None, -): - row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") - col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - sk = ( - seqlen_k - if key_padding_mask is None - else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - ) - sq = ( - seqlen_q - if query_padding_mask is None - else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") - ) - if window_size[0] < 0: - return col_idx > row_idx + sk - sq + window_size[1] - else: - sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk - return torch.logical_or( - col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - col_idx < row_idx + sk - sq - window_size[0], - ) - -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - attn_bias=None, - dropout_p=0.0, - dropout_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - upcast=True, - reorder_ops=False, -): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - causal: whether to apply causal masking - window_size: (int, int), left and right window size - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - d = q.shape[-1] - if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) - else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if window_size[0] >= 0 or window_size[1] >= 0: - local_mask = construct_local_mask( - seqlen_q, - seqlen_k, - window_size, - query_padding_mask, - key_padding_mask, - q.device, - ) - scores.masked_fill_(local_mask, float("-inf")) - if attn_bias is not None: - scores = scores + attn_bias - attention = torch.softmax(scores, dim=-1).to(v.dtype) - # Some rows might be completely masked out so we fill them with zero instead of NaN - if window_size[0] >= 0 or window_size[1] >= 0: - attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) - # We want to mask here so that the attention matrix doesn't have any NaNs - # Otherwise we'll get NaN in dV - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - dropout_scaling = 1.0 / (1 - dropout_p) - # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling - # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) - if dropout_mask is not None: - attention_drop = attention.masked_fill(~dropout_mask, 0.0) - else: - attention_drop = attention - output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @torch.inference_mode() -def test_flashatten_varlen( +def test_flashinfer_append( num_heads: int, num_queries_per_kv: int, head_size: int, dtype: torch.dtype ): - random.seed(0) torch.manual_seed(0) if torch.cuda.is_available(): @@ -145,39 +24,32 @@ def test_flashatten_varlen( batch_size = 10 cache_size = 640 block_size = 16 - - prefix_lens = [random.randint(16, 128) for _ in range(batch_size)] - append_lens = [random.randint(16, 128) for _ in range(batch_size)] + prefix_lens = [random.randint(100, 200) for _ in range(batch_size)] + append_lens = [random.randint(1, 5) for _ in range(batch_size)] seq_lens = [a + b for a, b in zip(prefix_lens, append_lens)] - num_tokens = sum(append_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-1, 1e-1) - num_kv_heads = num_heads // num_queries_per_kv key_value = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) key_value.uniform_(-1e-1, 1e-1) key, value = key_value.unbind(dim=1) - - # append_key = torch.zeros(sum(append_lens), num_kv_heads, head_size, dtype=dtype) - # append_value = torch.zeros(sum(append_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] max_block_per_request = int(cache_size / batch_size) block_table = values[: batch_size * max_block_per_request].view( batch_size, max_block_per_request ) - k_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, dtype=dtype) v_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, dtype=dtype) - qo_indptr = torch.cumsum(torch.tensor([0] + append_lens), dim=0, dtype=torch.int32) seq_start_loc = torch.cumsum( torch.tensor([0] + seq_lens), dim=0, dtype=torch.int32 ) - cu_prefix_lens = torch.cumsum(torch.tensor([0] + prefix_lens), dim=0, dtype=torch.int32) - + paged_kv_last_page_len = [] + paged_kv_indptr = [0] + page_kv_indices = [] + total_block_num = 0 for i in range(batch_size): # copy key, value to kv cache cur_prefix_id = 0 @@ -188,7 +60,6 @@ def test_flashatten_varlen( end_loc = seq_start_loc[i] + seq_lens[i] else: end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( @@ -199,25 +70,29 @@ def test_flashatten_varlen( ) cur_prefix_id += block_size block_id += 1 - - - scale = float(1.0 / (head_size**0.5)) - output = flash_attn_varlen_func(q=query, - k=k_cache, - v=v_cache, - cu_seqlens_q=qo_indptr, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max(append_lens), - max_seqlen_k=max(seq_lens), - softmax_scale=scale, - block_table=block_table) - + paged_kv_last_page_len.append((seq_lens[i] - 1) % block_size + 1) + cur_block_num = (seq_lens[i] - 1) // block_size + 1 + page_kv_indices.extend(block_table[i, :cur_block_num]) + total_block_num += cur_block_num + paged_kv_indptr.append(total_block_num) + output = flash_attn_varlen_func( + query, + k_cache, + v_cache, + cu_seqlens_q=qo_indptr, + cu_seqlens_k=seq_start_loc, + max_seqlen_q=max(append_lens), + max_seqlen_k=max(seq_lens), + causal=True, + block_table=block_table, + ) query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( append_lens, seq_lens ) + scale = float(1.0 / (head_size**0.5)) attn_op = xops.fmha.cutlass.FwOp() output_ref = xops.memory_efficient_attention_forward( query, @@ -228,8 +103,5 @@ def test_flashatten_varlen( scale=scale, op=attn_op, ).squeeze(0) - - output_ref, _ = attention_ref(query,key,value, causal=True) - - - assert torch.allclose(output_ref, output, atol=1e-4, rtol=1e-5) + print((output - output_ref).abs().max()) + assert torch.allclose(output_ref, output, atol=1e-4, rtol=1e-2) From 358886f85e914f266d3ee383e6c29e6d494c5afd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Apr 2024 18:03:04 +0000 Subject: [PATCH 25/81] Add vllm-flash-attn to requirements-cuda --- requirements-cuda.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index c6d2cd46aee54..248b4b06a1025 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,3 +7,4 @@ pynvml == 11.5.0 vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 +vllm-flash-attn == 2.5.7 # Requires PyTorch 2.2.1 From 00138596beb46d0b210c2b4786a319b1affdaacc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Apr 2024 18:03:18 +0000 Subject: [PATCH 26/81] Minor --- vllm/attention/selector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index b5bf1b3023edc..4e5a97c831cca 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -74,7 +74,6 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: try: import vllm_flash_attn # noqa: F401 except ImportError: - logger.info("vllm_flash_attn is not found.") logger.info( "Cannot use FlashAttention backend because the vllm_flash_attn " "package is not found. `pip install vllm-flash-attn` for better " From 45cb2d68bbf13dded6c8f71e96627ddfcfac37bc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Apr 2024 18:19:29 +0000 Subject: [PATCH 27/81] Fix --- vllm/attention/backends/flash_attn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 28df46c203c62..34aec36750c36 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -57,8 +57,7 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadataPerStage): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is From 2800f2e73e4cd396c1daa22ae1c21c4b419991a1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Apr 2024 18:21:14 +0000 Subject: [PATCH 28/81] yapf --- tests/kernels/test_flash_attn.py | 62 +++++++++++++++++++------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 936c1faddeb52..7b19590a217f0 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,18 +4,20 @@ from vllm_flash_attn import flash_attn_varlen_func from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask import pytest + NUM_HEADS = [8] NUM_QUERIES_PER_KV = [1] HEAD_SIZES = [128] DTYPES = [torch.float16] + + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @torch.inference_mode() -def test_flashinfer_append( - num_heads: int, num_queries_per_kv: int, head_size: int, dtype: torch.dtype -): +def test_flashinfer_append(num_heads: int, num_queries_per_kv: int, + head_size: int, dtype: torch.dtype): random.seed(0) torch.manual_seed(0) if torch.cuda.is_available(): @@ -31,21 +33,34 @@ def test_flashinfer_append( query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-1, 1e-1) num_kv_heads = num_heads // num_queries_per_kv - key_value = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + key_value = torch.empty(sum(seq_lens), + 2, + num_kv_heads, + head_size, + dtype=dtype) key_value.uniform_(-1e-1, 1e-1) key, value = key_value.unbind(dim=1) values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] max_block_per_request = int(cache_size / batch_size) - block_table = values[: batch_size * max_block_per_request].view( - batch_size, max_block_per_request - ) - k_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, dtype=dtype) - v_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, dtype=dtype) - qo_indptr = torch.cumsum(torch.tensor([0] + append_lens), dim=0, dtype=torch.int32) - seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens), dim=0, dtype=torch.int32 - ) + block_table = values[:batch_size * max_block_per_request].view( + batch_size, max_block_per_request) + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + qo_indptr = torch.cumsum(torch.tensor([0] + append_lens), + dim=0, + dtype=torch.int32) + seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens), + dim=0, + dtype=torch.int32) paged_kv_last_page_len = [] paged_kv_indptr = [0] page_kv_indices = [] @@ -62,12 +77,12 @@ def test_flashinfer_append( end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc] - ) - v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc] - ) + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) cur_prefix_id += block_size block_id += 1 paged_kv_last_page_len.append((seq_lens[i] - 1) % block_size + 1) @@ -76,9 +91,9 @@ def test_flashinfer_append( total_block_num += cur_block_num paged_kv_indptr.append(total_block_num) output = flash_attn_varlen_func( - query, - k_cache, - v_cache, + query, + k_cache, + v_cache, cu_seqlens_q=qo_indptr, cu_seqlens_k=seq_start_loc, max_seqlen_q=max(append_lens), @@ -90,8 +105,7 @@ def test_flashinfer_append( key = key.unsqueeze(0) value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - append_lens, seq_lens - ) + append_lens, seq_lens) scale = float(1.0 / (head_size**0.5)) attn_op = xops.fmha.cutlass.FwOp() output_ref = xops.memory_efficient_attention_forward( From 0fe3ec5a880c60fe437a5a18bc667fefe7ba696f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Apr 2024 18:22:46 +0000 Subject: [PATCH 29/81] isort --- tests/kernels/test_flash_attn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 7b19590a217f0..470621dcd4fb9 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -1,9 +1,10 @@ -from xformers import ops as xops -import torch import random + +import pytest +import torch from vllm_flash_attn import flash_attn_varlen_func +from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask -import pytest NUM_HEADS = [8] NUM_QUERIES_PER_KV = [1] From 977afb65190baf536760b384befd41aa93af00b9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Apr 2024 18:56:07 +0000 Subject: [PATCH 30/81] Fix --- vllm/attention/backends/flash_attn.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 34aec36750c36..156c7140b6da9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -103,11 +103,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The length of context (tokens stored in KV cache) per # sequence. WARNING: When it is a prefill request, it doesn't include new # tokens. When it is for decoding, it includes a new token. @@ -121,7 +116,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] - kv_cache_dtype: str class FlashAttentionImpl(AttentionImpl): From c55627a036737630b6aa58051c4fa6b66d81fe6c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Apr 2024 19:02:23 +0000 Subject: [PATCH 31/81] Fix --- vllm/attention/backends/flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 156c7140b6da9..fefd78cafe341 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -275,7 +275,7 @@ def forward( # Decoding run. # TODO(skrider): tune num_splits heuristic output[num_prefill_tokens:] = flash_attn_with_kvcache( - query.unsqueeze(1), + decode_query.unsqueeze(1), key_cache, value_cache, block_table=decode_meta.block_tables, @@ -283,7 +283,7 @@ def forward( softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, - ) + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size) From d7767ab78ff48ae07aedc5fda1c93ebb6cd23a67 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 6 May 2024 17:51:31 +0000 Subject: [PATCH 32/81] Fix --- vllm/attention/backends/flash_attn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index fefd78cafe341..7553097b54bee 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -260,10 +260,10 @@ def forward( q=query, k=key_cache, v=value_cache, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.context_lens, # FIXME - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_context_len, + cu_seqlens_q=prefill_meta.subquery_start_loc, + max_seqlen_q=prefill_meta.max_subquery_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -273,7 +273,6 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. - # TODO(skrider): tune num_splits heuristic output[num_prefill_tokens:] = flash_attn_with_kvcache( decode_query.unsqueeze(1), key_cache, From f9d200b311ce2c7b7219d58b5ac147109f215653 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 6 May 2024 17:59:23 +0000 Subject: [PATCH 33/81] Add test --- tests/kernels/test_flash_attn.py | 216 ++++++++++++++++--------------- 1 file changed, 109 insertions(+), 107 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 470621dcd4fb9..e40ecdfa3c25b 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -1,122 +1,124 @@ -import random +from typing import List, Tuple import pytest import torch from vllm_flash_attn import flash_attn_varlen_func -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask -NUM_HEADS = [8] -NUM_QUERIES_PER_KV = [1] -HEAD_SIZES = [128] -DTYPES = [torch.float16] - -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode() -def test_flashinfer_append(num_heads: int, num_queries_per_kv: int, - head_size: int, dtype: torch.dtype): - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", [(16, 16), (32, 8), (64, 8)]) +@pytest.mark.parametrize("head_size", [128, 256]) +@pytest.mark.parametrize("block_size", [16, 32]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@torch.inference_mode +def test_flash_attn( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, +) -> None: torch.set_default_device("cuda") - batch_size = 10 - cache_size = 640 - block_size = 16 - prefix_lens = [random.randint(100, 200) for _ in range(batch_size)] - append_lens = [random.randint(1, 5) for _ in range(batch_size)] - seq_lens = [a + b for a, b in zip(prefix_lens, append_lens)] - num_tokens = sum(append_lens) - query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - query.uniform_(-1e-1, 1e-1) - num_kv_heads = num_heads // num_queries_per_kv - key_value = torch.empty(sum(seq_lens), - 2, + torch.cuda.manual_seed_all(0) + num_blocks = 128 + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, num_kv_heads, head_size, dtype=dtype) - key_value.uniform_(-1e-1, 1e-1) - key, value = key_value.unbind(dim=1) - values = torch.arange(0, cache_size, dtype=torch.int32) - values = values[torch.randperm(cache_size)] - max_block_per_request = int(cache_size / batch_size) - block_table = values[:batch_size * max_block_per_request].view( - batch_size, max_block_per_request) - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - qo_indptr = torch.cumsum(torch.tensor([0] + append_lens), - dim=0, - dtype=torch.int32) - seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens), - dim=0, + value_cache = torch.randn_like(key_cache) + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + cu_kv_lens = torch.tensor([0] + kv_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - paged_kv_last_page_len = [] - paged_kv_indptr = [0] - page_kv_indices = [] - total_block_num = 0 - for i in range(batch_size): - # copy key, value to kv cache - cur_prefix_id = 0 - block_id = 0 - while cur_prefix_id < seq_lens[i]: - start_loc = seq_start_loc[i] + cur_prefix_id - if cur_prefix_id + block_size > seq_lens[i]: - end_loc = seq_start_loc[i] + seq_lens[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) - cur_prefix_id += block_size - block_id += 1 - paged_kv_last_page_len.append((seq_lens[i] - 1) % block_size + 1) - cur_block_num = (seq_lens[i] - 1) // block_size + 1 - page_kv_indices.extend(block_table[i, :cur_block_num]) - total_block_num += cur_block_num - paged_kv_indptr.append(total_block_num) + output = flash_attn_varlen_func( - query, - k_cache, - v_cache, - cu_seqlens_q=qo_indptr, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max(append_lens), - max_seqlen_k=max(seq_lens), + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, causal=True, - block_table=block_table, + block_table=block_tables, ) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - append_lens, seq_lens) - scale = float(1.0 / (head_size**0.5)) - attn_op = xops.fmha.cutlass.FwOp() - output_ref = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias, - p=0.0, + + ref_output = ref_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, scale=scale, - op=attn_op, - ).squeeze(0) - print((output - output_ref).abs().max()) - assert torch.allclose(output_ref, output, atol=1e-4, rtol=1e-2) + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) + + +def ref_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k) + mask = torch.triu(torch.ones(query_len, kv_len), + diagonal=kv_len - query_len + 1).bool() + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) From 5b5cfae19706320d690fafc831d9785f9b1b9a69 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 6 May 2024 18:02:41 +0000 Subject: [PATCH 34/81] Fix test --- tests/kernels/test_flash_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index e40ecdfa3c25b..f28d27b3a533b 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -76,7 +76,8 @@ def test_flash_attn( block_tables=block_tables, scale=scale, ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" def ref_attention( From 738f7fc5a5e9f3d143cf68852593cb4dbe4e2156 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 6 May 2024 19:54:29 +0000 Subject: [PATCH 35/81] Upgrade vllm-flash-attn --- requirements-cuda.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 248b4b06a1025..5ca79ee62124d 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,4 @@ pynvml == 11.5.0 vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 -vllm-flash-attn == 2.5.7 # Requires PyTorch 2.2.1 +vllm-flash-attn == 2.5.8 # Requires PyTorch 2.3.0 From b45cfb6562fcb44e9e89770cd2b564869d3e9029 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 6 May 2024 22:38:05 +0000 Subject: [PATCH 36/81] Minor --- csrc/cache_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index a31c12a579d5c..42f884c76c620 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -250,7 +250,6 @@ __global__ void reshape_and_cache_flash_kernel( v_cache[tgt_value_idx] = value[src_value_idx]; } } - } // namespace vllm #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ @@ -319,6 +318,7 @@ void reshape_and_cache_flash( torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype) { + // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } From 60648fd42225619f124adb55b6c06dd211f90ce2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 6 May 2024 22:41:06 +0000 Subject: [PATCH 37/81] Fix --- vllm/attention/backends/flash_attn.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6a3dfcc224d38..bb57f0de9f692 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -10,6 +10,8 @@ AttentionMetadata, AttentionMetadataPerStage) +_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] + class FlashAttentionBackend(AttentionBackend): @@ -28,8 +30,8 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if block_size != 16: - raise ValueError("FlashAttention only supports block size 16.") + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod @@ -166,11 +168,10 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = [32, 64, 96, 128, 160, 192, 224, 256] - if head_size not in suppored_head_sizes: + if head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") def forward( self, From eeb8050594709d16970915f81b7e27bbfe8461af Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 00:06:17 +0000 Subject: [PATCH 38/81] Fix --- vllm/attention/backends/abstract.py | 5 +++++ vllm/attention/backends/flash_attn.py | 19 +++++++++---------- vllm/attention/backends/flashinfer.py | 16 +++++++--------- vllm/attention/backends/rocm_flash_attn.py | 4 ++++ vllm/attention/backends/torch_sdpa.py | 4 ++++ vllm/attention/backends/xformers.py | 4 ++++ vllm/worker/model_runner.py | 16 +++++++++------- 7 files changed, 42 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 61c9c81d8a7b8..678634b894b39 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,6 +9,11 @@ class AttentionBackend(ABC): """Abstract class for attention backends.""" + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + @staticmethod @abstractmethod def get_impl_cls() -> Type["AttentionImpl"]: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bb57f0de9f692..821849eb55765 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -15,6 +15,10 @@ class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flash-attn" + @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl @@ -105,12 +109,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # Maximum context length in the batch. - max_context_len: Optional[int] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -235,7 +233,8 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -262,9 +261,9 @@ def forward( k=key_cache, v=value_cache, cu_seqlens_q=prefill_meta.subquery_start_loc, - max_seqlen_q=prefill_meta.max_subquery_len, + max_seqlen_q=prefill_meta.max_query_len, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -279,7 +278,7 @@ def forward( key_cache, value_cache, block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.context_lens, + cache_seqlens=decode_meta.seq_lens_tensor, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8ab4b1f12ee36..0ae5485fba146 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,16 +1,10 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type -try: - import flashinfer - from flash_attn import flash_attn_varlen_func - from flashinfer import BatchDecodeWithPagedKVCacheWrapper -except ImportError: - flashinfer = None - flash_attn_varlen_func = None - BatchDecodeWithPagedKVCacheWrapper = None - +import flashinfer import torch +from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from vllm_flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -20,6 +14,10 @@ class FlashInferBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flashinfer" + @staticmethod def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c411b3971b8f1..0532acf7a7f07 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,6 +17,10 @@ class ROCmFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "rocm-flash-attn" + @staticmethod def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index f75a279086a26..4f4fe50ce51a7 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -15,6 +15,10 @@ class TorchSDPABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "torch-sdpa" + @staticmethod def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 60f6d43f2eaa4..8ba2a323da5b7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -20,6 +20,10 @@ class XFormersBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "xformers" + @staticmethod def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab248596490f6..870a922a619c5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,7 +9,6 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -273,20 +272,23 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) + if self.attn_backend.get_name() == "flash-attn": + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) else: # The first prefill. - prefix_block_tables.append([]) + block_table = [] else: - prefix_block_tables.append([]) + block_table = [] # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 + prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len) @@ -395,7 +397,7 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": attn_metadata = self.attn_backend.make_metadata( is_prompt=True, use_cuda_graph=False, @@ -556,7 +558,7 @@ def _prepare_decode( device=self.device, ) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): # Allocate 16MB workspace buffer # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html From 5bbd2d3cfab3f70a111fd61861e0f59e68b9a106 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 00:15:36 +0000 Subject: [PATCH 39/81] Remove flash-attn from dockerfile --- Dockerfile | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/Dockerfile b/Dockerfile index 90be3a30f89b1..ddca95c0e8786 100644 --- a/Dockerfile +++ b/Dockerfile @@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.5.8 -ENV FLASH_ATTN_VERSION=${flash_attn_version} - -WORKDIR /usr/src/flash-attention-v2 - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir - -#################### FLASH_ATTENTION Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base @@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose - -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - --mount=type=cache,target=/root/.cache/pip \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### From b72cd1349f168ab791ed60332210dd3430866477 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 04:53:14 +0000 Subject: [PATCH 40/81] Add test for flash_attn_with_kv_cache --- tests/kernels/test_flash_attn.py | 170 +++++++++++++++++++++---------- 1 file changed, 118 insertions(+), 52 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index f28d27b3a533b..632223b3715fa 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -2,16 +2,127 @@ import pytest import torch -from vllm_flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k) + mask = torch.triu(torch.ones(query_len, kv_len), + diagonal=kv_len - query_len + 1).bool() + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_flash_attn_with_paged_kv( + kv_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_blocks = 128 + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + ).squeeze(1) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) -@pytest.mark.parametrize("num_heads", [(16, 16), (32, 8), (64, 8)]) -@pytest.mark.parametrize("head_size", [128, 256]) -@pytest.mark.parametrize("block_size", [16, 32]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) @torch.inference_mode -def test_flash_attn( +def test_varlen_with_paged_kv( seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], head_size: int, @@ -67,7 +178,7 @@ def test_flash_attn( block_table=block_tables, ) - ref_output = ref_attention( + ref_output = ref_paged_attn( query=query, key_cache=key_cache, value_cache=value_cache, @@ -78,48 +189,3 @@ def test_flash_attn( ) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" - - -def ref_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - query_lens: List[int], - kv_lens: List[int], - block_tables: torch.Tensor, - scale: float, -) -> torch.Tensor: - num_seqs = len(query_lens) - block_tables = block_tables.cpu().numpy() - _, block_size, num_kv_heads, head_size = key_cache.shape - - outputs = [] - start_idx = 0 - for i in range(num_seqs): - query_len = query_lens[i] - kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] - q *= scale - - num_kv_blocks = (kv_len + block_size - 1) // block_size - block_indices = block_tables[i, :num_kv_blocks] - - k = key_cache[block_indices].view(-1, num_kv_heads, head_size) - k = k[:kv_len] - v = value_cache[block_indices].view(-1, num_kv_heads, head_size) - v = v[:kv_len] - - if q.shape[1] != k.shape[1]: - k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) - v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) - attn = torch.einsum("qhd,khd->hqk", q, k) - mask = torch.triu(torch.ones(query_len, kv_len), - diagonal=kv_len - query_len + 1).bool() - attn.masked_fill_(mask, float("-inf")) - attn = torch.softmax(attn, dim=-1) - out = torch.einsum("hqk,khd->qhd", attn, v) - - outputs.append(out) - start_idx += query_len - - return torch.cat(outputs, dim=0) From 4230040b1691bd344e0a27b4fcd1e67cec204ed5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 04:53:32 +0000 Subject: [PATCH 41/81] Bump up vllm-flash-attn --- requirements-cuda.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 249ccdafcb5e2..ba8c614d205d2 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 -vllm-flash-attn == 2.5.8 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0 From 7cd9b7348c00d4d5ed7021b112b93f90393ae6ad Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 06:19:36 +0000 Subject: [PATCH 42/81] Handle FP8 KV cache --- vllm/attention/backends/flash_attn.py | 2 ++ vllm/attention/layer.py | 2 +- vllm/attention/selector.py | 43 +++++++++++++++++++-------- vllm/worker/cache_engine.py | 2 +- vllm/worker/cpu_model_runner.py | 4 +-- vllm/worker/cpu_worker.py | 2 +- vllm/worker/model_runner.py | 3 +- 7 files changed, 40 insertions(+), 18 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 02e6972e02077..c1e8eb390e304 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -192,6 +192,8 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + assert not attn_metadata.kv_cache_dtype.startswith("fp8"), ( + "FlashAttention does not support FP8 KV cache.") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ee7be26c0876c..1d542a380b044 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -31,7 +31,7 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - self.backend = get_attn_backend(torch.get_default_dtype()) + self.backend = get_attn_backend() impl_cls = self.backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 1e38723516862..7af0d920b1450 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,6 +1,5 @@ import enum -from functools import lru_cache -from typing import Type +from typing import Optional, Type import torch @@ -11,6 +10,8 @@ logger = init_logger(__name__) +_ATTN_BACKEND: Optional[Type[AttentionBackend]] = None + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -20,38 +21,52 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() -@lru_cache(maxsize=None) -def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - backend = _which_attn_to_use(dtype) +def get_attn_backend( + dtype: Optional[torch.dtype] = None, + kv_cache_dtype: Optional[str] = None, +) -> Type[AttentionBackend]: + global _ATTN_BACKEND + if dtype is None: + assert kv_cache_dtype is None, "KV cache dtype should be None." + assert _ATTN_BACKEND is not None, "Attention backend is not set." + return _ATTN_BACKEND + + assert kv_cache_dtype is not None, "KV cache dtype is not set." + assert _ATTN_BACKEND is None, "Attention backend is already set." + backend = _which_attn_to_use(dtype, kv_cache_dtype) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - return FlashAttentionBackend + _ATTN_BACKEND = FlashAttentionBackend elif backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend) - return XFormersBackend + _ATTN_BACKEND = XFormersBackend elif backend == _Backend.ROCM_FLASH: logger.info("Using ROCmFlashAttention backend.") from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 ROCmFlashAttentionBackend) - return ROCmFlashAttentionBackend + _ATTN_BACKEND = ROCmFlashAttentionBackend elif backend == _Backend.TORCH_SDPA: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend - return TorchSDPABackend + _ATTN_BACKEND = TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is enforced for the Flashinfer backend. ") + logger.warning("Eager mode is enforced for the Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend - return FlashInferBackend + _ATTN_BACKEND = FlashInferBackend else: raise ValueError("Invalid attention backend.") + return _ATTN_BACKEND -def _which_attn_to_use(dtype: torch.dtype) -> _Backend: +def _which_attn_to_use( + dtype: torch.dtype, + kv_cache_dtype: str, +) -> _Backend: """Returns which flash attention backend to use.""" if is_cpu(): return _Backend.TORCH_SDPA @@ -75,6 +90,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS + if kv_cache_dtype.startswith("fp8"): + logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") + return _Backend.XFORMERS + try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 26a60c652b6f4..f8a54ec382e50 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -43,7 +43,7 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend() # Initialize the cache. self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 193b021b7a11e..17b6d73dcc2ef 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -53,9 +53,9 @@ def __init__( self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) + self.model_config.dtype if model_config is not None else None, + self.kv_cache_dtype) # Lazy initialization. self.model: nn.Module # Set after init_Model diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index e1ef500ac07b8..e5587795cfde6 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -53,7 +53,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend() # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 870a922a619c5..52873063bf7ab 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -142,7 +142,8 @@ def __init__( self.vision_language_config = vision_language_config self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) + self.model_config.dtype if model_config is not None else None, + self.kv_cache_dtype) # Lazy initialization self.model: torch.nn.Module # Set after load_model From 5370f8679dac6f4c491af385817e57cfb4b92c7b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 06:24:18 +0000 Subject: [PATCH 43/81] Add docstring --- vllm/attention/selector.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7af0d920b1450..aa8db50d3c680 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -25,6 +25,12 @@ def get_attn_backend( dtype: Optional[torch.dtype] = None, kv_cache_dtype: Optional[str] = None, ) -> Type[AttentionBackend]: + """Returns the attention backend to use. + + For the first call, the backend is selected based on the dtype and + kv_cache_dtype. For subsequent calls, the dtype and kv_cache_dtype should + be None, and the cached backend is returned. + """ global _ATTN_BACKEND if dtype is None: assert kv_cache_dtype is None, "KV cache dtype should be None." From 2f5b9b76cedcb3e78c0c8d9385d3ba6069746e06 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 07:59:32 +0000 Subject: [PATCH 44/81] Fix --- vllm/worker/model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 52873063bf7ab..7c297bc4dd334 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -141,9 +141,9 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None, - self.kv_cache_dtype) + self.attn_backend = (get_attn_backend(self.model_config.dtype, + self.kv_cache_dtype) + if self.model_config is not None else None) # Lazy initialization self.model: torch.nn.Module # Set after load_model From 4b05153f77cdd3a7f626bbae9193a9c0f6b71aa2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 16:33:51 +0000 Subject: [PATCH 45/81] Fix --- vllm/attention/selector.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index aa8db50d3c680..a441863a590f1 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,5 @@ import enum +from functools import lru_cache from typing import Optional, Type import torch @@ -28,45 +29,56 @@ def get_attn_backend( """Returns the attention backend to use. For the first call, the backend is selected based on the dtype and - kv_cache_dtype. For subsequent calls, the dtype and kv_cache_dtype should - be None, and the cached backend is returned. + kv_cache_dtype. The selected backend is cached for subsequent calls. """ global _ATTN_BACKEND if dtype is None: assert kv_cache_dtype is None, "KV cache dtype should be None." assert _ATTN_BACKEND is not None, "Attention backend is not set." return _ATTN_BACKEND + else: + assert kv_cache_dtype is not None, "KV cache dtype is not set." + attn_backend = select_attn_backend(dtype, kv_cache_dtype) + if _ATTN_BACKEND is None: + _ATTN_BACKEND = attn_backend + else: + assert attn_backend == _ATTN_BACKEND, ( + "Cannot change the attention backend after it is set.") + return _ATTN_BACKEND + - assert kv_cache_dtype is not None, "KV cache dtype is not set." - assert _ATTN_BACKEND is None, "Attention backend is already set." +@lru_cache(maxsize=None) +def select_attn_backend( + dtype: torch.dtype, + kv_cache_dtype: str, +) -> Type[AttentionBackend]: backend = _which_attn_to_use(dtype, kv_cache_dtype) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - _ATTN_BACKEND = FlashAttentionBackend + return FlashAttentionBackend elif backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend) - _ATTN_BACKEND = XFormersBackend + return XFormersBackend elif backend == _Backend.ROCM_FLASH: logger.info("Using ROCmFlashAttention backend.") from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 ROCmFlashAttentionBackend) - _ATTN_BACKEND = ROCmFlashAttentionBackend + return ROCmFlashAttentionBackend elif backend == _Backend.TORCH_SDPA: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend - _ATTN_BACKEND = TorchSDPABackend + return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") logger.warning("Eager mode is enforced for the Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend - _ATTN_BACKEND = FlashInferBackend + return FlashInferBackend else: raise ValueError("Invalid attention backend.") - return _ATTN_BACKEND def _which_attn_to_use( From 848a1d79288a2ebae8eb1d3dc1a6bd3cd50e3f5c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 7 May 2024 19:03:26 +0000 Subject: [PATCH 46/81] Fix --- vllm/attention/selector.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index a441863a590f1..dce0efb85d7e3 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -38,16 +38,10 @@ def get_attn_backend( return _ATTN_BACKEND else: assert kv_cache_dtype is not None, "KV cache dtype is not set." - attn_backend = select_attn_backend(dtype, kv_cache_dtype) - if _ATTN_BACKEND is None: - _ATTN_BACKEND = attn_backend - else: - assert attn_backend == _ATTN_BACKEND, ( - "Cannot change the attention backend after it is set.") + _ATTN_BACKEND = select_attn_backend(dtype, kv_cache_dtype) return _ATTN_BACKEND -@lru_cache(maxsize=None) def select_attn_backend( dtype: torch.dtype, kv_cache_dtype: str, @@ -81,6 +75,7 @@ def select_attn_backend( raise ValueError("Invalid attention backend.") +@lru_cache(maxsize=None) def _which_attn_to_use( dtype: torch.dtype, kv_cache_dtype: str, From d6996c199bd5d47c559a0620e88e7027f6171314 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 06:24:22 +0000 Subject: [PATCH 47/81] Set block size from beginning --- vllm/worker/cpu_model_runner.py | 21 ++++++------- vllm/worker/cpu_worker.py | 1 + vllm/worker/model_runner.py | 54 +++++++++++++-------------------- vllm/worker/worker.py | 2 +- 4 files changed, 32 insertions(+), 46 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 193b021b7a11e..6c8b1685dadcf 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -4,8 +4,9 @@ from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -26,6 +27,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], @@ -39,27 +41,22 @@ def __init__( self.scheduler_config = scheduler_config # Currently, CPU worker doesn't support chunked prefill. assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.vision_language_config = vision_language_config self.load_config = load_config self.is_driver_worker = is_driver_worker - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend(self.model_config.dtype) # Lazy initialization. self.model: nn.Module # Set after init_Model - self.block_size: int # Set after initial profiling. def load_model(self) -> None: self.model = get_model( diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index e1ef500ac07b8..5e4ae564cb57e 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -151,6 +151,7 @@ def __init__( parallel_config, scheduler_config, device_config, + cache_config, load_config=self.load_config, lora_config=self.lora_config, vision_language_config=self.vision_language_config, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 46c6730645c1b..b5e582116297c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,8 +9,9 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -106,6 +107,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", @@ -115,48 +117,40 @@ def __init__( self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker + self.vision_language_config = vision_language_config - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() - # Set after load_model. - self.lora_manager: LRUCacheWorkerLoRAManager = None - + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - - self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture - if self.model_config is not None else 0) - - self.pin_memory = is_pin_memory_available() - self.kv_cache_dtype = kv_cache_dtype - self.vision_language_config = vision_language_config - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) - - # Lazy initialization - self.model: torch.nn.Module # Set after load_model - self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables: torch.Tensor # Set after initial profiling. + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + self.attn_backend = get_attn_backend(self.model_config.dtype) + # Lazy initialization + self.model: torch.nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -211,13 +205,6 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") - def set_block_size(self, block_size: int) -> None: - self.block_size = block_size - - self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), - dtype=np.int32) - def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size @@ -835,6 +822,7 @@ def profile_run(self) -> None: dummy_lora_requests = [] dummy_lora_requests_per_seq = [] if self.lora_config: + assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for idx in range(self.lora_config.max_loras): lora_id = idx + 1 diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 313bcf25d8870..43f6b2b443b70 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -75,6 +75,7 @@ def __init__( parallel_config, scheduler_config, device_config, + cache_config, load_config=load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, @@ -184,7 +185,6 @@ def _init_cache_engine(self): self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache - self.model_runner.set_block_size(self.cache_engine.block_size) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: From 6b45dfb730879531428d5eb2f521eb31c8b9d283 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 06:37:50 +0000 Subject: [PATCH 48/81] Remove model runner from test_sampler --- tests/samplers/test_sampler.py | 81 ++++++++++------------------------ 1 file changed, 24 insertions(+), 57 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index e4fea165a4d46..ddc66aa28a094 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -11,8 +11,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter -from vllm.worker.model_runner import ModelRunner +from vllm.utils import Counter, is_pin_memory_available class MockLogitsSampler(Sampler): @@ -26,20 +25,14 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, VOCAB_SIZE), 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, sampler, model_runner + return input_tensor, fake_logits, sampler VOCAB_SIZE = 32000 @@ -53,7 +46,6 @@ def _do_sample( batch_size: int, input_tensor: torch.Tensor, sampler: MockLogitsSampler, - model_runner: ModelRunner, sampling_params: SamplingParams, device: str, ): @@ -75,7 +67,7 @@ def _do_sample( seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -85,19 +77,16 @@ def test_sampler_all_greedy(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -105,8 +94,7 @@ def test_sampler_all_random(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -115,15 +103,13 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -131,7 +117,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -141,15 +127,13 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -157,7 +141,7 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, @@ -165,15 +149,13 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) assert first_sampler_output == second_sampler_output - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -181,20 +163,18 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params, - device) + _do_sample(batch_size, fake_logits, sampler, sampling_params, device) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler # when handling an all-beam search case. - del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -448,13 +428,13 @@ def run_test_case(*, ("Invalid test case, expected_penalization does not match computed" "batch size") - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens=seq_lens if seq_lens else None, query_lens=seq_lens if seq_lens else None, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -480,8 +460,6 @@ def run_test_case(*, fake_logits[logits_idx, :] == -float('inf')) == 0, "No tokens should have been penalized" - del model_runner - for test_case in test_cases: run_test_case(**test_case) @@ -492,8 +470,7 @@ def test_sampler_mixed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] @@ -534,13 +511,13 @@ def test_sampler_mixed(seed: int, device: str): )) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - def test_sampling(model_runner: ModelRunner): + def test_sampling(): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -570,7 +547,7 @@ def test_sampling(model_runner: ModelRunner): assert nth_output.output_token in expected_tokens[i] # Test batch - test_sampling(model_runner) + test_sampling() # Shuffle the batch and resample target_index = list(range(batch_size)) @@ -583,9 +560,7 @@ def test_sampling(model_runner: ModelRunner): # This time, results of seeded random samples will be compared with # the corresponding sample in the pre-shuffled batch - test_sampling(model_runner) - - del model_runner + test_sampling() @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -605,12 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, @@ -641,7 +610,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sample_probs = None @@ -657,5 +626,3 @@ def mock_sample(probs, *args, **kwargs): hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) - - del model_runner From 6137ad4b1469f8fbb0d61e2521106e3bfa19ed11 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 06:39:05 +0000 Subject: [PATCH 49/81] [Misc] Remove unnecessary ModelRunner import --- tests/samplers/test_sampler.py | 81 ++++++++++------------------------ 1 file changed, 24 insertions(+), 57 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index e4fea165a4d46..ddc66aa28a094 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -11,8 +11,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter -from vllm.worker.model_runner import ModelRunner +from vllm.utils import Counter, is_pin_memory_available class MockLogitsSampler(Sampler): @@ -26,20 +25,14 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, VOCAB_SIZE), 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, sampler, model_runner + return input_tensor, fake_logits, sampler VOCAB_SIZE = 32000 @@ -53,7 +46,6 @@ def _do_sample( batch_size: int, input_tensor: torch.Tensor, sampler: MockLogitsSampler, - model_runner: ModelRunner, sampling_params: SamplingParams, device: str, ): @@ -75,7 +67,7 @@ def _do_sample( seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -85,19 +77,16 @@ def test_sampler_all_greedy(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -105,8 +94,7 @@ def test_sampler_all_random(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -115,15 +103,13 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -131,7 +117,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -141,15 +127,13 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -157,7 +141,7 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, @@ -165,15 +149,13 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) assert first_sampler_output == second_sampler_output - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -181,20 +163,18 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params, - device) + _do_sample(batch_size, fake_logits, sampler, sampling_params, device) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler # when handling an all-beam search case. - del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -448,13 +428,13 @@ def run_test_case(*, ("Invalid test case, expected_penalization does not match computed" "batch size") - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens=seq_lens if seq_lens else None, query_lens=seq_lens if seq_lens else None, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -480,8 +460,6 @@ def run_test_case(*, fake_logits[logits_idx, :] == -float('inf')) == 0, "No tokens should have been penalized" - del model_runner - for test_case in test_cases: run_test_case(**test_case) @@ -492,8 +470,7 @@ def test_sampler_mixed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] @@ -534,13 +511,13 @@ def test_sampler_mixed(seed: int, device: str): )) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - def test_sampling(model_runner: ModelRunner): + def test_sampling(): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -570,7 +547,7 @@ def test_sampling(model_runner: ModelRunner): assert nth_output.output_token in expected_tokens[i] # Test batch - test_sampling(model_runner) + test_sampling() # Shuffle the batch and resample target_index = list(range(batch_size)) @@ -583,9 +560,7 @@ def test_sampling(model_runner: ModelRunner): # This time, results of seeded random samples will be compared with # the corresponding sample in the pre-shuffled batch - test_sampling(model_runner) - - del model_runner + test_sampling() @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -605,12 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, @@ -641,7 +610,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sample_probs = None @@ -657,5 +626,3 @@ def mock_sample(probs, *args, **kwargs): hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) - - del model_runner From 7569137e27e43194d9bb23a8b6fdb4bf6edd83a6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 06:47:24 +0000 Subject: [PATCH 50/81] Fix test_logits_processor --- tests/test_logits_processor.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 179e8d25a341b..4ee980505a3ab 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -9,7 +9,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.utils import is_pin_memory_available class MockLogitsProcessor(LogitsProcessor): @@ -30,21 +30,15 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, logits_processor, model_runner + return input_tensor, fake_logits, logits_processor RANDOM_SEEDS = list(range(128)) @@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. @@ -87,8 +80,8 @@ def pick_ith(token_ids, logits): seq_group_metadata_list, seq_lens, query_lens=seq_lens, - device=model_runner.device, - pin_memory=model_runner.pin_memory) + device=device, + pin_memory=is_pin_memory_available()) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, @@ -99,5 +92,3 @@ def pick_ith(token_ids, logits): fake_logits *= logits_processor.scale assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], 1e-4) - - del model_runner From 6bcf10f8843f50ae73e683bb4784b05802eef33d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 07:34:31 +0000 Subject: [PATCH 51/81] Fix test_model_runner --- tests/worker/test_model_runner.py | 90 +++++++++++-------------------- 1 file changed, 32 insertions(+), 58 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e7975d0ef48b9..3e3d2e3f5c53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,27 +1,38 @@ import pytest import torch -from vllm.config import ModelConfig, SchedulerConfig from vllm.distributed.parallel_state import init_distributed_environment +from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size +def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + return model_runner + + @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=False) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) + model_runner = _create_model_runner( + "facebook/opt-125m", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + ) seq_lens = [] seq_group_metadata_list = [] @@ -123,27 +134,15 @@ def test_prepare_prompt(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_decode_cuda_graph(batch_size): - model_config = ModelConfig( + model_runner = _create_model_runner( "facebook/opt-125m", - "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=False, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, ) - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=False) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) seq_lens = [] seq_group_metadata_list = [] @@ -214,23 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size): def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output.""" - model_config = ModelConfig( - "facebook/opt-125m", + model_runner = _create_model_runner( "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) seq_group_metadata_list = [] input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) @@ -260,29 +248,15 @@ def distributed_init(): @pytest.mark.parametrize("batch_size", list(range(2, 128))) @pytest.mark.parametrize("enforce_eager", [True, False]) def test_hybrid_batches(batch_size, enforce_eager, distributed_init): - - model_config = ModelConfig( - "facebook/opt-125m", + model_runner = _create_model_runner( "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=enforce_eager, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=True, ) - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=True) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None, - is_driver_worker=True) - model_runner.set_block_size(16) # Add prefill requests. seq_lens = [] From 9092bb40e1c5408592e16b0a859f20f63e775c15 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 21:01:49 +0000 Subject: [PATCH 52/81] Enhance attention backend selector --- tests/worker/test_model_runner.py | 1 - vllm/attention/__init__.py | 12 ++-- vllm/attention/backends/abstract.py | 24 +++++--- vllm/attention/backends/flash_attn.py | 29 ++++------ vllm/attention/backends/flashinfer.py | 33 +++++++---- vllm/attention/backends/rocm_flash_attn.py | 23 +++----- vllm/attention/backends/torch_sdpa.py | 32 +++++------ vllm/attention/backends/xformers.py | 22 ++------ vllm/attention/layer.py | 6 +- vllm/attention/selector.py | 66 +++++++++++++--------- vllm/worker/cache_engine.py | 14 ++++- vllm/worker/model_runner.py | 42 +++++++++----- 12 files changed, 167 insertions(+), 137 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 3e3d2e3f5c53d..c2d1c5769619b 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) - assert attn_metadata.kv_cache_dtype == "auto" assert attn_metadata.num_prefills == prefill_batch_size if enforce_eager: assert attn_metadata.num_decode_tokens == decode_batch_size diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 7636b34a16fed..5902c053d55ff 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,13 +1,17 @@ -from vllm.attention.backends.abstract import (AttentionBackend, +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) from vllm.attention.layer import Attention -from vllm.attention.selector import get_attn_backend +from vllm.attention.selector import (get_attn_backend, get_cached_attn_impl, + set_attn_impl) __all__ = [ + "Attention", "AttentionBackend", + "AttentionImpl", "AttentionMetadata", - "Attention", - "get_attn_backend", "AttentionMetadataPerStage", + "get_attn_backend", + "get_cached_attn_impl", + "set_attn_impl", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 64ccb309a0480..a0f07efa719c7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -4,6 +4,7 @@ TypeVar) import torch +import torch.nn as nn class AttentionBackend(ABC): @@ -94,8 +95,6 @@ class AttentionMetadata(Generic[T]): # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - # The kv cache's data type. - kv_cache_dtype: str def __post_init__(self): if self.num_prefill_tokens > 0: @@ -105,9 +104,8 @@ def __post_init__(self): assert self.decode_metadata is not None -class AttentionImpl(ABC): +class AttentionImpl(nn.Module): - @abstractmethod def __init__( self, num_heads: int, @@ -116,10 +114,22 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - raise NotImplementedError + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads - @abstractmethod def forward( self, query: torch.Tensor, @@ -127,6 +137,6 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 068d2bc7a4885..4d1b842546d3f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -121,11 +121,11 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| + |<--------------- num_prefill_tokens ----------------->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. @@ -152,24 +152,21 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype) self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads if head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "FlashAttention backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") def forward( self, @@ -178,7 +175,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata[FlashAttentionMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -192,8 +189,6 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." - assert not attn_metadata.kv_cache_dtype.startswith("fp8"), ( - "FlashAttention does not support FP8 KV cache.") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. @@ -214,7 +209,7 @@ def forward( key_cache, value_cache, attn_metadata.slot_mapping.flatten(), - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, ) num_prefill_tokens = attn_metadata.num_prefill_tokens diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 36e162671f944..8c58816ffa19e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -149,20 +149,31 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + ) if sliding_window is not None: raise ValueError("Sliding window is not supported in FlashInfer.") self.sliding_window = (-1, -1) - self.alibi_slopes = alibi_slopes - self.scale = scale - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - - def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], - kv_scale: float): + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float = 1.0, + ): + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -183,7 +194,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, ) if prefill_meta := attn_metadata.prefill_metadata: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8fc1af1aa1e1c..82de2936422bb 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -138,25 +138,18 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype) self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. @@ -229,7 +222,7 @@ def forward( key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, kv_scale, ) @@ -323,7 +316,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c29218dfd0cfc..4246a308c5b06 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -83,26 +83,22 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window - if alibi_slopes is not None: - assert len(alibi_slopes) == num_heads - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype) self.need_mask = (self.alibi_slopes is not None or self.sliding_window is not None) - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") def forward( self, @@ -111,7 +107,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -124,6 +120,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -136,8 +133,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -195,7 +191,7 @@ def forward( attn_metadata.block_tables, attn_metadata.seq_lens_tensor, attn_metadata.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2a9150dea5875..99f5f3943cda8 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -149,18 +149,10 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype) suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: @@ -175,7 +167,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata[XFormersMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -188,7 +180,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -203,8 +194,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -262,7 +252,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1d542a380b044..9a12f473def00 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -6,7 +6,7 @@ from vllm.attention.backends.abstract import (AttentionMetadata, AttentionMetadataPerStage) -from vllm.attention.selector import get_attn_backend +from vllm.attention.selector import get_cached_attn_impl class Attention(nn.Module): @@ -31,8 +31,8 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - self.backend = get_attn_backend() - impl_cls = self.backend.get_impl_cls() + impl_cls = get_cached_attn_impl() + assert impl_cls is not None self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 889ddf9aa3fbf..cfa9983258c80 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,17 +1,32 @@ import enum +from contextlib import contextmanager from functools import lru_cache from typing import Optional, Type import torch import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, AttentionImpl from vllm.logger import init_logger from vllm.utils import is_cpu, is_hip logger = init_logger(__name__) -_ATTN_BACKEND: Optional[Type[AttentionBackend]] = None +_CACHED_ATTN_IMPL: Optional[Type[AttentionImpl]] = None + + +@contextmanager +def set_attn_impl(attn_impl: Optional[Type[AttentionImpl]]): + global _CACHED_ATTN_IMPL + prev = _CACHED_ATTN_IMPL + _CACHED_ATTN_IMPL = attn_impl + yield + _CACHED_ATTN_IMPL = prev + + +def get_cached_attn_impl() -> Optional[Type[AttentionImpl]]: + global _CACHED_ATTN_IMPL + return _CACHED_ATTN_IMPL class _Backend(enum.Enum): @@ -22,31 +37,19 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() +@lru_cache(maxsize=None) def get_attn_backend( - dtype: Optional[torch.dtype] = None, - kv_cache_dtype: Optional[str] = None, -) -> Type[AttentionBackend]: - """Returns the attention backend to use. - - For the first call, the backend is selected based on the dtype and - kv_cache_dtype. The selected backend is cached for subsequent calls. - """ - global _ATTN_BACKEND - if dtype is None: - assert kv_cache_dtype is None, "KV cache dtype should be None." - assert _ATTN_BACKEND is not None, "Attention backend is not set." - return _ATTN_BACKEND - else: - assert kv_cache_dtype is not None, "KV cache dtype is not set." - _ATTN_BACKEND = select_attn_backend(dtype, kv_cache_dtype) - return _ATTN_BACKEND - - -def select_attn_backend( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], dtype: torch.dtype, - kv_cache_dtype: str, + kv_cache_dtype: Optional[str], + block_size: int, ) -> Type[AttentionBackend]: - backend = _which_attn_to_use(dtype, kv_cache_dtype) + backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 @@ -75,10 +78,14 @@ def select_attn_backend( raise ValueError("Invalid attention backend.") -@lru_cache(maxsize=None) def _which_attn_to_use( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], dtype: torch.dtype, - kv_cache_dtype: str, + kv_cache_dtype: Optional[str], + block_size: int, ) -> _Backend: """Returns which flash attention backend to use.""" if is_cpu(): @@ -103,10 +110,15 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS - if kv_cache_dtype.startswith("fp8"): + if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") return _Backend.XFORMERS + if block_size % 16 != 0: + logger.info("Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + return _Backend.XFORMERS + try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 964540fe6f8cc..07d51dca226bd 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -31,7 +31,7 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) - self.num_heads = model_config.get_num_kv_heads(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -43,7 +43,15 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend() + self.attn_backend = get_attn_backend( + model_config.get_num_attention_heads(parallel_config), + self.head_size, + self.num_kv_heads, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") @@ -56,7 +64,7 @@ def _allocate_kv_cache( ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_heads, self.head_size) + num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 947c3504a63b4..b080ae3791e2a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,3 +1,5 @@ +import contextlib +import functools import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -7,7 +9,7 @@ import torch.nn as nn from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, - get_attn_backend) + get_attn_backend, set_attn_impl) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -142,26 +144,38 @@ def __init__( self.graph_block_tables = np.zeros( (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization - self.model: torch.nn.Module # Set after load_model + self.model: nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: - with CudaMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - ) + attn_impl = self.attn_backend.get_impl_cls() + attn_impl = functools.partial(attn_impl, + kv_cache_dtype=self.kv_cache_dtype) + with set_attn_impl(attn_impl): + with CudaMemoryProfiler() as m: + self.model = get_model( + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + ) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -757,7 +771,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, @@ -976,7 +989,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping=slot_mapping[:batch_size], prefill_metadata=None, decode_metadata=decode_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) if self.lora_config: From 7a10755fe633fe8fe24fcfd42ab41c77d839e4fb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 21:40:14 +0000 Subject: [PATCH 53/81] Rever flash-attn --- vllm/attention/backends/flash_attn.py | 134 +++++++++++--------------- 1 file changed, 58 insertions(+), 76 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4d1b842546d3f..0abb0476d94a6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,16 +1,20 @@ -"""Attention layer with FlashAttention.""" +"""Attention layer with Flash and PagedAttention. + +NOTE(woosuk): At the moment, this file includes a lot of duplicated code from +XFormers backend. The duplicated code will be removed once we use flash-attn or +flashinfer for all the attention operations. +""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm_flash_attn import flash_attn_varlen_func -from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) - -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) class FlashAttentionBackend(AttentionBackend): @@ -34,9 +38,8 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -44,26 +47,19 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + PagedAttention.copy_blocks(kv_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -109,23 +105,15 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| + |<--------------- num_prefill_tokens ----------------->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. @@ -159,14 +147,11 @@ def __init__( self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) - if head_size not in _SUPPORTED_HEAD_SIZES: + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") - if kv_cache_dtype != "auto": - raise NotImplementedError( - "FlashAttention backend does not support FP8 KV cache. " - "Please use xFormers backend instead.") + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") def forward( self, @@ -177,19 +162,17 @@ def forward( attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention. + """Forward pass with FlashAttention and PagedAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -197,20 +180,16 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache = kv_cache[0] - value_cache = kv_cache[1] + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - ) + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -230,8 +209,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -252,34 +230,38 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - # FIXME(woosuk): FlashAttention does not support FP8 KV cache. - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.subquery_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=prefill_meta.max_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + output[:num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], ) - if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ).squeeze(1) + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) From 0359113d9cc474b586dd2c4286dc2a45b3b9fac0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 21:41:32 +0000 Subject: [PATCH 54/81] Remove test --- tests/kernels/test_flash_attn.py | 191 ------------------------------- 1 file changed, 191 deletions(-) delete mode 100644 tests/kernels/test_flash_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py deleted file mode 100644 index 632223b3715fa..0000000000000 --- a/tests/kernels/test_flash_attn.py +++ /dev/null @@ -1,191 +0,0 @@ -from typing import List, Tuple - -import pytest -import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache - -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] - - -def ref_paged_attn( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - query_lens: List[int], - kv_lens: List[int], - block_tables: torch.Tensor, - scale: float, -) -> torch.Tensor: - num_seqs = len(query_lens) - block_tables = block_tables.cpu().numpy() - _, block_size, num_kv_heads, head_size = key_cache.shape - - outputs = [] - start_idx = 0 - for i in range(num_seqs): - query_len = query_lens[i] - kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] - q *= scale - - num_kv_blocks = (kv_len + block_size - 1) // block_size - block_indices = block_tables[i, :num_kv_blocks] - - k = key_cache[block_indices].view(-1, num_kv_heads, head_size) - k = k[:kv_len] - v = value_cache[block_indices].view(-1, num_kv_heads, head_size) - v = v[:kv_len] - - if q.shape[1] != k.shape[1]: - k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) - v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) - attn = torch.einsum("qhd,khd->hqk", q, k) - mask = torch.triu(torch.ones(query_len, kv_len), - diagonal=kv_len - query_len + 1).bool() - attn.masked_fill_(mask, float("-inf")) - attn = torch.softmax(attn, dim=-1) - out = torch.einsum("hqk,khd->qhd", attn, v) - - outputs.append(out) - start_idx += query_len - - return torch.cat(outputs, dim=0) - - -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_flash_attn_with_paged_kv( - kv_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - ).squeeze(1) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" - - -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_varlen_with_paged_kv( - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_query_len = max(query_lens) - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - cu_kv_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - block_table=block_tables, - ) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" From 21945e32d445c03072f42ec57181ef1b45f5e057 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 21:43:01 +0000 Subject: [PATCH 55/81] Enhance attention selector --- tests/worker/test_model_runner.py | 1 - vllm/attention/__init__.py | 12 +++-- vllm/attention/backends/abstract.py | 24 +++++++--- vllm/attention/backends/flash_attn.py | 20 +++----- vllm/attention/backends/flashinfer.py | 33 ++++++++----- vllm/attention/backends/rocm_flash_attn.py | 23 ++++----- vllm/attention/backends/torch_sdpa.py | 32 ++++++------- vllm/attention/backends/xformers.py | 22 +++------ vllm/attention/layer.py | 6 +-- vllm/attention/selector.py | 56 +++++++++++++++++++--- vllm/worker/cache_engine.py | 14 ++++-- vllm/worker/cpu_worker.py | 2 +- vllm/worker/model_runner.py | 53 ++++++++++++-------- 13 files changed, 180 insertions(+), 118 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 3e3d2e3f5c53d..c2d1c5769619b 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) - assert attn_metadata.kv_cache_dtype == "auto" assert attn_metadata.num_prefills == prefill_batch_size if enforce_eager: assert attn_metadata.num_decode_tokens == decode_batch_size diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 7636b34a16fed..5902c053d55ff 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,13 +1,17 @@ -from vllm.attention.backends.abstract import (AttentionBackend, +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) from vllm.attention.layer import Attention -from vllm.attention.selector import get_attn_backend +from vllm.attention.selector import (get_attn_backend, get_cached_attn_impl, + set_attn_impl) __all__ = [ + "Attention", "AttentionBackend", + "AttentionImpl", "AttentionMetadata", - "Attention", - "get_attn_backend", "AttentionMetadataPerStage", + "get_attn_backend", + "get_cached_attn_impl", + "set_attn_impl", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 64ccb309a0480..a0f07efa719c7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -4,6 +4,7 @@ TypeVar) import torch +import torch.nn as nn class AttentionBackend(ABC): @@ -94,8 +95,6 @@ class AttentionMetadata(Generic[T]): # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - # The kv cache's data type. - kv_cache_dtype: str def __post_init__(self): if self.num_prefill_tokens > 0: @@ -105,9 +104,8 @@ def __post_init__(self): assert self.decode_metadata is not None -class AttentionImpl(ABC): +class AttentionImpl(nn.Module): - @abstractmethod def __init__( self, num_heads: int, @@ -116,10 +114,22 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - raise NotImplementedError + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads - @abstractmethod def forward( self, query: torch.Tensor, @@ -127,6 +137,6 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4bad226512b69..0abb0476d94a6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -140,19 +140,12 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype) self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: @@ -167,7 +160,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata[FlashAttentionMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -196,8 +189,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -264,7 +256,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 36e162671f944..8c58816ffa19e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -149,20 +149,31 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + ) if sliding_window is not None: raise ValueError("Sliding window is not supported in FlashInfer.") self.sliding_window = (-1, -1) - self.alibi_slopes = alibi_slopes - self.scale = scale - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - - def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], - kv_scale: float): + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float = 1.0, + ): + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -183,7 +194,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, ) if prefill_meta := attn_metadata.prefill_metadata: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8fc1af1aa1e1c..82de2936422bb 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -138,25 +138,18 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype) self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. @@ -229,7 +222,7 @@ def forward( key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, kv_scale, ) @@ -323,7 +316,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c29218dfd0cfc..4246a308c5b06 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -83,26 +83,22 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window - if alibi_slopes is not None: - assert len(alibi_slopes) == num_heads - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype) self.need_mask = (self.alibi_slopes is not None or self.sliding_window is not None) - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") def forward( self, @@ -111,7 +107,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -124,6 +120,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -136,8 +133,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -195,7 +191,7 @@ def forward( attn_metadata.block_tables, attn_metadata.seq_lens_tensor, attn_metadata.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2a9150dea5875..99f5f3943cda8 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -149,18 +149,10 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype) suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: @@ -175,7 +167,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata[XFormersMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -188,7 +180,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -203,8 +194,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -262,7 +252,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ee7be26c0876c..9a12f473def00 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -6,7 +6,7 @@ from vllm.attention.backends.abstract import (AttentionMetadata, AttentionMetadataPerStage) -from vllm.attention.selector import get_attn_backend +from vllm.attention.selector import get_cached_attn_impl class Attention(nn.Module): @@ -31,8 +31,8 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - self.backend = get_attn_backend(torch.get_default_dtype()) - impl_cls = self.backend.get_impl_cls() + impl_cls = get_cached_attn_impl() + assert impl_cls is not None self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index f4446bac6b8d2..cfa9983258c80 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,16 +1,33 @@ import enum +from contextlib import contextmanager from functools import lru_cache -from typing import Type +from typing import Optional, Type import torch import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, AttentionImpl from vllm.logger import init_logger from vllm.utils import is_cpu, is_hip logger = init_logger(__name__) +_CACHED_ATTN_IMPL: Optional[Type[AttentionImpl]] = None + + +@contextmanager +def set_attn_impl(attn_impl: Optional[Type[AttentionImpl]]): + global _CACHED_ATTN_IMPL + prev = _CACHED_ATTN_IMPL + _CACHED_ATTN_IMPL = attn_impl + yield + _CACHED_ATTN_IMPL = prev + + +def get_cached_attn_impl() -> Optional[Type[AttentionImpl]]: + global _CACHED_ATTN_IMPL + return _CACHED_ATTN_IMPL + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -21,8 +38,18 @@ class _Backend(enum.Enum): @lru_cache(maxsize=None) -def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - backend = _which_attn_to_use(dtype) +def get_attn_backend( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> Type[AttentionBackend]: + backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 @@ -44,14 +71,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is enforced for the Flashinfer backend. ") + logger.warning("Eager mode is enforced for the Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: raise ValueError("Invalid attention backend.") -def _which_attn_to_use(dtype: torch.dtype) -> _Backend: +def _which_attn_to_use( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> _Backend: """Returns which flash attention backend to use.""" if is_cpu(): return _Backend.TORCH_SDPA @@ -75,6 +110,15 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS + if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") + return _Backend.XFORMERS + + if block_size % 16 != 0: + logger.info("Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + return _Backend.XFORMERS + try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 1fb63a3e47921..07d51dca226bd 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -31,7 +31,7 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) - self.num_heads = model_config.get_num_kv_heads(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -43,7 +43,15 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + model_config.get_num_attention_heads(parallel_config), + self.head_size, + self.num_kv_heads, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") @@ -56,7 +64,7 @@ def _allocate_kv_cache( ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_heads, self.head_size) + num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 5e4ae564cb57e..8363ce6f9ec20 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -53,7 +53,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend() # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3fc76c6142165..b080ae3791e2a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,3 +1,5 @@ +import contextlib +import functools import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -7,7 +9,7 @@ import torch.nn as nn from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, - get_attn_backend) + get_attn_backend, set_attn_impl) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -142,26 +144,38 @@ def __init__( self.graph_block_tables = np.zeros( (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization - self.model: torch.nn.Module # Set after load_model + self.model: nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: - with CudaMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - ) + attn_impl = self.attn_backend.get_impl_cls() + attn_impl = functools.partial(attn_impl, + kv_cache_dtype=self.kv_cache_dtype) + with set_attn_impl(attn_impl): + with CudaMemoryProfiler() as m: + self.model = get_model( + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + ) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -258,20 +272,23 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) + if self.attn_backend.get_name() == "flash-attn": + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) else: # The first prefill. - prefix_block_tables.append([]) + block_table = [] else: - prefix_block_tables.append([]) + block_table = [] # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 + prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len) @@ -754,7 +771,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, @@ -973,7 +989,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping=slot_mapping[:batch_size], prefill_metadata=None, decode_metadata=decode_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) if self.lora_config: From 72d515586c9f36f53bbcf7b9180843f8dccbade8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 21:58:46 +0000 Subject: [PATCH 56/81] Fix --- vllm/attention/backends/abstract.py | 19 ++++--------------- vllm/attention/backends/flash_attn.py | 13 +++++++++++-- vllm/attention/backends/flashinfer.py | 20 +++++++++++--------- vllm/attention/backends/rocm_flash_attn.py | 13 +++++++++++-- vllm/attention/backends/torch_sdpa.py | 14 ++++++++++++-- vllm/attention/backends/xformers.py | 14 ++++++++++++-- vllm/attention/selector.py | 9 --------- vllm/worker/model_runner.py | 3 +-- 8 files changed, 62 insertions(+), 43 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a0f07efa719c7..98d70fcab1a18 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -4,7 +4,6 @@ TypeVar) import torch -import torch.nn as nn class AttentionBackend(ABC): @@ -104,8 +103,9 @@ def __post_init__(self): assert self.decode_metadata is not None -class AttentionImpl(nn.Module): +class AttentionImpl(ABC): + @abstractmethod def __init__( self, num_heads: int, @@ -116,20 +116,9 @@ def __init__( sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", ) -> None: - super().__init__() - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = sliding_window - self.kv_cache_dtype = kv_cache_dtype - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads + raise NotImplementedError + @abstractmethod def forward( self, query: torch.Tensor, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0abb0476d94a6..f59715bd76ede 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -142,10 +142,19 @@ def __init__( sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", ) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype) + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8c58816ffa19e..35a7f8d4b466c 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -151,18 +151,20 @@ def __init__( sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", ) -> None: - super().__init__( - num_heads, - head_size, - scale, - num_kv_heads, - alibi_slopes, - sliding_window, - kv_cache_dtype, - ) + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes if sliding_window is not None: raise ValueError("Sliding window is not supported in FlashInfer.") self.sliding_window = (-1, -1) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads def forward( self, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 82de2936422bb..539585b46c7aa 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -140,10 +140,19 @@ def __init__( sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", ) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype) + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads supported_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in supported_head_sizes: diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 4246a308c5b06..2dd72a00c6e30 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -85,8 +85,18 @@ def __init__( sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", ) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype) + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.need_mask = (self.alibi_slopes is not None or self.sliding_window is not None) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 99f5f3943cda8..cb2028553461f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -151,8 +151,18 @@ def __init__( sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", ) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype) + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index cfa9983258c80..dc04f3e219217 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -110,15 +110,6 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") - return _Backend.XFORMERS - - if block_size % 16 != 0: - logger.info("Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - return _Backend.XFORMERS - try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b080ae3791e2a..08aad469deda3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,3 @@ -import contextlib import functools import time from enum import IntEnum @@ -165,7 +164,7 @@ def load_model(self) -> None: attn_impl = self.attn_backend.get_impl_cls() attn_impl = functools.partial(attn_impl, kv_cache_dtype=self.kv_cache_dtype) - with set_attn_impl(attn_impl): + with set_attn_impl(attn_impl): # noqa: SIM117 with CudaMemoryProfiler() as m: self.model = get_model( model_config=self.model_config, From c49d015c2c0e74652c9340d481f568d61fd9defc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 22:01:33 +0000 Subject: [PATCH 57/81] Revert --- vllm/worker/model_runner.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 08aad469deda3..cfd2590de7a3b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -271,19 +271,17 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - if self.attn_backend.get_name() == "flash-attn": - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums + prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) else: # The first prefill. - block_table = [] + prefix_block_tables.append([]) else: - block_table = [] + prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 From 8a8bb1cd7ea2b0dd4e8882f61294435aa9c2f96a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 22:04:24 +0000 Subject: [PATCH 58/81] Fix CPU --- vllm/worker/cpu_model_runner.py | 10 +++++++++- vllm/worker/cpu_worker.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 6c8b1685dadcf..a2f62a8ed2d88 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -53,7 +53,15 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization. self.model: nn.Module # Set after init_Model diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 8363ce6f9ec20..3ee394f9912e9 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -53,7 +53,15 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend() + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) From 1ff4fbd74e874be872e6bef6c9e4815492a61ce8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 22:05:35 +0000 Subject: [PATCH 59/81] Fix --- vllm/worker/model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cfd2590de7a3b..762afdec34c71 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -285,7 +285,6 @@ def _prepare_prompt( # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 - prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len) From 950bc82525745f0c177d6e6723b2533163b2f088 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 22:29:40 +0000 Subject: [PATCH 60/81] Fix --- vllm/attention/backends/flash_attn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 19a543c03b086..f79fecf632adc 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -194,8 +194,6 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." - assert not attn_metadata.kv_cache_dtype.startswith("fp8"), ( - "FlashAttention does not support FP8 KV cache.") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. From adf545a004eaa3b0edee5d53b52cb12b67f16ac0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 10 May 2024 22:32:31 +0000 Subject: [PATCH 61/81] Fix --- vllm/attention/backends/flashinfer.py | 2 +- vllm/worker/cpu_model_runner.py | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 35a7f8d4b466c..92d0fe0487516 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -174,7 +174,7 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata[FlashInferMetadata], kv_scale: float = 1.0, - ): + ) -> torch.Tensor: assert kv_scale == 1.0 num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index a2f62a8ed2d88..a5e51e5a40053 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -3,7 +3,7 @@ import torch from torch import nn -from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention import AttentionMetadata, get_attn_backend, set_attn_impl from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -67,14 +67,16 @@ def __init__( self.model: nn.Module # Set after init_Model def load_model(self) -> None: - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - vision_language_config=self.vision_language_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + attn_impl = self.attn_backend.get_impl_cls() + with set_attn_impl(attn_impl): + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=self.vision_language_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) def _prepare_prompt( self, From 250eac40317c24f48ad869c6f96827a4c1c50949 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 00:09:39 +0000 Subject: [PATCH 62/81] Fix CPU --- vllm/worker/cpu_model_runner.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index a2f62a8ed2d88..a5e51e5a40053 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -3,7 +3,7 @@ import torch from torch import nn -from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention import AttentionMetadata, get_attn_backend, set_attn_impl from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -67,14 +67,16 @@ def __init__( self.model: nn.Module # Set after init_Model def load_model(self) -> None: - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - vision_language_config=self.vision_language_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + attn_impl = self.attn_backend.get_impl_cls() + with set_attn_impl(attn_impl): + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=self.vision_language_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) def _prepare_prompt( self, From 974a4f8d3f2f312ded260617e948f67556823f68 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 00:11:27 +0000 Subject: [PATCH 63/81] Fix CPU --- vllm/worker/cpu_model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index a5e51e5a40053..35d09e4ec8da6 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -168,7 +168,6 @@ def _prepare_prompt( decode_metadata=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) @@ -252,7 +251,6 @@ def _prepare_decode( prefill_metadata=None, decode_metadata=None, block_tables=block_tables, - kv_cache_dtype=self.kv_cache_dtype, ) return ( input_tokens, From 8a629e52fe3b3f013b87a1ce3a54c5b773e59b0a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 01:13:54 +0000 Subject: [PATCH 64/81] Fix --- vllm/attention/layer.py | 18 +++++- vllm/attention/selector.py | 2 + vllm/model_executor/model_loader/__init__.py | 19 +++--- vllm/model_executor/model_loader/loader.py | 61 +++++++++++++------- vllm/model_executor/models/opt.py | 16 +++-- vllm/worker/cpu_model_runner.py | 21 ++++--- vllm/worker/model_runner.py | 28 ++++----- 7 files changed, 101 insertions(+), 64 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9a12f473def00..b9f5075cf1a2a 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -6,7 +6,8 @@ from vllm.attention.backends.abstract import (AttentionMetadata, AttentionMetadataPerStage) -from vllm.attention.selector import get_cached_attn_impl +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig class Attention(nn.Module): @@ -29,10 +30,21 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() - impl_cls = get_cached_attn_impl() - assert impl_cls is not None + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend( + num_heads, head_size, + num_kv_heads if num_kv_heads is not None else num_heads, + sliding_window, dtype, kv_cache_dtype, block_size) + impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index dc04f3e219217..4f55bccc16748 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -50,6 +50,8 @@ def get_attn_backend( backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, block_size) + logger.info(num_heads, head_size, num_kv_heads, sliding_window, dtype, + kv_cache_dtype, block_size) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 6f90e49994fb2..e3e32d61ab04d 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -2,26 +2,29 @@ from torch import nn -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.model_executor.model_loader.loader import (BaseModelLoader, get_model_loader) from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture) -def get_model( - *, model_config: ModelConfig, load_config: LoadConfig, - device_config: DeviceConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def get_model(*, model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: loader = get_model_loader(load_config) return loader.load_model(model_config=model_config, device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, parallel_config=parallel_config, - scheduler_config=scheduler_config) + scheduler_config=scheduler_config, + cache_config=cache_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bafa2de62e5df..fc9c8aa0af44b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,9 +9,9 @@ import torch from torch import nn -from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -77,15 +77,16 @@ def _get_model_initialization_kwargs( return extra_kwargs -def _initialize_model( - model_config: ModelConfig, load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: """Initialize a model with the given configurations.""" model_class = get_model_architecture(model_config)[0] quant_config = _get_quantization_config(model_config, load_config) return model_class(config=model_config.hf_config, + cache_config=cache_config, quant_config=quant_config, **_get_model_initialization_kwargs( model_class, lora_config, vision_language_config)) @@ -103,7 +104,8 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: """Load a model with the given configurations.""" ... @@ -216,11 +218,13 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, @@ -253,11 +257,13 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) @@ -286,9 +292,12 @@ def _get_weights_iterator( return tensorizer_weights_iterator(tensorizer_args) def _load_model_unserialized( - self, model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, ) -> nn.Module: """Load an unserialized model with tensorizer. @@ -299,15 +308,19 @@ def _load_model_unserialized( with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) model.load_weights(self._get_weights_iterator()) return model.eval() def _load_model_serialized( - self, model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, ) -> nn.Module: """Load a serialized model with tensorizer. @@ -321,6 +334,7 @@ def _load_model_serialized( extra_kwargs = _get_model_initialization_kwargs( model_class, lora_config, vision_language_config) extra_kwargs["quant_config"] = quant_config + extra_kwargs["cache_config"] = cache_config tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class @@ -335,16 +349,19 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: self._verify_config(model_config, parallel_config) if is_vllm_serialized_tensorizer(self.tensorizer_config): return self._load_model_serialized(model_config, device_config, lora_config, - vision_language_config) + vision_language_config, + cache_config) return self._load_model_unserialized(model_config, device_config, lora_config, - vision_language_config) + vision_language_config, + cache_config) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 336f765ababaa..d241756e50f4a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,6 +24,7 @@ from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -61,6 +62,7 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -88,7 +90,8 @@ def __init__( ) self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config) def forward( self, @@ -108,6 +111,7 @@ class OPTDecoderLayer(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -117,6 +121,7 @@ def __init__( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, + cache_config=cache_config, quant_config=quant_config, ) self.do_layer_norm_before = config.do_layer_norm_before @@ -181,6 +186,7 @@ class OPTDecoder(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -226,7 +232,7 @@ def __init__( self.final_layer_norm = None self.layers = nn.ModuleList([ - OPTDecoderLayer(config, quant_config) + OPTDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -259,10 +265,11 @@ class OPTModel(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.decoder = OPTDecoder(config, quant_config) + self.decoder = OPTDecoder(config, cache_config, quant_config) def forward( self, @@ -279,12 +286,13 @@ class OPTForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.model = OPTModel(config, quant_config) + self.model = OPTModel(config, cache_config, quant_config) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 35d09e4ec8da6..0a0b0d70cfe21 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -3,7 +3,7 @@ import torch from torch import nn -from vllm.attention import AttentionMetadata, get_attn_backend, set_attn_impl +from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -67,16 +67,15 @@ def __init__( self.model: nn.Module # Set after init_Model def load_model(self) -> None: - attn_impl = self.attn_backend.get_impl_cls() - with set_attn_impl(attn_impl): - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - vision_language_config=self.vision_language_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=self.vision_language_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) def _prepare_prompt( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 762afdec34c71..608c72fd5a6f4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,3 @@ -import functools import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -8,7 +7,7 @@ import torch.nn as nn from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, - get_attn_backend, set_attn_impl) + get_attn_backend) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -161,20 +160,17 @@ def __init__( self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: - attn_impl = self.attn_backend.get_impl_cls() - attn_impl = functools.partial(attn_impl, - kv_cache_dtype=self.kv_cache_dtype) - with set_attn_impl(attn_impl): # noqa: SIM117 - with CudaMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - ) + with CudaMemoryProfiler() as m: + self.model = get_model( + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", From d622b3e38f3f65f3485f3e1f8fe9b65b1e031596 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 01:14:33 +0000 Subject: [PATCH 65/81] Fix --- vllm/attention/layer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b9f5075cf1a2a..4f65f8a7d4a0b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -39,11 +39,12 @@ def __init__( else: kv_cache_dtype = "auto" block_size = 16 + if num_kv_heads is None: + num_kv_heads = num_heads dtype = torch.get_default_dtype() - attn_backend = get_attn_backend( - num_heads, head_size, - num_kv_heads if num_kv_heads is not None else num_heads, - sliding_window, dtype, kv_cache_dtype, block_size) + attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) From d27c139b5c9f4d3cb6826f8ef080026a66415dce Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 01:18:23 +0000 Subject: [PATCH 66/81] Fix Llama --- vllm/model_executor/models/llama.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f6d7fc8733fce..ffc2dbae6c9d4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,7 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -93,6 +93,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -152,7 +153,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + cache_config=cache_config) def forward( self, @@ -175,6 +177,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -203,6 +206,7 @@ def __init__( quant_config=quant_config, bias=attention_bias, sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -249,6 +253,7 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -265,7 +270,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) + LlamaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -330,12 +335,13 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.model = LlamaModel(config, quant_config, lora_config=lora_config) + self.model = LlamaModel(config, cache_config, quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size From e4fa4941f059b7162c39802c48660a0abd77d442 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 01:19:15 +0000 Subject: [PATCH 67/81] Fix --- vllm/attention/__init__.py | 5 +---- vllm/attention/selector.py | 18 ------------------ vllm/model_executor/models/llama.py | 5 ++++- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 5902c053d55ff..33a25adbf7330 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -2,8 +2,7 @@ AttentionMetadata, AttentionMetadataPerStage) from vllm.attention.layer import Attention -from vllm.attention.selector import (get_attn_backend, get_cached_attn_impl, - set_attn_impl) +from vllm.attention.selector import get_attn_backend __all__ = [ "Attention", @@ -12,6 +11,4 @@ "AttentionMetadata", "AttentionMetadataPerStage", "get_attn_backend", - "get_cached_attn_impl", - "set_attn_impl", ] diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4f55bccc16748..3ce5f186d7e30 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -12,22 +12,6 @@ logger = init_logger(__name__) -_CACHED_ATTN_IMPL: Optional[Type[AttentionImpl]] = None - - -@contextmanager -def set_attn_impl(attn_impl: Optional[Type[AttentionImpl]]): - global _CACHED_ATTN_IMPL - prev = _CACHED_ATTN_IMPL - _CACHED_ATTN_IMPL = attn_impl - yield - _CACHED_ATTN_IMPL = prev - - -def get_cached_attn_impl() -> Optional[Type[AttentionImpl]]: - global _CACHED_ATTN_IMPL - return _CACHED_ATTN_IMPL - class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -50,8 +34,6 @@ def get_attn_backend( backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, block_size) - logger.info(num_heads, head_size, num_kv_heads, sliding_window, dtype, - kv_cache_dtype, block_size) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ffc2dbae6c9d4..df5102043d4eb 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -341,7 +341,10 @@ def __init__( ) -> None: super().__init__() self.config = config - self.model = LlamaModel(config, cache_config, quant_config, lora_config=lora_config) + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size From e2a4ba0f2591d612a3099cd57d3c6b87c5230848 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 01:20:25 +0000 Subject: [PATCH 68/81] yapf --- vllm/attention/__init__.py | 3 +-- vllm/attention/selector.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 33a25adbf7330..088f48def7668 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,4 +1,4 @@ -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, +from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionMetadataPerStage) from vllm.attention.layer import Attention @@ -7,7 +7,6 @@ __all__ = [ "Attention", "AttentionBackend", - "AttentionImpl", "AttentionMetadata", "AttentionMetadataPerStage", "get_attn_backend", diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3ce5f186d7e30..06f99718a4dee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,12 +1,11 @@ import enum -from contextlib import contextmanager from functools import lru_cache from typing import Optional, Type import torch import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionBackend, AttentionImpl +from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils import is_cpu, is_hip From ee714454b54bcfdb227a612aecd653fc548ca544 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 02:05:01 +0000 Subject: [PATCH 69/81] Update models --- vllm/model_executor/models/arctic.py | 16 +++++++++-- vllm/model_executor/models/baichuan.py | 18 +++++++++--- vllm/model_executor/models/bloom.py | 15 +++++++--- vllm/model_executor/models/chatglm.py | 17 +++++++---- vllm/model_executor/models/commandr.py | 14 ++++++++-- vllm/model_executor/models/dbrx.py | 17 ++++++++--- vllm/model_executor/models/decilm.py | 4 ++- vllm/model_executor/models/deepseek.py | 16 +++++++++-- vllm/model_executor/models/falcon.py | 15 +++++++--- vllm/model_executor/models/gemma.py | 14 +++++++--- vllm/model_executor/models/gpt2.py | 16 ++++++++--- vllm/model_executor/models/gpt_bigcode.py | 14 +++++++--- vllm/model_executor/models/gpt_j.py | 20 +++++++++---- vllm/model_executor/models/gpt_neox.py | 16 ++++++++--- vllm/model_executor/models/internlm2.py | 13 +++++++-- vllm/model_executor/models/jais.py | 12 ++++++-- vllm/model_executor/models/llava.py | 6 ++-- vllm/model_executor/models/minicpm.py | 13 +++++++-- vllm/model_executor/models/mixtral.py | 13 +++++++-- vllm/model_executor/models/mixtral_quant.py | 31 ++++++++++++++------- vllm/model_executor/models/mpt.py | 18 ++++++++---- vllm/model_executor/models/olmo.py | 14 +++++++--- vllm/model_executor/models/orion.py | 13 +++++++-- vllm/model_executor/models/phi.py | 16 ++++++++--- vllm/model_executor/models/qwen.py | 15 ++++++++-- vllm/model_executor/models/qwen2.py | 14 +++++++--- vllm/model_executor/models/qwen2_moe.py | 16 +++++++++-- vllm/model_executor/models/stablelm.py | 14 +++++++--- vllm/model_executor/models/starcoder2.py | 13 +++++++-- vllm/model_executor/models/xverse.py | 14 +++++++--- 30 files changed, 336 insertions(+), 111 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 796cef7c4a735..cb99939cbb17a 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -5,6 +5,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -215,6 +216,7 @@ def __init__( self, config: ArcticConfig, layer_idx: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -265,7 +267,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -288,6 +291,7 @@ def __init__( self, config: ArcticConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -297,6 +301,7 @@ def __init__( self.use_residual = config.use_residual and is_moe_layer self.self_attn = ArcticAttention(config, layer_idx, + cache_config, quant_config=quant_config) self.block_sparse_moe = ArcticMoE( config, @@ -356,6 +361,7 @@ class ArcticModel(nn.Module): def __init__( self, config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -366,7 +372,10 @@ def __init__( config.hidden_size, org_num_embeddings=self.vocab_size) self.layers = nn.ModuleList([ - ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) + ArcticDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self._attn_implementation = config._attn_implementation @@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module): def __init__(self, config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, **kwargs) -> None: super().__init__() self.config = config - self.model = ArcticModel(config, quant_config) + self.model = ArcticModel(config, cache_config, quant_config) self.vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.vocab_size, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 186cee2584369..f69ec55b431b3 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,7 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -111,6 +111,7 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -162,7 +163,10 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config) def forward( self, @@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size @@ -197,6 +202,7 @@ def __init__(self, position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = BaiChuanMLP( @@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -255,7 +262,8 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, quant_config) + BaiChuanDecoderLayer(config, position_embedding, cache_config, + quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -304,13 +312,15 @@ def __init__( self, config, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.model = BaiChuanModel(config, position_embedding, quant_config) + self.model = BaiChuanModel(config, position_embedding, cache_config, + quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 1d7e5d2517c72..fe2de87b20dc9 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,7 @@ from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -71,6 +72,7 @@ class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -108,7 +110,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + cache_config=cache_config) def forward( self, @@ -158,6 +161,7 @@ class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -165,7 +169,8 @@ def __init__( self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, quant_config) + self.self_attention = BloomAttention(config, cache_config, + quant_config) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) @@ -214,6 +219,7 @@ class BloomModel(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -229,7 +235,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - BloomBlock(config, quant_config) + BloomBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = BloomModel(config, quant_config) + self.transformer = BloomModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e116af2ed080d..cde34feb48868 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -9,7 +9,7 @@ from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -34,6 +34,7 @@ class GLMAttention(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -90,6 +91,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) def forward( @@ -167,6 +169,7 @@ class GLMBlock(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -181,7 +184,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, quant_config) + self.self_attention = GLMAttention(config, cache_config, quant_config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -237,6 +240,7 @@ class GLMTransformer(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -246,8 +250,10 @@ def __init__( self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList( - [GLMBlock(config, quant_config) for i in range(self.num_layers)]) + self.layers = nn.ModuleList([ + GLMBlock(config, cache_config, quant_config) + for i in range(self.num_layers) + ]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -292,7 +299,7 @@ def __init__( self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, quant_config) + self.encoder = GLMTransformer(config, cache_config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 17c2f1223d96b..7354d11f98b15 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,6 +29,7 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -124,6 +125,7 @@ class CohereAttention(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -180,6 +182,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, @@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, quant_config=quant_config) + self.self_attn = CohereAttention(config, + cache_config, + quant_config=quant_config) self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), @@ -258,6 +264,7 @@ class CohereModel(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -266,7 +273,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - CohereDecoderLayer(config, quant_config=quant_config) + CohereDecoderLayer(config, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = LayerNorm(param_shape=(config.hidden_size), @@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -306,7 +314,7 @@ def __init__( self.quant_config = quant_config self.logits_processor = LogitsProcessor(config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, quant_config) + self.model = CohereModel(config, cache_config, quant_config) self.sampler = Sampler() @torch.no_grad() diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index a4a0ae50c645e..083ddf0159f71 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -5,6 +5,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -166,6 +167,7 @@ class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -221,6 +223,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) def forward( @@ -279,10 +282,12 @@ class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config) + self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, + quant_config) self.ffn = DbrxExperts(config, quant_config) def forward( @@ -308,6 +313,7 @@ class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -315,8 +321,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [DbrxBlock(config, quant_config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + DbrxBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, @@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, quant_config) + self.transformer = DbrxModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index be9a6b6813f8f..e293ee491908d 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -28,7 +28,7 @@ import torch from transformers import PretrainedConfig -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, config: Optional[PretrainedConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") super().__init__(config=config, + cache_config=cache_config, quant_config=quant_config, lora_config=lora_config) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index e5f7ba086a35d..62e04f9649915 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -28,6 +28,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -178,6 +179,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -229,7 +231,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -252,6 +255,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -267,6 +271,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) if (config.n_routed_experts is not None @@ -321,6 +326,7 @@ class DeepseekModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -332,7 +338,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config) + DeepseekDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekModel(config, quant_config) + self.model = DeepseekModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 08dd69923dc6d..ab9e1994be426 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,6 +27,7 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -77,6 +78,7 @@ class FalconAttention(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -168,7 +170,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -229,12 +232,14 @@ class FalconDecoderLayer(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, quant_config) + self.self_attention = FalconAttention(config, cache_config, + quant_config) self.mlp = FalconMLP(config, quant_config) self.config = config @@ -311,6 +316,7 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -327,7 +333,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config, quant_config) + FalconDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -359,12 +365,13 @@ class FalconForCausalLM(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = FalconModel(config, quant_config) + self.transformer = FalconModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index bb73ff4d206da..d1502b718a773 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,7 +22,7 @@ from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul @@ -107,6 +107,7 @@ def __init__(self, head_dim: int, max_position_embeddings: int = 8192, rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -155,7 +156,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -177,6 +179,7 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -188,6 +191,7 @@ def __init__( head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = GemmaMLP( @@ -236,6 +240,7 @@ class GemmaModel(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -246,7 +251,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, quant_config) + GemmaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -309,6 +314,7 @@ class GemmaForCausalLM(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -316,7 +322,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = GemmaModel(config, quant_config) + self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 75eaebf0dbd15..0deaa58ed9eb5 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -70,7 +72,10 @@ def __init__( bias=True, quant_config=quant_config, ) - self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config) def forward( self, @@ -122,6 +127,7 @@ class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -130,7 +136,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, quant_config) + self.attn = GPT2Attention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config, quant_config) @@ -163,6 +169,7 @@ class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -174,7 +181,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, quant_config) + GPT2Block(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(config, quant_config) + self.transformer = GPT2Model(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index d057fd928fdb5..c20fb3230c394 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,6 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -46,6 +47,7 @@ class GPTBigCodeAttention(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -85,7 +87,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -143,6 +146,7 @@ class GPTBigCodeBlock(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -151,7 +155,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, quant_config) + self.attn = GPTBigCodeAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigMLP(inner_dim, config, quant_config) @@ -184,6 +188,7 @@ class GPTBigCodeModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -195,7 +200,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPTBigCodeBlock(config, quant_config) + GPTBigCodeBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -224,12 +229,13 @@ class GPTBigCodeForCausalLM(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, quant_config) + self.transformer = GPTBigCodeModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 8d7fe8a5beef7..5f4d8ec3d3a7a 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,6 +23,7 @@ from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -83,7 +85,10 @@ def __init__( base=rope_theta, is_neox_style=False, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -135,13 +140,14 @@ class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, quant_config) + self.attn = GPTJAttention(config, cache_config, quant_config) self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( @@ -169,6 +175,7 @@ class GPTJModel(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -178,8 +185,10 @@ def __init__( config.vocab_size, self.embed_dim, ) - self.h = nn.ModuleList( - [GPTJBlock(config, quant_config) for _ in range(config.n_layer)]) + self.h = nn.ModuleList([ + GPTJBlock(config, cache_config, quant_config) + for _ in range(config.n_layer) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -207,13 +216,14 @@ class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(config, quant_config) + self.transformer = GPTJModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index bab563b9c5a39..dcb52ff666c95 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,6 +23,7 @@ from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPTNeoXAttention(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -84,7 +86,10 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -134,6 +139,7 @@ class GPTNeoXLayer(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -142,7 +148,7 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, quant_config) + self.attention = GPTNeoXAttention(config, cache_config, quant_config) self.mlp = GPTNeoXMLP(config, quant_config) def forward( @@ -182,6 +188,7 @@ class GPTNeoXModel(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -192,7 +199,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GPTNeoXLayer(config, quant_config) + GPTNeoXLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, @@ -223,12 +230,13 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.gpt_neox = GPTNeoXModel(config, quant_config) + self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 5811cae83bf8b..65f7ddb8b082c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -6,6 +6,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -64,6 +65,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -114,7 +116,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -136,6 +139,7 @@ class InternLMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -151,6 +155,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.feed_forward = InternLM2MLP( @@ -196,6 +201,7 @@ class InternLM2Model(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -207,7 +213,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, quant_config) + InternLMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -239,12 +245,13 @@ class InternLM2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = InternLM2Model(config, quant_config) + self.model = InternLM2Model(config, cache_config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index bd6a180ec8dfc..df30fd1ba0a37 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -26,6 +26,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -69,6 +70,7 @@ class JAISAttention(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -108,6 +110,7 @@ def __init__( self.head_dim, scale=self.scale, alibi_slopes=alibi_slopes, + cache_config=cache_config, ) def forward( @@ -170,6 +173,7 @@ class JAISBlock(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -178,7 +182,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, quant_config) + self.attn = JAISAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -211,6 +215,7 @@ class JAISModel(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -228,7 +233,7 @@ def __init__( else: self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ - JAISBlock(config, quant_config) + JAISBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -262,12 +267,13 @@ class JAISLMHeadModel(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = JAISModel(config, quant_config) + self.transformer = JAISModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index dcde4dfa0795e..3b99b337a2765 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -7,7 +7,7 @@ from transformers import CLIPVisionModel, LlavaConfig from vllm.attention import AttentionMetadata -from vllm.config import VisionLanguageConfig +from vllm.config import CacheConfig, VisionLanguageConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -62,6 +62,7 @@ class LlavaForConditionalGeneration(nn.Module): def __init__(self, config: "LlavaConfig", vision_language_config: VisionLanguageConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional["QuantizationConfig"] = None) -> None: super().__init__() self.config = config @@ -85,7 +86,8 @@ def __init__(self, projector_hidden_act=config.projector_hidden_act) self.quant_config = quant_config - self.language_model = LlamaModel(config.text_config, quant_config) + self.language_model = LlamaModel(config.text_config, cache_config, + quant_config) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index c90bcfbfc4707..0b85cf1c94795 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -28,7 +28,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -181,6 +181,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +235,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -259,6 +261,7 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -275,6 +278,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.num_experts = getattr(self.config, "num_experts", 0) @@ -330,6 +334,7 @@ class MiniCPMModel(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -346,7 +351,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, quant_config) + MiniCPMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -413,6 +418,7 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -421,6 +427,7 @@ def __init__( self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, + cache_config, quant_config, lora_config=lora_config) unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index efa4de7516212..113abbaa6036d 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,7 +29,7 @@ from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -252,6 +252,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() @@ -313,6 +314,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -335,6 +337,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -348,6 +351,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, + cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, @@ -394,6 +398,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -410,7 +415,9 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, quant_config=quant_config) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -460,12 +467,14 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.model = MixtralModel(config, + cache_config, quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 38c62afced28a..ee2626b1c1aa2 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,6 +30,7 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -157,14 +158,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -215,6 +219,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -237,6 +242,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -250,6 +256,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, + cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) @@ -292,6 +299,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -303,7 +311,9 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, quant_config=quant_config) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -332,12 +342,13 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, quant_config) + self.model = MixtralModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 6fa5c5bd3014a..716ac51cde94d 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,6 +7,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -43,6 +44,7 @@ class MPTAttention(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -107,7 +109,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -166,12 +169,13 @@ class MPTBlock(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, quant_config) + self.attn = MPTAttention(config, cache_config, quant_config) self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -201,6 +205,7 @@ class MPTModel(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -211,8 +216,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [MPTBlock(config, quant_config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + MPTBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -253,7 +261,7 @@ def __init__( assert config.tie_word_embeddings self.quant_config = quant_config - self.transformer = MPTModel(config, quant_config) + self.transformer = MPTModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index f212ea2166e1d..69f23bbfb5d0a 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -28,6 +28,7 @@ from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -55,6 +56,7 @@ class OlmoAttention(nn.Module): def __init__( self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -93,7 +95,8 @@ def __init__( self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config) # Attention output projection. self.o_proj = RowParallelLinear( @@ -175,10 +178,11 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, quant_config) + self.self_attn = OlmoAttention(config, cache_config, quant_config) # MLP block. self.mlp = OlmoMLP(config, quant_config) @@ -217,6 +221,7 @@ class OlmoModel(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -224,7 +229,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - OlmoDecoderLayer(config, quant_config) + OlmoDecoderLayer(config, cache_config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, @@ -271,10 +276,11 @@ class OlmoForCausalLM(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = OlmoModel(config, quant_config) + self.model = OlmoModel(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight else: diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 9ab5dfb97c19a..59cd42e31b374 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -68,6 +69,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -118,7 +120,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -140,6 +143,7 @@ class OrionDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -155,6 +159,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = OrionMLP( @@ -202,6 +207,7 @@ class OrionModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -213,7 +219,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - OrionDecoderLayer(config, quant_config) + OrionDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -245,12 +251,13 @@ class OrionForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = OrionModel(config, quant_config) + self.model = OrionModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 4a45879201af3..ed25a232f4208 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,6 +42,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -63,6 +64,7 @@ class PhiAttention(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.total_num_heads = config.num_attention_heads @@ -105,7 +107,10 @@ def __init__(self, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -155,11 +160,12 @@ class PhiLayer(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, quant_config) + self.self_attn = PhiAttention(config, cache_config, quant_config) self.mlp = PhiMLP(config, quant_config) def forward( @@ -186,6 +192,7 @@ class PhiModel(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -193,7 +200,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, quant_config) + PhiLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, @@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.quant_config = quant_config - self.model = PhiModel(config, quant_config) + self.model = PhiModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e5e0028888c88..d158846a3a1f5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -68,6 +69,7 @@ def __init__( max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -101,7 +103,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config) def forward( self, @@ -123,6 +128,7 @@ class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -135,6 +141,7 @@ def __init__( config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, + cache_config=cache_config, quant_config=quant_config) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -175,6 +182,7 @@ class QWenModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -186,7 +194,7 @@ def __init__( config.hidden_size, ) self.h = nn.ModuleList([ - QWenBlock(config, quant_config) + QWenBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -218,12 +226,13 @@ class QWenLMHeadModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = QWenModel(config, quant_config) + self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 62bc7fe22c367..31ba6441f9f7a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -29,7 +29,7 @@ from transformers import Qwen2Config from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -87,6 +87,7 @@ def __init__(self, max_position: int = 4096 * 32, rope_theta: float = 10000, use_sliding_window: bool = False, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() @@ -137,7 +138,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + sliding_window=self.sliding_window, + cache_config=cache_config) def forward( self, @@ -160,6 +162,7 @@ def __init__( self, config: Qwen2Config, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -175,6 +178,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, use_sliding_window=use_sliding_window, + cache_config=cache_config, quant_config=quant_config, sliding_window=config.sliding_window) self.mlp = Qwen2MLP( @@ -222,6 +226,7 @@ class Qwen2Model(nn.Module): def __init__( self, config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +239,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, quant_config) + Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -287,6 +292,7 @@ class Qwen2ForCausalLM(nn.Module): def __init__( self, config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -294,7 +300,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2Model(config, quant_config) + self.model = Qwen2Model(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 8da89a2b7ba6c..2a3b0173adf8b 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,6 +30,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -187,6 +188,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -238,7 +240,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -261,6 +264,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -276,6 +280,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) if (config.num_experts is not None @@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -339,7 +345,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config) + Qwen2MoeDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2MoeModel(config, quant_config) + self.model = Qwen2MoeModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 3d4f4f700f867..922e971c093e4 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -72,6 +73,7 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config @@ -124,7 +126,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_key_value_heads) + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config) def forward( self, @@ -146,11 +149,12 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.self_attn = StablelmAttention(config) - self.mlp = StablelmMLP(config, quant_config) + self.mlp = StablelmMLP(config, cache_config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) @@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( @@ -195,7 +200,7 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, quant_config) + StablelmDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) norm_eps = getattr(config, "norm_eps", @@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(config, quant_config) + self.model = StableLMEpochModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 33998e2aad5c5..df5d51983fe04 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,6 +25,7 @@ from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -101,6 +103,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, quant_config=quant_config) + self.self_attn = Starcoder2Attention(config, + cache_config, + quant_config=quant_config) self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -201,7 +208,9 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, quant_config=quant_config) + Starcoder2DecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 0fb2662b2f715..6ef230a8ebbca 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -27,7 +27,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -89,6 +89,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -133,7 +134,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + cache_config=cache_config) def forward( self, @@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -175,6 +178,7 @@ def __init__( quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = XverseMLP( hidden_size=self.hidden_size, @@ -221,6 +225,7 @@ class XverseModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -237,7 +242,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - XverseDecoderLayer(config, quant_config) + XverseDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config=None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = XverseModel(config, quant_config) + self.model = XverseModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() From 974ed4defc2703b7f1c13fbea1851fbfca106582 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 20:26:00 +0000 Subject: [PATCH 70/81] Fix --- vllm/model_executor/models/chatglm.py | 3 ++- vllm/model_executor/models/starcoder2.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index cde34feb48868..29c76682109c6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -341,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config - self.transformer = ChatGLMModel(config, quant_config) + self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index df5d51983fe04..3c19d63276a77 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -235,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = Starcoder2Model(config, quant_config=quant_config) + self.model = Starcoder2Model(config, + cache_config, + quant_config=quant_config) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: From ec720638c7aba5b3e3d22d69453fb245d3143785 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 20:27:28 +0000 Subject: [PATCH 71/81] Fix --- vllm/model_executor/models/baichuan.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index f69ec55b431b3..58b3405d319d1 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -399,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", quant_config, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", quant_config, lora_config) + super().__init__(config, "ALIBI", cache_config, quant_config, + lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -414,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, "ROPE", quant_config, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) From 180acaa23eb8c648ed51e5d384feca012c6c73bb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 11 May 2024 20:29:16 +0000 Subject: [PATCH 72/81] Fix --- vllm/model_executor/models/stablelm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 922e971c093e4..8b4a5507feade 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -153,8 +153,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.self_attn = StablelmAttention(config) - self.mlp = StablelmMLP(config, cache_config, quant_config) + self.self_attn = StablelmAttention(config, cache_config, quant_config) + self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) From 1c2ad0a7143b562e3cd34e2f366cee3ac277cb5e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 15:10:52 +0000 Subject: [PATCH 73/81] Add comment --- vllm/attention/layer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 4f65f8a7d4a0b..8a872dba8c877 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -41,6 +41,8 @@ def __init__( block_size = 16 if num_kv_heads is None: num_kv_heads = num_heads + # During model initialization, the default dtype is set as the model + # weight and activation dtype. dtype = torch.get_default_dtype() attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, From 8cfb402dc9cd0bf59089b48dd5ba58742b44ffc2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 16:05:07 +0000 Subject: [PATCH 74/81] Remove kv_cache_dtype --- vllm/worker/embedding_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 2d3f160c60dc1..d04bebbdc31b6 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -235,7 +235,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, pooling_metadata, From 9a1b8f30b5c9dbcbcef36ea24629405d28ae4486 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 17:38:59 +0000 Subject: [PATCH 75/81] yapf --- vllm/worker/cpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index bf25c2b4259f2..0a0b0d70cfe21 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -3,7 +3,7 @@ import torch from torch import nn -from vllm.attention import AttentionMetadata, get_attn_backend, set_attn_impl +from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) From 304c9e317d68a37ad812bd48f3e6e7dba7820d1d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 17:49:41 +0000 Subject: [PATCH 76/81] Sliding window --- tests/kernels/test_flash_attn.py | 17 ++++++++++++++++- vllm/attention/backends/flash_attn.py | 8 ++++++-- vllm/attention/selector.py | 5 +++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 632223b3715fa..be2c605ac815a 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple import pytest import torch @@ -18,6 +18,7 @@ def ref_paged_attn( kv_lens: List[int], block_tables: torch.Tensor, scale: float, + sliding_window: Optional[int] = None, ) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() @@ -45,6 +46,13 @@ def ref_paged_attn( attn = torch.einsum("qhd,khd->hqk", q, k) mask = torch.triu(torch.ones(query_len, kv_len), diagonal=kv_len - query_len + 1).bool() + # print(mask) + if sliding_window is not None: + sliding_window_mask = torch.triu(torch.ones(query_len, kv_len), + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask attn.masked_fill_(mask, float("-inf")) attn = torch.softmax(attn, dim=-1) out = torch.einsum("hqk,khd->qhd", attn, v) @@ -120,12 +128,14 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) @pytest.mark.parametrize("dtype", DTYPES) @torch.inference_mode def test_varlen_with_paged_kv( seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], head_size: int, + sliding_window: Optional[int], dtype: torch.dtype, block_size: int, ) -> None: @@ -140,6 +150,9 @@ def test_varlen_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) + window_size = ((sliding_window, + sliding_window) if sliding_window is not None else + (-1, -1)) scale = head_size**-0.5 query = torch.randn(sum(query_lens), @@ -175,6 +188,7 @@ def test_varlen_with_paged_kv( max_seqlen_k=max_kv_len, softmax_scale=scale, causal=True, + window_size=window_size, block_table=block_tables, ) @@ -186,6 +200,7 @@ def test_varlen_with_paged_kv( kv_lens=kv_lens, block_tables=block_tables, scale=scale, + sliding_window=sliding_window, ) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f79fecf632adc..11ecb2792ea9d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -168,6 +168,11 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") if head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " @@ -193,6 +198,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." num_tokens, hidden_size = query.shape @@ -257,7 +263,6 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - # FIXME(woosuk): FlashAttention does not support FP8 KV cache. output[:num_prefill_tokens] = flash_attn_varlen_func( q=query, k=key_cache, @@ -268,7 +273,6 @@ def forward( max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, - window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, block_table=prefill_meta.block_tables, ) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d94cc4641d5b5..5140c3cc86a31 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -102,6 +102,11 @@ def _which_attn_to_use( "divisible by 16.") return _Backend.XFORMERS + if sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + return _Backend.XFORMERS + try: import vllm_flash_attn # noqa: F401 except ImportError: From 1be2eb3658661e1c152c0bb3cf3116a9df27410a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 17:54:24 +0000 Subject: [PATCH 77/81] yapf --- vllm/attention/backends/flash_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d6cddaef8167b..11ecb2792ea9d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -290,6 +290,5 @@ def forward( alibi_slopes=self.alibi_slopes, ).squeeze(1) - # Reshape the output tensor. return output.view(num_tokens, hidden_size) From d54461133efd691abac2dc9b7391cf5c7480effd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 19:25:38 +0000 Subject: [PATCH 78/81] Use fp32 in ref attn softmax --- tests/kernels/test_flash_attn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index be2c605ac815a..41c26dd5eae6d 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -43,10 +43,9 @@ def ref_paged_attn( if q.shape[1] != k.shape[1]: k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) - attn = torch.einsum("qhd,khd->hqk", q, k) + attn = torch.einsum("qhd,khd->hqk", q, k).float() mask = torch.triu(torch.ones(query_len, kv_len), diagonal=kv_len - query_len + 1).bool() - # print(mask) if sliding_window is not None: sliding_window_mask = torch.triu(torch.ones(query_len, kv_len), diagonal=kv_len - @@ -54,7 +53,7 @@ def ref_paged_attn( 1).bool().logical_not() mask |= sliding_window_mask attn.masked_fill_(mask, float("-inf")) - attn = torch.softmax(attn, dim=-1) + attn = torch.softmax(attn, dim=-1).to(v.dtype) out = torch.einsum("hqk,khd->qhd", attn, v) outputs.append(out) From ddd9e35d4f8e8ca5e4680e8afec65dc58bd1004c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 19:53:30 +0000 Subject: [PATCH 79/81] Fix broken tests --- tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index c02204f16ac68..10e7c64e34e75 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - "mosaicml/mpt-7b", + # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index e87a1783a83f1..b7781c1cb36f5 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,7 +25,7 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', @@ -36,7 +36,7 @@ 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' From 7e0da785a78c699c885073b0e9c3378cc0c89b46 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 19:57:21 +0000 Subject: [PATCH 80/81] Address comments --- tests/kernels/test_flash_attn.py | 6 +++--- vllm/worker/model_runner.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 41c26dd5eae6d..da3c0333c51e1 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -44,10 +44,10 @@ def ref_paged_attn( k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) attn = torch.einsum("qhd,khd->hqk", q, k).float() - mask = torch.triu(torch.ones(query_len, kv_len), - diagonal=kv_len - query_len + 1).bool() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(torch.ones(query_len, kv_len), + sliding_window_mask = torch.triu(empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1).bool().logical_not() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cdec863195748..3f7e87c1de48c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -267,6 +267,10 @@ def _prepare_prompt( context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. block_table = seq_group_metadata.block_tables[seq_id] else: block_table = computed_block_nums From cd22037fbdaa9dfe843944e09a281170421ea67d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 21:08:08 +0000 Subject: [PATCH 81/81] Fix CI --- tests/kernels/test_flash_attn.py | 4 ++++ tests/models/test_fp8.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index da3c0333c51e1..89bdacc67fbc4 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -164,6 +164,10 @@ def test_varlen_with_paged_kv( head_size, dtype=dtype) value_cache = torch.randn_like(key_cache) + # Normalize the scale of the key and value caches to mitigate + # numerical instability. + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index b7781c1cb36f5..664e951a89f2a 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -26,10 +26,10 @@ 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',