From fab469e3b06be517f879a5992741d96f2e1c2f91 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 14 May 2024 18:27:15 -0700 Subject: [PATCH] Revert "[Kernel] Use flash-attn for decoding (#3648)" This reverts commit 1356df53bd5d6877358aff3d2bbd95f28f8009a4. --- tests/kernels/test_flash_attn.py | 209 -------------------------- tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 10 +- vllm/attention/backends/flash_attn.py | 128 +++++++--------- vllm/attention/selector.py | 14 -- vllm/worker/model_runner.py | 15 +- 6 files changed, 65 insertions(+), 313 deletions(-) delete mode 100644 tests/kernels/test_flash_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py deleted file mode 100644 index 89bdacc67fbc4..0000000000000 --- a/tests/kernels/test_flash_attn.py +++ /dev/null @@ -1,209 +0,0 @@ -from typing import List, Optional, Tuple - -import pytest -import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache - -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] - - -def ref_paged_attn( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - query_lens: List[int], - kv_lens: List[int], - block_tables: torch.Tensor, - scale: float, - sliding_window: Optional[int] = None, -) -> torch.Tensor: - num_seqs = len(query_lens) - block_tables = block_tables.cpu().numpy() - _, block_size, num_kv_heads, head_size = key_cache.shape - - outputs = [] - start_idx = 0 - for i in range(num_seqs): - query_len = query_lens[i] - kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] - q *= scale - - num_kv_blocks = (kv_len + block_size - 1) // block_size - block_indices = block_tables[i, :num_kv_blocks] - - k = key_cache[block_indices].view(-1, num_kv_heads, head_size) - k = k[:kv_len] - v = value_cache[block_indices].view(-1, num_kv_heads, head_size) - v = v[:kv_len] - - if q.shape[1] != k.shape[1]: - k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) - v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) - attn = torch.einsum("qhd,khd->hqk", q, k).float() - empty_mask = torch.ones(query_len, kv_len) - mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() - if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() - mask |= sliding_window_mask - attn.masked_fill_(mask, float("-inf")) - attn = torch.softmax(attn, dim=-1).to(v.dtype) - out = torch.einsum("hqk,khd->qhd", attn, v) - - outputs.append(out) - start_idx += query_len - - return torch.cat(outputs, dim=0) - - -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_flash_attn_with_paged_kv( - kv_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - ).squeeze(1) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" - - -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None]) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_varlen_with_paged_kv( - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - sliding_window: Optional[int], - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_query_len = max(query_lens) - max_kv_len = max(kv_lens) - window_size = ((sliding_window, - sliding_window) if sliding_window is not None else - (-1, -1)) - scale = head_size**-0.5 - - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - # Normalize the scale of the key and value caches to mitigate - # numerical instability. - key_cache /= head_size**0.5 - value_cache /= head_size**0.5 - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - cu_kv_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - ) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - sliding_window=sliding_window, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 10e7c64e34e75..c02204f16ac68 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - # "mosaicml/mpt-7b", # Broken + "mosaicml/mpt-7b", # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 664e951a89f2a..e87a1783a83f1 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,18 +25,18 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', + 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', + 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 11ecb2792ea9d..f59715bd76ede 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,16 +1,20 @@ -"""Attention layer with FlashAttention.""" +"""Attention layer with Flash and PagedAttention. + +NOTE(woosuk): At the moment, this file includes a lot of duplicated code from +XFormers backend. The duplicated code will be removed once we use flash-attn or +flashinfer for all the attention operations. +""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm_flash_attn import flash_attn_varlen_func -from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) - -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) class FlashAttentionBackend(AttentionBackend): @@ -34,9 +38,8 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -44,26 +47,19 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + PagedAttention.copy_blocks(kv_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -109,14 +105,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - class FlashAttentionImpl(AttentionImpl): """ @@ -168,15 +156,11 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - if head_size not in _SUPPORTED_HEAD_SIZES: + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") def forward( self, @@ -187,20 +171,17 @@ def forward( attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention. + """Forward pass with FlashAttention and PagedAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -208,20 +189,16 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache = kv_cache[0] - value_cache = kv_cache[1] + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - ) + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -241,8 +218,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -263,32 +239,38 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.subquery_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=prefill_meta.max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + output[:num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], ) - if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ).squeeze(1) + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5140c3cc86a31..06f99718a4dee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -93,20 +93,6 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") - return _Backend.XFORMERS - - if block_size % 16 != 0: - logger.info("Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - return _Backend.XFORMERS - - if sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - return _Backend.XFORMERS - try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3f7e87c1de48c..b5e1991717b13 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -266,27 +266,20 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums + prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) else: # The first prefill. - block_table = [] + prefix_block_tables.append([]) else: - block_table = [] + prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 - prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len)