From b3f42edaff84bb82bea32c2e72748310e8d63541 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 21 Oct 2024 18:38:00 +0000 Subject: [PATCH 01/32] Add flash attn kernel support for encoder-decoder models --- tests/kernels/test_encoder_decoder_attn.py | 4 +- tests/kernels/utils.py | 8 +- vllm/attention/backends/flash_attn.py | 96 +++++++++++++++++----- vllm/attention/layer.py | 1 + 4 files changed, 85 insertions(+), 24 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 6b979d0558c46..1cf11bb7adca8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -21,7 +21,7 @@ from vllm.utils import is_hip # List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] HEAD_SIZES = [64, 256] @@ -129,6 +129,7 @@ class that Attention will automatically select when it is constructed. scale = float(1.0 / (test_pt.head_size**0.5)) attn_backend = make_backend(test_pt.backend_name) + print('attn_backend ' + str(attn_backend)) attn = Attention( test_pt.num_heads, test_pt.head_size, @@ -774,6 +775,7 @@ def test_encoder_only( # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): + torch.set_default_dtype(torch.bfloat16) # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index a2d414f636e13..c906da1a8ccf5 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,7 +13,7 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, STR_FLASH_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some @@ -525,8 +525,12 @@ def make_backend(backend_name: str) -> AttentionBackend: if backend_name == STR_XFORMERS_ATTN_VAL: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() + elif backend_name == STR_FLASH_ATTN_VAL: + print('Hello') + from vllm.attention.backends.flash_attn import FlashAttentionBackend + return FlashAttentionBackend() + raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d54dbdcb19495..e0d75375609cd 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -108,26 +108,12 @@ class FlashAttentionMetadata(AttentionMetadata): # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] @@ -143,11 +129,63 @@ class FlashAttentionMetadata(AttentionMetadata): # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool + use_cuda_graph: Optional[bool] + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_prefills == 0: @@ -179,6 +217,12 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables ) return self._cached_prefill_metadata @@ -210,6 +254,12 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables ) return self._cached_decode_metadata @@ -571,16 +621,20 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + output = torch.ops.vllm.unified_flash_attention( query, key, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b46f0721d0caf..a6b7052c68194 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -78,6 +78,7 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() + print('dtype ' + str(dtype)) attn_backend = get_attn_backend(head_size, sliding_window, dtype, kv_cache_dtype, block_size, is_attention_free, blocksparse_params From 551008e445d82eeaf8609093716173141b08729b Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 22 Oct 2024 04:07:42 +0000 Subject: [PATCH 02/32] Run with flash attn --- tests/kernels/test_encoder_decoder_attn.py | 49 +++--- tests/kernels/utils.py | 20 ++- vllm/attention/backends/abstract.py | 2 + vllm/attention/backends/flash_attn.py | 170 +++++++++++++++++---- vllm/attention/backends/utils.py | 60 ++++++++ vllm/attention/backends/xformers.py | 5 + vllm/attention/layer.py | 1 + vllm/attention/selector.py | 11 +- 8 files changed, 263 insertions(+), 55 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 1cf11bb7adca8..32fb8a9011f81 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -11,6 +11,7 @@ import pytest import torch +from vllm.forward_context import get_forward_context, set_forward_context from tests.kernels.utils import * from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, @@ -620,14 +621,15 @@ def _run_encoder_attention_test( attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, - packed_qkv.key, - packed_qkv.value, - torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + return attn.forward(packed_qkv.query, + packed_qkv.key, + packed_qkv.value, + torch.tensor([], + dtype=torch.float32, + device=packed_qkv.query.device), + attn_metadata, + attn_type=attn_type) def _run_decoder_self_attention_test( @@ -661,12 +663,13 @@ def _run_decoder_self_attention_test( kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, - packed_qkv.key, - packed_qkv.value, - kv_cache, - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + return attn.forward(packed_qkv.query, + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -719,12 +722,13 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, - key, - value, - kv_cache, - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -787,6 +791,7 @@ def test_encoder_only( # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init test_rsrcs = _make_test_resources(test_pt) + print('test_rsrcs.attn_backend ' + str(test_rsrcs.attn_backend)) # Construct encoder attention test params (only used # during prefill) @@ -803,6 +808,8 @@ def test_encoder_only( encoder_test_params=enc_test_params, cross_test_params=None, device=CUDA_DEVICE) + + print('prephase_attn_metadata ' + str(prephase_attn_metadata)) # PREFILL: encoder attention @@ -893,7 +900,7 @@ def test_e2e_enc_dec_attn( # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - + torch.set_default_dtype(torch.bfloat16) # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index c906da1a8ccf5..4b84b08a2c51b 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -539,7 +539,7 @@ def _make_metadata_tensors( seq_lens: Optional[List[int]], context_lens: Optional[List[int]], encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] ) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], - torch.Tensor, Optional[int]]: + torch.Tensor, torch.Tensor, Optional[int]]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -569,9 +569,18 @@ def _make_metadata_tensors( max(encoder_seq_lens)) seq_start_loc = None + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=encoder_seq_lens_tensor.device) + torch.cumsum( + encoder_seq_lens_tensor, dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + print('encoder_seq_start_loc ' + str(encoder_seq_start_loc)) return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, - seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len) + seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len) def make_kv_cache(num_blocks: int, @@ -805,6 +814,7 @@ def make_test_metadata( * AttentionMetadata structure ''' + print('Here for metadata!!!') # Decoder self-attention memory mapping # decoder_test_params is None signals encoder-only # scenario, so kv_mmap is None @@ -850,6 +860,8 @@ def make_test_metadata( # (kv_mmap) cross_kv_mmap = cross_test_params.kv_mmap + print('Here for metadata!!') + if is_prompt: # Prefill-phase scenario @@ -864,6 +876,7 @@ def make_test_metadata( _, _, encoder_seq_lens_tensor, + encoder_seq_start_loc, max_encoder_seq_len, ) = _make_metadata_tensors(seq_lens, context_lens, @@ -885,6 +898,7 @@ def make_test_metadata( num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, + encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), @@ -909,6 +923,7 @@ def make_test_metadata( _, _, encoder_seq_lens_tensor, + encoder_seq_start_loc, max_encoder_seq_len, ) = _make_metadata_tensors(seq_lens, context_lens, @@ -930,6 +945,7 @@ def make_test_metadata( num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, + encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2bc36ff18a96b..52e91ea751a7c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -16,9 +16,11 @@ class AttentionType(Enum): DECODER = auto() # Decoder attention between previous layer Q/K/V ENCODER = auto() # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + class AttentionBackend(ABC): """Abstract class for attention backends.""" diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e0d75375609cd..518d30af43d31 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -12,7 +12,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, - is_block_tables_empty) + is_block_tables_empty, get_seq_len_block_table_args) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -154,6 +154,12 @@ class FlashAttentionMetadata(AttentionMetadata): # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None @@ -194,32 +200,51 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + #assert self.seq_start_loc is not None + + #assert self.context_lens_tensor is not None + #assert self.block_tables is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) self._cached_prefill_metadata = FlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_query_len=0, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=False, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables @@ -233,16 +258,25 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: if self._cached_decode_metadata is not None: return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + self._cached_decode_metadata = FlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + slot_mapping=slot_mapping, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, @@ -252,11 +286,12 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], + block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables @@ -507,6 +542,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) + print('build.seq_start_loc ' + str(seq_start_loc)) + return FlashAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -625,6 +662,8 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + print('Hello I am here!!') + if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " @@ -650,6 +689,7 @@ def forward( self.sliding_window, self.alibi_slopes, self.logits_soft_cap, + attn_type.value, ) return output @@ -672,23 +712,42 @@ def unified_flash_attention( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, + attn_type : int = 0, ) -> torch.Tensor: + print('attn_type ' + str(attn_type)) + current_metadata = get_forward_context() assert current_metadata is not None assert isinstance(current_metadata, FlashAttentionMetadata) attn_metadata: FlashAttentionMetadata = current_metadata - num_tokens, hidden_size = query.shape + print('query.shape ' + str(query.shape)) + #num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. query = query.view(-1, num_heads, head_size) + hidden_size = num_heads * head_size + num_tokens = query.shape[0] key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) if kv_cache.numel() > 0: + print('Hell!!') key_cache = kv_cache[0] value_cache = kv_cache[1] + if attn_type == 1: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + print('Hello calling reshape_and_cache_flash') + # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. @@ -697,19 +756,53 @@ def unified_flash_attention( value, kv_cache[0], kv_cache[1], - attn_metadata.slot_mapping.flatten(), + updated_slot_mapping.flatten(), kv_cache_dtype, k_scale, v_scale, ) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens + print('Hell221!!') + #num_prefill_tokens = attn_metadata.num_prefill_tokens + #num_decode_tokens = attn_metadata.num_decode_tokens + print('attn_type ' + str(attn_type)) + print('AttentionType.ENCODER ' + str(AttentionType.ENCODER.value)) + if attn_type == AttentionType.ENCODER.value: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + print('Hello!!!') + elif attn_type == AttentionType.DECODER.value: + # Decoder self-attention supports chunked prefill. + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + else: # attn_type == AttentionType.ENCODER_DECODER + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + if attn_metadata.num_encoder_tokens is not None: + num_encoder_tokens = attn_metadata.num_encoder_tokens + else: + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + + print('Hell223!!') + print('key.shape[0] ' + str(key.shape[0])) assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + print('Hell224!!') assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - + print('Hell225!!') # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. @@ -722,6 +815,7 @@ def unified_flash_attention( prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None + print('Hell222!!') if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. @@ -730,22 +824,40 @@ def unified_flash_attention( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. + print('prefill_meta ' + str(prefill_meta)) + _, seq_start_loc, max_prefill_seq_len, _ = get_seq_len_block_table_args( + prefill_meta, True, attn_type) + causal = True + if (attn_type == AttentionType.ENCODER.value or \ + attn_type == AttentionType.ENCODER_ONLY.value) : + causal = False + + print('seq_start_loc ' + str(seq_start_loc)) + print('max_prefill_seq_len ' + str(max_prefill_seq_len)) prefill_output = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + #cu_seqlens_q=prefill_meta.seq_start_loc, + #cu_seqlens_k=prefill_meta.seq_start_loc, + #max_seqlen_q=prefill_meta.max_prefill_seq_len, + #max_seqlen_k=prefill_meta.max_prefill_seq_len, + cu_seqlens_q=seq_start_loc, + cu_seqlens_k=seq_start_loc, + max_seqlen_q=max_prefill_seq_len, + max_seqlen_k=max_prefill_seq_len, softmax_scale=softmax_scale, - causal=True, + #causal=True, + #causal=False, + causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, ) else: # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Decoder only models currently support prefix caching") assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) prefill_output = flash_attn_varlen_func( # noqa diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 358a223e7ed0e..46209d40bdb31 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -8,6 +8,7 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.attention.backends.abstract import AttentionType if TYPE_CHECKING: from vllm.worker.model_runner_base import ModelRunnerBase @@ -432,3 +433,62 @@ def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, input_buffers["cross_block_tables"].copy_( attn_metadata.decode_metadata.cross_block_tables, non_blocking=True) + +def get_seq_len_block_table_args( + attn_metadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER.value: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, attn_metadata.seq_start_loc, + max_seq_len, attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER.value: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER.value: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, None) + elif attn_type == AttentionType.ENCODER_ONLY.value: + assert is_prompt, "Should not have decode for encoder only model." + + # No block tables associated with encoder attention + return (attn_metadata.seq_lens_tensor, attn_metadata.seq_start_loc, + attn_metadata.max_prefill_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 25b86176f630e..25941347c9fe5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -135,6 +135,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a6b7052c68194..9ed2fb398f07c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -83,6 +83,7 @@ def __init__( kv_cache_dtype, block_size, is_attention_free, blocksparse_params is not None) + print('attn_backend ' + str(attn_backend)) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7edb7676ea2cd..499d8e554d80b 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -87,7 +87,7 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend -@lru_cache(maxsize=None) +#@lru_cache(maxsize=None) def get_attn_backend( head_size: int, sliding_window: Optional[int], @@ -98,13 +98,14 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - + print('In get_attn_backend1') if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - + + print('In get_attn_backend2') backend = which_attn_to_use(head_size, sliding_window, dtype, kv_cache_dtype, block_size, is_attention_free) if backend == _Backend.FLASH_ATTN: @@ -177,6 +178,7 @@ def which_attn_to_use( # ENVIRONMENT VARIABLE. backend_by_global_setting: Optional[_Backend] = ( get_global_forced_attn_backend()) + print('backend_by_global_setting ' + str(backend_by_global_setting)) if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -296,6 +298,9 @@ def global_force_attn_backend_context_manager( # Globally force the new backend override global_force_attn_backend(attn_backend) + print('original value ' + str(original_value)) + print('new value ' + str(attn_backend)) + # Yield control back to the enclosed code block try: yield From 99cfaf1a7a6740cd4be86c0bbdac8019955e8fa2 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 23 Oct 2024 21:52:42 +0000 Subject: [PATCH 03/32] Flash Attn support --- tests/kernels/test_encoder_decoder_attn.py | 50 ++++++++------- tests/kernels/utils.py | 21 ++++-- vllm/attention/backends/flash_attn.py | 36 +++++++---- vllm/attention/backends/utils.py | 75 ++++++++++++++++++++++ 4 files changed, 142 insertions(+), 40 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 32fb8a9011f81..593d094c83e0d 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -22,7 +22,8 @@ from vllm.utils import is_hip # List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] +#LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] HEAD_SIZES = [64, 256] @@ -965,6 +966,9 @@ def test_e2e_enc_dec_attn( # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + print('First test succeeded') + print('prephase_attn_metadata ' + str(prephase_attn_metadata)) + # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( @@ -974,41 +978,43 @@ def test_e2e_enc_dec_attn( assert_actual_matches_ideal(prephase_dec_test_params, prephase_dec_pckd_act_out) + print('Second test succeeded') + # PREFILL: encoder/decoder cross-attention test - prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - prephase_attn_metadata) + #prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + # test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, + # prephase_attn_metadata) # - Is prefill encoder/decoder cross-attention correct? - assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out) + #assert_actual_matches_ideal(prephase_cross_test_params, + # prephase_cross_pckd_act_out) # DECODE: build decode-phase attention metadata - decphase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, - False, - dec_qkv.q_seq_lens, - decoder_test_params=decphase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=decphase_cross_test_params, - device=CUDA_DEVICE) + #decphase_attn_metadata: AttentionMetadata = make_test_metadata( + # test_rsrcs.attn_backend, + # False, + # dec_qkv.q_seq_lens, + # decoder_test_params=decphase_dec_test_params, + # encoder_test_params=enc_test_params, + # cross_test_params=decphase_cross_test_params, + # device=CUDA_DEVICE) # DECODE: decoder self-attention test - decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) + #decphase_dec_pckd_act_out = _run_decoder_self_attention_test( + # test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) # - Is decode-phase decoder self-attention correct? - assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out) + #assert_actual_matches_ideal(decphase_dec_test_params, + # decphase_dec_pckd_act_out) # DECODE: encoder/decoder cross-attention test - decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) + #decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + # test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) # - Is decode-phase encoder/decoder cross-attention correct? - assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) + #assert_actual_matches_ideal(decphase_cross_test_params, + # decphase_cross_pckd_act_out) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 4b84b08a2c51b..ac2332dd20b52 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -538,7 +538,7 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors( seq_lens: Optional[List[int]], context_lens: Optional[List[int]], encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] -) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], +) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[int]]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -567,8 +567,19 @@ def _make_metadata_tensors( encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) max_encoder_seq_len = (None if encoder_seq_lens is None else max(encoder_seq_lens)) - + seq_start_loc = None + + if seq_lens_tensor is not None: + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=seq_lens_tensor.device) + torch.cumsum( + seq_lens_tensor, dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, dtype=torch.int32, @@ -874,7 +885,7 @@ def make_test_metadata( context_lens_tensor, _, _, - _, + seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len, @@ -890,6 +901,7 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + seq_start_loc=seq_start_loc, max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, @@ -921,7 +933,7 @@ def make_test_metadata( context_lens_tensor, _, _, - _, + seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len, @@ -937,6 +949,7 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + seq_start_loc=seq_start_loc, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), context_lens_tensor=context_lens_tensor, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 518d30af43d31..7b88142db10b3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -12,7 +12,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, - is_block_tables_empty, get_seq_len_block_table_args) + is_block_tables_empty, get_query_key_seq_metadata) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -797,11 +797,11 @@ def unified_flash_attention( print('Hell223!!') print('key.shape[0] ' + str(key.shape[0])) - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + #assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + # f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa print('Hell224!!') - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + #assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + # f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa print('Hell225!!') # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] @@ -816,24 +816,29 @@ def unified_flash_attention( prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None print('Hell222!!') - if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. + print('Running in prefill!!!') if (kv_cache.numel() == 0 or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - print('prefill_meta ' + str(prefill_meta)) - _, seq_start_loc, max_prefill_seq_len, _ = get_seq_len_block_table_args( + #print('prefill_meta ' + str(prefill_meta)) + print('Hello123') + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = get_query_key_seq_metadata( prefill_meta, True, attn_type) + #print('seq_start_loc ' + str(seq_start_loc)) + #print('max_prefill_seq_len ' + str(max_prefill_seq_len)) causal = True if (attn_type == AttentionType.ENCODER.value or \ attn_type == AttentionType.ENCODER_ONLY.value) : causal = False + print('k_seq_len ' + str(k_seq_len)) + #print('q_seq_start_loc ' + str(q_seq_start_loc)) + #print('k_seq_start_loc ' + str(k_seq_start_loc)) + print('q_seq_len ' + str(q_seq_len)) - print('seq_start_loc ' + str(seq_start_loc)) - print('max_prefill_seq_len ' + str(max_prefill_seq_len)) prefill_output = flash_attn_varlen_func( q=query, k=key, @@ -842,10 +847,10 @@ def unified_flash_attention( #cu_seqlens_k=prefill_meta.seq_start_loc, #max_seqlen_q=prefill_meta.max_prefill_seq_len, #max_seqlen_k=prefill_meta.max_prefill_seq_len, - cu_seqlens_q=seq_start_loc, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max_prefill_seq_len, - max_seqlen_k=max_prefill_seq_len, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, softmax_scale=softmax_scale, #causal=True, #causal=False, @@ -855,6 +860,7 @@ def unified_flash_attention( softcap=logits_soft_cap, ) else: + print('Going into prefix') # prefix-enabled attention assert attn_type == AttentionType.DECODER, ( "Decoder only models currently support prefix caching") @@ -876,6 +882,7 @@ def unified_flash_attention( ) if decode_meta := attn_metadata.decode_metadata: + print('Running in decode!!!') # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. @@ -897,6 +904,7 @@ def unified_flash_attention( ) else: # Use flash_attn_with_kvcache for normal decoding. + print('Running here!!!') decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 46209d40bdb31..34035054972a9 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -434,6 +434,81 @@ def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, attn_metadata.decode_metadata.cross_block_tables, non_blocking=True) +def get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + print('Hello456') + if attn_type == AttentionType.DECODER.value: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + print('attn_metadata.seq_start_loc ' + str(attn_metadata.seq_start_loc)) + #print('attn_metadata.encoder_seq_start_loc ' + str(attn_metadata.encoder_seq_start_loc)) + print('attn_metadata.max_decode_seq_len ' + str(attn_metadata.max_decode_seq_len)) + print('attn_metadata.max_prefill_seq_len ' + str(attn_metadata.max_prefill_seq_len)) + #print('attn_metadata.max_encoder_seq_len ' + str(attn_metadata.max_encoder_seq_len)) + + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER.value: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + print('attn_metadata.seq_start_loc ' + str(attn_metadata.seq_start_loc)) + print('attn_metadata.encoder_seq_start_loc ' + str(attn_metadata.encoder_seq_start_loc)) + print('attn_metadata.max_decode_seq_len ' + str(attn_metadata.max_decode_seq_len)) + print('attn_metadata.max_prefill_seq_len ' + str(attn_metadata.max_prefill_seq_len)) + print('attn_metadata.max_encoder_seq_len ' + str(attn_metadata.max_encoder_seq_len)) + return (attn_metadata.seq_start_loc, + max(attn_metadata.max_decode_seq_len, attn_metadata.max_prefill_seq_len), + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER.value: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY.value: + assert is_prompt, "Should not have decode for encoder only model." + + # No block tables associated with encoder attention + return (attn_metadata.seq_start_loc, + attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, + attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( attn_metadata, is_prompt: bool, From 1e1fb570bd9addae6fdc4d7fb04e9d862f6ce214 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 25 Oct 2024 17:53:53 +0000 Subject: [PATCH 04/32] Some more fixes --- tests/kernels/test_encoder_decoder_attn.py | 74 ++++++++++++---------- tests/kernels/utils.py | 31 ++++++--- vllm/attention/backends/flash_attn.py | 11 +++- vllm/attention/backends/utils.py | 7 +- 4 files changed, 78 insertions(+), 45 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 593d094c83e0d..ca483f1cd9498 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -22,12 +22,13 @@ from vllm.utils import is_hip # List of support backends for encoder/decoder models -#LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] - +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] +#LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] +#LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.FLASH_ATTN] HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] +#NUM_HEADS = [1] BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] @@ -131,7 +132,7 @@ class that Attention will automatically select when it is constructed. scale = float(1.0 / (test_pt.head_size**0.5)) attn_backend = make_backend(test_pt.backend_name) - print('attn_backend ' + str(attn_backend)) + #print('attn_backend ' + str(attn_backend)) attn = Attention( test_pt.num_heads, test_pt.head_size, @@ -148,7 +149,7 @@ class that Attention will automatically select when it is constructed. test_pt.num_heads, test_pt.head_size, test_pt.block_size, - device=CUDA_DEVICE) + device=CUDA_DEVICE, backend=test_pt.backend_name) return TestResources(scale, attn_backend, attn, kv_cache) @@ -810,7 +811,7 @@ def test_encoder_only( cross_test_params=None, device=CUDA_DEVICE) - print('prephase_attn_metadata ' + str(prephase_attn_metadata)) + #print('prephase_attn_metadata ' + str(prephase_attn_metadata)) # PREFILL: encoder attention @@ -820,6 +821,9 @@ def test_encoder_only( # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) +@pytest.fixture(autouse=True) +def print_test_name(request): + print(f"Running test: {request.node.name}") @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -959,15 +963,15 @@ def test_e2e_enc_dec_attn( # PREFILL: encoder attention - enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) + #enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, + # enc_test_params, + # prephase_attn_metadata) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + #assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) - print('First test succeeded') - print('prephase_attn_metadata ' + str(prephase_attn_metadata)) + #print('First test succeeded') + #print('prephase_attn_metadata ' + str(prephase_attn_metadata)) # PREFILL: decoder self-attention test @@ -982,39 +986,43 @@ def test_e2e_enc_dec_attn( # PREFILL: encoder/decoder cross-attention test - #prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - # test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - # prephase_attn_metadata) + prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, + prephase_attn_metadata) # - Is prefill encoder/decoder cross-attention correct? - #assert_actual_matches_ideal(prephase_cross_test_params, - # prephase_cross_pckd_act_out) + assert_actual_matches_ideal(prephase_cross_test_params, + prephase_cross_pckd_act_out) + + print('Third test succeeded') # DECODE: build decode-phase attention metadata - #decphase_attn_metadata: AttentionMetadata = make_test_metadata( - # test_rsrcs.attn_backend, - # False, - # dec_qkv.q_seq_lens, - # decoder_test_params=decphase_dec_test_params, - # encoder_test_params=enc_test_params, - # cross_test_params=decphase_cross_test_params, - # device=CUDA_DEVICE) + decphase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + False, + dec_qkv.q_seq_lens, + decoder_test_params=decphase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=decphase_cross_test_params, + device=CUDA_DEVICE) # DECODE: decoder self-attention test - #decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - # test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) + decphase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) # - Is decode-phase decoder self-attention correct? - #assert_actual_matches_ideal(decphase_dec_test_params, - # decphase_dec_pckd_act_out) + assert_actual_matches_ideal(decphase_dec_test_params, + decphase_dec_pckd_act_out) + print('Fourth test succeeded') # DECODE: encoder/decoder cross-attention test - #decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - # test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) + decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) # - Is decode-phase encoder/decoder cross-attention correct? - #assert_actual_matches_ideal(decphase_cross_test_params, - # decphase_cross_pckd_act_out) + assert_actual_matches_ideal(decphase_cross_test_params, + decphase_cross_pckd_act_out) + print('Fifth test succeeded') diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ac2332dd20b52..c8f2d6f1359cf 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -527,7 +527,7 @@ def make_backend(backend_name: str) -> AttentionBackend: from vllm.attention.backends.xformers import XFormersBackend return XFormersBackend() elif backend_name == STR_FLASH_ATTN_VAL: - print('Hello') + #print('Hello') from vllm.attention.backends.flash_attn import FlashAttentionBackend return FlashAttentionBackend() @@ -537,7 +537,7 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors( seq_lens: Optional[List[int]], context_lens: Optional[List[int]], - encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] + encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str], ) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[int]]: ''' @@ -580,6 +580,10 @@ def _make_metadata_tensors( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) + print('seq_start_loc ' + str(seq_start_loc)) + print('seq_lens_tensor ' + str(seq_lens_tensor)) + print('max_seq_len ' + str(max_seq_len)) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, dtype=torch.int32, @@ -588,7 +592,9 @@ def _make_metadata_tensors( encoder_seq_lens_tensor, dim=0, dtype=encoder_seq_start_loc.dtype, out=encoder_seq_start_loc[1:]) - print('encoder_seq_start_loc ' + str(encoder_seq_start_loc)) + + + #print('encoder_seq_start_loc ' + str(encoder_seq_start_loc)) return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len) @@ -599,6 +605,7 @@ def make_kv_cache(num_blocks: int, head_size: int, block_size: int, device: Union[torch.device, str], + backend: str, default_val: float = 0.0) -> torch.Tensor: ''' Create a fake KV cache. @@ -616,9 +623,14 @@ def make_kv_cache(num_blocks: int, * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) ''' - - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) + if backend == 'XFORMERS': + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) + elif backend == 'FLASH_ATTN': + kv_cache = torch.rand( + (2, num_blocks, block_size, num_heads, head_size)).to(device) + else: + raise ValueError(f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'.") if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache @@ -825,7 +837,7 @@ def make_test_metadata( * AttentionMetadata structure ''' - print('Here for metadata!!!') + #print('Here for metadata!!!') # Decoder self-attention memory mapping # decoder_test_params is None signals encoder-only # scenario, so kv_mmap is None @@ -871,7 +883,7 @@ def make_test_metadata( # (kv_mmap) cross_kv_mmap = cross_test_params.kv_mmap - print('Here for metadata!!') + #print('Here for metadata!!') if is_prompt: # Prefill-phase scenario @@ -952,6 +964,7 @@ def make_test_metadata( seq_start_loc=seq_start_loc, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), + max_decode_query_len=1, context_lens_tensor=context_lens_tensor, block_tables=kv_mmap.block_tables, use_cuda_graph=False, @@ -979,7 +992,7 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, ''' ideal_output = test_params.packed_qkvo.ideal_output torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output)) + output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016) # Copied/modified from torch._refs.__init__.py diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7b88142db10b3..75ad366f56b0e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -723,6 +723,8 @@ def unified_flash_attention( attn_metadata: FlashAttentionMetadata = current_metadata print('query.shape ' + str(query.shape)) + print('num_heads ' + str(num_heads)) + print('num_kv_heads ' + str(num_kv_heads)) #num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. @@ -737,16 +739,18 @@ def unified_flash_attention( key_cache = kv_cache[0] value_cache = kv_cache[1] - if attn_type == 1: + if attn_type == 4: # Update cross-attention KV cache (prefill-only) # During cross-attention decode, key & value will be None, # preventing this IF-statement branch from running updated_slot_mapping = attn_metadata.cross_slot_mapping else: # Update self-attention KV cache (prefill/decode) + print('I am here!!!') updated_slot_mapping = attn_metadata.slot_mapping print('Hello calling reshape_and_cache_flash') + print('updated_slot_mapping ' + str(updated_slot_mapping)) # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are @@ -832,7 +836,8 @@ def unified_flash_attention( #print('max_prefill_seq_len ' + str(max_prefill_seq_len)) causal = True if (attn_type == AttentionType.ENCODER.value or \ - attn_type == AttentionType.ENCODER_ONLY.value) : + attn_type == AttentionType.ENCODER_ONLY.value or \ + attn_type == AttentionType.ENCODER_DECODER.value) : causal = False print('k_seq_len ' + str(k_seq_len)) #print('q_seq_start_loc ' + str(q_seq_start_loc)) @@ -905,6 +910,8 @@ def unified_flash_attention( else: # Use flash_attn_with_kvcache for normal decoding. print('Running here!!!') + print('decode_meta.block_tables.shape ' + str(decode_meta.block_tables.shape)) + print('key_cache.shape ' + str(key_cache.shape)) decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 34035054972a9..ce10e195214a7 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -466,11 +466,16 @@ def get_query_key_seq_metadata( if attn_type == AttentionType.DECODER.value: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run + print('is_prompt ' + str(is_prompt)) + #print('attn_metadata ' + str(attn_metadata)) if is_prompt: max_seq_len = attn_metadata.max_prefill_seq_len else: max_seq_len = attn_metadata.max_decode_seq_len - print('attn_metadata.seq_start_loc ' + str(attn_metadata.seq_start_loc)) + #if attn_metadata.seq_start_loc is None: + # print('attn_metadata.seq_start_loc is None') + #else: + # print('attn_metadata.seq_start_loc ' + str(attn_metadata.seq_start_loc)) #print('attn_metadata.encoder_seq_start_loc ' + str(attn_metadata.encoder_seq_start_loc)) print('attn_metadata.max_decode_seq_len ' + str(attn_metadata.max_decode_seq_len)) print('attn_metadata.max_prefill_seq_len ' + str(attn_metadata.max_prefill_seq_len)) From e16cbcbbc8f26f0a7282be23e6280debd966957c Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 25 Oct 2024 18:44:42 +0000 Subject: [PATCH 05/32] More fixes --- vllm/attention/backends/flash_attn.py | 139 +++++++++++++++++++------- 1 file changed, 105 insertions(+), 34 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 75ad366f56b0e..a9fac1713d8c8 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -695,6 +695,65 @@ def forward( return output +def _get_seq_len_block_table_args( + attn_metadata: FlashAttentionMetadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER.value: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER.value: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER.value: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + elif attn_type == AttentionType.ENCODER_ONLY.value: + assert is_prompt, "Should not have decode for encoder only model." + + # No block tables associated with encoder attention + return (attn_metadata.seq_lens_tensor, + attn_metadata.max_prefill_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + + @torch.library.custom_op("vllm::unified_flash_attention", mutates_args=["kv_cache"]) def unified_flash_attention( @@ -731,40 +790,42 @@ def unified_flash_attention( query = query.view(-1, num_heads, head_size) hidden_size = num_heads * head_size num_tokens = query.shape[0] - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) + if (key is not None) and (value is not None): + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) - if kv_cache.numel() > 0: + if kv_cache.numel() > 0 : print('Hell!!') key_cache = kv_cache[0] value_cache = kv_cache[1] - if attn_type == 4: - # Update cross-attention KV cache (prefill-only) - # During cross-attention decode, key & value will be None, - # preventing this IF-statement branch from running - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - print('I am here!!!') - updated_slot_mapping = attn_metadata.slot_mapping - - print('Hello calling reshape_and_cache_flash') - print('updated_slot_mapping ' + str(updated_slot_mapping)) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) + if (key is not None) and (value is not None): + if attn_type == 4: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + print('I am here!!!') + updated_slot_mapping = attn_metadata.slot_mapping + + print('Hello calling reshape_and_cache_flash') + print('updated_slot_mapping ' + str(updated_slot_mapping)) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) print('Hell221!!') #num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -800,7 +861,7 @@ def unified_flash_attention( num_decode_tokens = attn_metadata.num_decode_tokens print('Hell223!!') - print('key.shape[0] ' + str(key.shape[0])) + #print('key.shape[0] ' + str(key.shape[0])) #assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ # f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa print('Hell224!!') @@ -811,8 +872,9 @@ def unified_flash_attention( decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] + if (key is not None) and (value is not None): + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens @@ -912,12 +974,21 @@ def unified_flash_attention( print('Running here!!!') print('decode_meta.block_tables.shape ' + str(decode_meta.block_tables.shape)) print('key_cache.shape ' + str(key_cache.shape)) + + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) + decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, + #block_table=decode_meta.block_tables, + block_table=block_tables_arg, + #cache_seqlens=decode_meta.seq_lens_tensor, + cache_seqlens=seq_lens_arg, softmax_scale=softmax_scale, causal=True, alibi_slopes=alibi_slopes, From 1b2c06045c31c18565f0fcaead00e94fe282171a Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 25 Oct 2024 21:28:44 +0000 Subject: [PATCH 06/32] More fixes to model_runner --- vllm/attention/backends/flash_attn.py | 51 +++++++++++---------------- vllm/attention/backends/utils.py | 39 ++++++++++---------- vllm/worker/enc_dec_model_runner.py | 5 ++- 3 files changed, 46 insertions(+), 49 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index a9fac1713d8c8..4127806d81527 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -154,16 +154,12 @@ class FlashAttentionMetadata(AttentionMetadata): # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. encoder_seq_start_loc: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None - # Number of tokens input to encoder num_encoder_tokens: Optional[int] = None @@ -204,10 +200,6 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: or (self.encoder_seq_lens is not None)) assert ((self.seq_lens_tensor is not None) or (self.encoder_seq_lens_tensor is not None)) - #assert self.seq_start_loc is not None - - #assert self.context_lens_tensor is not None - #assert self.block_tables is not None # Compute some attn_metadata fields which default to None query_start_loc = (None if self.query_start_loc is None else @@ -795,11 +787,10 @@ def unified_flash_attention( value = value.view(-1, num_kv_heads, head_size) if kv_cache.numel() > 0 : - print('Hell!!') key_cache = kv_cache[0] value_cache = kv_cache[1] - if (key is not None) and (value is not None): + if (attn_type != AttentionType.ENCODER.value) and (key is not None) and (value is not None): if attn_type == 4: # Update cross-attention KV cache (prefill-only) # During cross-attention decode, key & value will be None, @@ -807,11 +798,11 @@ def unified_flash_attention( updated_slot_mapping = attn_metadata.cross_slot_mapping else: # Update self-attention KV cache (prefill/decode) - print('I am here!!!') + #print('I am here!!!') updated_slot_mapping = attn_metadata.slot_mapping - print('Hello calling reshape_and_cache_flash') - print('updated_slot_mapping ' + str(updated_slot_mapping)) + #print('Hello calling reshape_and_cache_flash') + #print('updated_slot_mapping ' + str(updated_slot_mapping)) # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are @@ -827,11 +818,11 @@ def unified_flash_attention( v_scale, ) - print('Hell221!!') + #print('Hell221!!') #num_prefill_tokens = attn_metadata.num_prefill_tokens #num_decode_tokens = attn_metadata.num_decode_tokens - print('attn_type ' + str(attn_type)) - print('AttentionType.ENCODER ' + str(AttentionType.ENCODER.value)) + #print('attn_type ' + str(attn_type)) + #print('AttentionType.ENCODER ' + str(AttentionType.ENCODER.value)) if attn_type == AttentionType.ENCODER.value: # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them @@ -840,7 +831,7 @@ def unified_flash_attention( num_prefill_tokens = attn_metadata.num_encoder_tokens num_encoder_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 - print('Hello!!!') + #print('Hello!!!') elif attn_type == AttentionType.DECODER.value: # Decoder self-attention supports chunked prefill. num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -860,14 +851,14 @@ def unified_flash_attention( num_encoder_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - print('Hell223!!') + #print('Hell223!!') #print('key.shape[0] ' + str(key.shape[0])) #assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ # f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - print('Hell224!!') + #print('Hell224!!') #assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ # f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - print('Hell225!!') + #print('Hell225!!') # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. @@ -881,17 +872,17 @@ def unified_flash_attention( prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None - print('Hell222!!') + #print('Hell222!!') if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - print('Running in prefill!!!') + #print('Running in prefill!!!') if (kv_cache.numel() == 0 or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. #print('prefill_meta ' + str(prefill_meta)) - print('Hello123') + #print('Hello123') q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = get_query_key_seq_metadata( prefill_meta, True, attn_type) #print('seq_start_loc ' + str(seq_start_loc)) @@ -901,10 +892,10 @@ def unified_flash_attention( attn_type == AttentionType.ENCODER_ONLY.value or \ attn_type == AttentionType.ENCODER_DECODER.value) : causal = False - print('k_seq_len ' + str(k_seq_len)) + #print('k_seq_len ' + str(k_seq_len)) #print('q_seq_start_loc ' + str(q_seq_start_loc)) #print('k_seq_start_loc ' + str(k_seq_start_loc)) - print('q_seq_len ' + str(q_seq_len)) + #print('q_seq_len ' + str(q_seq_len)) prefill_output = flash_attn_varlen_func( q=query, @@ -927,7 +918,7 @@ def unified_flash_attention( softcap=logits_soft_cap, ) else: - print('Going into prefix') + #print('Going into prefix') # prefix-enabled attention assert attn_type == AttentionType.DECODER, ( "Decoder only models currently support prefix caching") @@ -949,7 +940,7 @@ def unified_flash_attention( ) if decode_meta := attn_metadata.decode_metadata: - print('Running in decode!!!') + #print('Running in decode!!!') # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. @@ -971,9 +962,9 @@ def unified_flash_attention( ) else: # Use flash_attn_with_kvcache for normal decoding. - print('Running here!!!') - print('decode_meta.block_tables.shape ' + str(decode_meta.block_tables.shape)) - print('key_cache.shape ' + str(key_cache.shape)) + #print('Running here!!!') + #print('decode_meta.block_tables.shape ' + str(decode_meta.block_tables.shape)) + #print('key_cache.shape ' + str(key_cache.shape)) ( seq_lens_arg, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index ce10e195214a7..4aa1cda169e84 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -315,14 +315,17 @@ def graph_capture_get_metadata_for_batch( block_tables=self._graph_block_tables[:batch_size], use_cuda_graph=True, ) - if is_encoder_decoder_model: + #if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ - f" got '{self.runner.attn_backend.get_name()}'" - self._update_captured_metadata_for_enc_dec_model( - batch_size=batch_size, attn_metadata=attn_metadata) + # assert self.runner.attn_backend.get_name() == "xformers", \ + # f"Expected attn_backend name to be 'xformers', but "\ + # f" got '{self.runner.attn_backend.get_name()}'" + # self._update_captured_metadata_for_enc_dec_model( + # batch_size=batch_size, attn_metadata=attn_metadata) + + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) return attn_metadata @@ -335,14 +338,14 @@ def get_graph_input_buffers( "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } - if is_encoder_decoder_model: + #if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ - f" got '{self.runner.attn_backend.get_name()}'" - self._add_additonal_input_buffers_for_enc_dec_model( - attn_metadata=attn_metadata, input_buffers=input_buffers) + #assert self.runner.attn_backend.get_name() == "xformers", \ + #f"Expected attn_backend name to be 'xformers', but "\ + #f" got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers def prepare_graph_input_buffers( @@ -354,14 +357,14 @@ def prepare_graph_input_buffers( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - if is_encoder_decoder_model: + #if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ - f" got '{self.runner.attn_backend.get_name()}'" - self._prepare_input_buffers_for_enc_dec_model( - attn_metadata, input_buffers) + #assert self.runner.attn_backend.get_name() == "xformers", \ + #f"Expected attn_backend name to be 'xformers', but "\ + #f" got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) def begin_forward(self, model_input) -> None: return diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 6a00444f5098b..a3325a7560b08 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -102,7 +102,7 @@ def __init__( the base-class constructor. ''' - self._maybe_force_supported_attention_backend() + #self._maybe_force_supported_attention_backend() super().__init__( model_config, @@ -514,6 +514,7 @@ def _prepare_encoder_model_input_tensors( encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, dtype=torch.int32, + dtype=torch.int32, device=self.device) torch.cumsum(encoder_seq_lens_tensor, dim=0, @@ -528,6 +529,7 @@ def _prepare_encoder_model_input_tensors( attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, attn_metadata.cross_slot_mapping, attn_metadata.cross_block_tables, ) = ( @@ -535,6 +537,7 @@ def _prepare_encoder_model_input_tensors( encoder_seq_lens, encoder_seq_lens_tensor, max_encoder_seq_len, + encoder_seq_start_loc, cross_slot_mapping_tensor, cross_block_tables, ) From b619e32a5fa6ff5f0c7dc635c423fb94c7ea9bd2 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sat, 26 Oct 2024 05:52:21 +0000 Subject: [PATCH 07/32] Fixes --- tests/kernels/test_encoder_decoder_attn.py | 18 ++- tests/kernels/utils.py | 48 +++--- vllm/attention/backends/abstract.py | 1 - vllm/attention/backends/flash_attn.py | 161 +++++++-------------- vllm/attention/backends/utils.py | 133 +++++------------ vllm/attention/backends/xformers.py | 36 +---- vllm/attention/selector.py | 2 +- vllm/model_executor/models/bart.py | 9 +- vllm/worker/enc_dec_model_runner.py | 1 - 9 files changed, 137 insertions(+), 272 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ca483f1cd9498..733006056da46 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -149,7 +149,8 @@ class that Attention will automatically select when it is constructed. test_pt.num_heads, test_pt.head_size, test_pt.block_size, - device=CUDA_DEVICE, backend=test_pt.backend_name) + device=CUDA_DEVICE, + backend=test_pt.backend_name) return TestResources(scale, attn_backend, attn, kv_cache) @@ -596,6 +597,7 @@ def _run_encoder_attention_test( attn: Attention, encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run encoder attention. @@ -624,12 +626,12 @@ def _run_encoder_attention_test( packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - return attn.forward(packed_qkv.query, + return attn.forward(packed_qkv.query.view(-1, test_pt.num_heads * test_pt.head_size), packed_qkv.key, packed_qkv.value, torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), + dtype=torch.float32, + device=packed_qkv.query.device), attn_metadata, attn_type=attn_type) @@ -810,21 +812,23 @@ def test_encoder_only( encoder_test_params=enc_test_params, cross_test_params=None, device=CUDA_DEVICE) - + #print('prephase_attn_metadata ' + str(prephase_attn_metadata)) # PREFILL: encoder attention enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( - test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) + test_rsrcs.attn, enc_test_params, prephase_attn_metadata, test_pt=test_pt)) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + @pytest.fixture(autouse=True) def print_test_name(request): print(f"Running test: {request.node.name}") + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -993,7 +997,7 @@ def test_e2e_enc_dec_attn( # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, prephase_cross_pckd_act_out) - + print('Third test succeeded') # DECODE: build decode-phase attention metadata diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index c8f2d6f1359cf..127b980480fbb 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,8 +13,8 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, STR_FLASH_ATTN_VAL, - make_tensor_with_pad) +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, + STR_FLASH_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -536,8 +536,10 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors( - seq_lens: Optional[List[int]], context_lens: Optional[List[int]], - encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str], + seq_lens: Optional[List[int]], + context_lens: Optional[List[int]], + encoder_seq_lens: Optional[List[int]], + device: Union[torch.device, str], ) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[int]]: ''' @@ -567,37 +569,35 @@ def _make_metadata_tensors( encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) max_encoder_seq_len = (None if encoder_seq_lens is None else max(encoder_seq_lens)) - + seq_start_loc = None if seq_lens_tensor is not None: - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + - 1, + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=seq_lens_tensor.device) - torch.cumsum( - seq_lens_tensor, dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) print('seq_start_loc ' + str(seq_start_loc)) print('seq_lens_tensor ' + str(seq_lens_tensor)) print('max_seq_len ' + str(max_seq_len)) - encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + - 1, + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=encoder_seq_lens_tensor.device) - torch.cumsum( - encoder_seq_lens_tensor, dim=0, - dtype=encoder_seq_start_loc.dtype, - out=encoder_seq_start_loc[1:]) - - + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + #print('encoder_seq_start_loc ' + str(encoder_seq_start_loc)) return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, - seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len) + seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, + max_encoder_seq_len) def make_kv_cache(num_blocks: int, @@ -630,7 +630,9 @@ def make_kv_cache(num_blocks: int, kv_cache = torch.rand( (2, num_blocks, block_size, num_heads, head_size)).to(device) else: - raise ValueError(f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'.") + raise ValueError( + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." + ) if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache @@ -992,7 +994,9 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, ''' ideal_output = test_params.packed_qkvo.ideal_output torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016) + output_under_test.view_as(ideal_output), + atol=0.01, + rtol=0.016) # Copied/modified from torch._refs.__init__.py diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 52e91ea751a7c..cb6d267092782 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -20,7 +20,6 @@ class AttentionType(Enum): ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V - class AttentionBackend(ABC): """Abstract class for attention backends.""" diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4127806d81527..f5ca33696bf39 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -12,7 +12,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, - is_block_tables_empty, get_query_key_seq_metadata) + is_block_tables_empty, + get_seq_len_block_table_args) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -201,7 +202,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: assert ((self.seq_lens_tensor is not None) or (self.encoder_seq_lens_tensor is not None)) - # Compute some attn_metadata fields which default to None + # Compute some attn_metadata fields which default to None query_start_loc = (None if self.query_start_loc is None else self.query_start_loc[:self.num_prefills + 1]) slot_mapping = (None if self.slot_mapping is None else @@ -211,7 +212,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills]) seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) + self.seq_start_loc[:self.num_prefills + 1]) context_lens_tensor = (None if self.context_lens_tensor is None else self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else @@ -239,8 +240,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: encoder_seq_start_loc=self.encoder_seq_start_loc, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables - ) + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -261,7 +261,6 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - self._cached_decode_metadata = FlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, @@ -286,8 +285,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: encoder_seq_start_loc=self.encoder_seq_start_loc, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables - ) + cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata def advance_step(self, @@ -654,8 +652,6 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - print('Hello I am here!!') - if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " @@ -687,65 +683,44 @@ def forward( return output -def _get_seq_len_block_table_args( - attn_metadata: FlashAttentionMetadata, +def _get_query_key_seq_metadata( + attn_metadata, is_prompt: bool, attn_type: AttentionType, ) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - - if attn_type == AttentionType.DECODER.value: + if attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run if is_prompt: max_seq_len = attn_metadata.max_prefill_seq_len else: max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_lens_tensor, max_seq_len, - attn_metadata.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER.value: + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables - return (attn_metadata.encoder_seq_lens_tensor, + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + return (attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len, - attn_metadata.cross_block_tables) - elif attn_type == AttentionType.ENCODER.value: - # No block tables associated with encoder attention - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, None) - elif attn_type == AttentionType.ENCODER_ONLY.value: + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: assert is_prompt, "Should not have decode for encoder only model." - - # No block tables associated with encoder attention - return (attn_metadata.seq_lens_tensor, - attn_metadata.max_prefill_seq_len, None) + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") - @torch.library.custom_op("vllm::unified_flash_attention", mutates_args=["kv_cache"]) def unified_flash_attention( @@ -763,10 +738,15 @@ def unified_flash_attention( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, - attn_type : int = 0, + attn_type_int_val: int = 0, ) -> torch.Tensor: - print('attn_type ' + str(attn_type)) + # Convert integer attn_type to enum + try: + attn_type = AttentionType(attn_type_int_val) + except ValueError: + raise AttributeError( + f"Invalid attention type {str(attn_type_int_val)}") current_metadata = get_forward_context() assert current_metadata is not None @@ -774,6 +754,10 @@ def unified_flash_attention( attn_metadata: FlashAttentionMetadata = current_metadata print('query.shape ' + str(query.shape)) + if key is not None: + print('key.shape ' + str(key.shape)) + print('value.shape ' + str(key.shape)) + print('attn_type ' + str(attn_type)) print('num_heads ' + str(num_heads)) print('num_kv_heads ' + str(num_kv_heads)) #num_tokens, hidden_size = query.shape @@ -786,24 +770,21 @@ def unified_flash_attention( key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) - if kv_cache.numel() > 0 : + if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] - if (attn_type != AttentionType.ENCODER.value) and (key is not None) and (value is not None): - if attn_type == 4: + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: # Update cross-attention KV cache (prefill-only) # During cross-attention decode, key & value will be None, # preventing this IF-statement branch from running updated_slot_mapping = attn_metadata.cross_slot_mapping else: # Update self-attention KV cache (prefill/decode) - #print('I am here!!!') updated_slot_mapping = attn_metadata.slot_mapping - #print('Hello calling reshape_and_cache_flash') - #print('updated_slot_mapping ' + str(updated_slot_mapping)) - # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. @@ -818,24 +799,16 @@ def unified_flash_attention( v_scale, ) - #print('Hell221!!') - #num_prefill_tokens = attn_metadata.num_prefill_tokens - #num_decode_tokens = attn_metadata.num_decode_tokens - #print('attn_type ' + str(attn_type)) - #print('AttentionType.ENCODER ' + str(AttentionType.ENCODER.value)) - if attn_type == AttentionType.ENCODER.value: + if attn_type == AttentionType.ENCODER: # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 - #print('Hello!!!') - elif attn_type == AttentionType.DECODER.value: + elif attn_type == AttentionType.DECODER: # Decoder self-attention supports chunked prefill. num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens # Only enforce this shape-constraint for decoder # self-attention @@ -845,13 +818,8 @@ def unified_flash_attention( # Encoder/decoder cross-attention requires no chunked # prefill (100% prefill or 100% decode tokens, no mix) num_prefill_tokens = attn_metadata.num_prefill_tokens - if attn_metadata.num_encoder_tokens is not None: - num_encoder_tokens = attn_metadata.num_encoder_tokens - else: - num_encoder_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - - #print('Hell223!!') + #print('key.shape[0] ' + str(key.shape[0])) #assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ # f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa @@ -872,53 +840,35 @@ def unified_flash_attention( prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None - #print('Hell222!!') if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - #print('Running in prefill!!!') if (kv_cache.numel() == 0 or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - #print('prefill_meta ' + str(prefill_meta)) - #print('Hello123') - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = get_query_key_seq_metadata( + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = _get_query_key_seq_metadata( prefill_meta, True, attn_type) - #print('seq_start_loc ' + str(seq_start_loc)) - #print('max_prefill_seq_len ' + str(max_prefill_seq_len)) causal = True - if (attn_type == AttentionType.ENCODER.value or \ - attn_type == AttentionType.ENCODER_ONLY.value or \ - attn_type == AttentionType.ENCODER_DECODER.value) : + if (attn_type == AttentionType.ENCODER or \ + attn_type == AttentionType.ENCODER_ONLY or \ + attn_type == AttentionType.ENCODER_DECODER) : causal = False - #print('k_seq_len ' + str(k_seq_len)) - #print('q_seq_start_loc ' + str(q_seq_start_loc)) - #print('k_seq_start_loc ' + str(k_seq_start_loc)) - #print('q_seq_len ' + str(q_seq_len)) - prefill_output = flash_attn_varlen_func( q=query, k=key, v=value, - #cu_seqlens_q=prefill_meta.seq_start_loc, - #cu_seqlens_k=prefill_meta.seq_start_loc, - #max_seqlen_q=prefill_meta.max_prefill_seq_len, - #max_seqlen_k=prefill_meta.max_prefill_seq_len, cu_seqlens_q=q_seq_start_loc, cu_seqlens_k=k_seq_start_loc, max_seqlen_q=q_seq_len, max_seqlen_k=k_seq_len, softmax_scale=softmax_scale, - #causal=True, - #causal=False, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, ) else: - #print('Going into prefix') # prefix-enabled attention assert attn_type == AttentionType.DECODER, ( "Decoder only models currently support prefix caching") @@ -940,7 +890,6 @@ def unified_flash_attention( ) if decode_meta := attn_metadata.decode_metadata: - #print('Running in decode!!!') # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. @@ -962,23 +911,17 @@ def unified_flash_attention( ) else: # Use flash_attn_with_kvcache for normal decoding. - #print('Running here!!!') - #print('decode_meta.block_tables.shape ' + str(decode_meta.block_tables.shape)) - #print('key_cache.shape ' + str(key_cache.shape)) - ( seq_lens_arg, - max_seq_len_arg, + _, block_tables_arg, - ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) - + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, - #block_table=decode_meta.block_tables, block_table=block_tables_arg, - #cache_seqlens=decode_meta.seq_lens_tensor, cache_seqlens=seq_lens_arg, softmax_scale=softmax_scale, causal=True, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 4aa1cda169e84..1f62cdcf6dd4e 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -316,8 +316,8 @@ def graph_capture_get_metadata_for_batch( use_cuda_graph=True, ) #if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. + # The encoder decoder model works only with XFormers backend. + # Assert the same. # assert self.runner.attn_backend.get_name() == "xformers", \ # f"Expected attn_backend name to be 'xformers', but "\ # f" got '{self.runner.attn_backend.get_name()}'" @@ -339,11 +339,11 @@ def get_graph_input_buffers( "block_tables": attn_metadata.decode_metadata.block_tables, } #if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - #assert self.runner.attn_backend.get_name() == "xformers", \ - #f"Expected attn_backend name to be 'xformers', but "\ - #f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers backend. + # Assert the same. + #assert self.runner.attn_backend.get_name() == "xformers", \ + #f"Expected attn_backend name to be 'xformers', but "\ + #f" got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers @@ -358,13 +358,13 @@ def prepare_graph_input_buffers( input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) #if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - #assert self.runner.attn_backend.get_name() == "xformers", \ - #f"Expected attn_backend name to be 'xformers', but "\ - #f" got '{self.runner.attn_backend.get_name()}'" - self._prepare_input_buffers_for_enc_dec_model( - attn_metadata, input_buffers) + # The encoder decoder model works only with XFormers backend. + # Assert the same. + #assert self.runner.attn_backend.get_name() == "xformers", \ + #f"Expected attn_backend name to be 'xformers', but "\ + #f" got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model(attn_metadata, + input_buffers) def begin_forward(self, model_input) -> None: return @@ -437,84 +437,25 @@ def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, attn_metadata.decode_metadata.cross_block_tables, non_blocking=True) -def get_query_key_seq_metadata( - attn_metadata, - is_prompt: bool, - attn_type: AttentionType, -) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: +def is_all_encoder_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((attn_metadata.encoder_seq_lens is not None) + and (attn_metadata.encoder_seq_lens_tensor is not None) + and (attn_metadata.max_encoder_seq_len is not None)) - * attn_metadata: Attention metadata structure associated with attention op - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) +def is_all_cross_attn_metadata_set(attn_metadata): ''' - print('Hello456') - if attn_type == AttentionType.DECODER.value: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - print('is_prompt ' + str(is_prompt)) - #print('attn_metadata ' + str(attn_metadata)) - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - #if attn_metadata.seq_start_loc is None: - # print('attn_metadata.seq_start_loc is None') - #else: - # print('attn_metadata.seq_start_loc ' + str(attn_metadata.seq_start_loc)) - #print('attn_metadata.encoder_seq_start_loc ' + str(attn_metadata.encoder_seq_start_loc)) - print('attn_metadata.max_decode_seq_len ' + str(attn_metadata.max_decode_seq_len)) - print('attn_metadata.max_prefill_seq_len ' + str(attn_metadata.max_prefill_seq_len)) - #print('attn_metadata.max_encoder_seq_len ' + str(attn_metadata.max_encoder_seq_len)) - - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.seq_start_loc, max_seq_len) - - elif attn_type == AttentionType.ENCODER_DECODER.value: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - print('attn_metadata.seq_start_loc ' + str(attn_metadata.seq_start_loc)) - print('attn_metadata.encoder_seq_start_loc ' + str(attn_metadata.encoder_seq_start_loc)) - print('attn_metadata.max_decode_seq_len ' + str(attn_metadata.max_decode_seq_len)) - print('attn_metadata.max_prefill_seq_len ' + str(attn_metadata.max_prefill_seq_len)) - print('attn_metadata.max_encoder_seq_len ' + str(attn_metadata.max_encoder_seq_len)) - return (attn_metadata.seq_start_loc, - max(attn_metadata.max_decode_seq_len, attn_metadata.max_prefill_seq_len), - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER.value: - # No block tables associated with encoder attention - return (attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER_ONLY.value: - assert is_prompt, "Should not have decode for encoder only model." + All attention metadata required for enc/dec cross-attention is set. - # No block tables associated with encoder attention - return (attn_metadata.seq_start_loc, - attn_metadata.max_prefill_seq_len, - attn_metadata.seq_start_loc, - attn_metadata.max_prefill_seq_len) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") + Superset of encoder attention required metadata. + ''' + return (attn_metadata.is_all_encoder_attn_metadata_set + and (attn_metadata.cross_slot_mapping is not None) + and (attn_metadata.cross_block_tables is not None)) def get_seq_len_block_table_args( @@ -546,32 +487,24 @@ def get_seq_len_block_table_args( * Appropriate block tables (or None) ''' - if attn_type == AttentionType.DECODER.value: + if attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run if is_prompt: max_seq_len = attn_metadata.max_prefill_seq_len else: max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_lens_tensor, attn_metadata.seq_start_loc, - max_seq_len, attn_metadata.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER.value: + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len, attn_metadata.cross_block_tables) - elif attn_type == AttentionType.ENCODER.value: + elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len, None) - elif attn_type == AttentionType.ENCODER_ONLY.value: - assert is_prompt, "Should not have decode for encoder only model." - - # No block tables associated with encoder attention - return (attn_metadata.seq_lens_tensor, attn_metadata.seq_start_loc, - attn_metadata.max_prefill_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 25941347c9fe5..a40a16ecc4c69 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -12,7 +12,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) + CommonMetadataBuilder, + is_all_encoder_attn_metadata_set, + is_all_cross_attn_metadata_set, + get_seq_len_block_table_args) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -167,9 +170,7 @@ def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + return is_all_encoder_attn_metadata_set(self) @property def is_all_cross_attn_metadata_set(self): @@ -178,9 +179,7 @@ def is_all_cross_attn_metadata_set(self): Superset of encoder attention required metadata. ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + return is_all_cross_attn_metadata_set(self) @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -356,28 +355,7 @@ def _get_seq_len_block_table_args( * Appropriate max sequence-length scalar * Appropriate block tables (or None) ''' - - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_lens_tensor, max_seq_len, - attn_metadata.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.cross_block_tables) - elif attn_type == AttentionType.ENCODER: - # No block tables associated with encoder attention - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, None) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") + return get_seq_len_block_table_args(attn_metadata, is_prompt, attn_type) class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 499d8e554d80b..6b509222a8dc0 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -104,7 +104,7 @@ def get_attn_backend( from vllm.attention.backends.blocksparse_attn import ( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - + print('In get_attn_backend2') backend = which_attn_to_use(head_size, sliding_window, dtype, kv_cache_dtype, block_size, is_attention_free) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index cbdacf779b089..d115ebb1c9249 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -184,6 +184,9 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + print('Enc qkv_shape ' + str(qkv.shape)) + print('Enc q_shape ' + str(q.shape)) + print('Enc hidden_states_shape ' + str(hidden_states.shape)) attn_output = self.attn(q, k, @@ -266,6 +269,10 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + print('Dec qkv_shape ' + str(qkv.shape)) + print('Dec q_shape ' + str(q.shape)) + print('Dec hidden_states_shape ' + str(hidden_states.shape)) + attn_output = self.attn(q, k, @@ -624,8 +631,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Decoder output torch.Tensor """ # retrieve input_ids and inputs_embeds - - input_ids = input_ids.view(-1, input_ids.shape[-1]) inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index a3325a7560b08..08ee01394228e 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -514,7 +514,6 @@ def _prepare_encoder_model_input_tensors( encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, dtype=torch.int32, - dtype=torch.int32, device=self.device) torch.cumsum(encoder_seq_lens_tensor, dim=0, From ffd82c0eaf06757f87d1281a381774301aae0651 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sun, 27 Oct 2024 16:33:15 +0000 Subject: [PATCH 08/32] commits --- vllm/attention/backends/flash_attn.py | 43 +++++++++++++++++++-------- vllm/attention/backends/utils.py | 18 +++++------ 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f5ca33696bf39..83c718f76487a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -798,17 +798,20 @@ def unified_flash_attention( k_scale, v_scale, ) - + + if attn_type == AttentionType.ENCODER: # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 elif attn_type == AttentionType.DECODER: # Decoder self-attention supports chunked prefill. num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens # Only enforce this shape-constraint for decoder # self-attention @@ -818,22 +821,33 @@ def unified_flash_attention( # Encoder/decoder cross-attention requires no chunked # prefill (100% prefill or 100% decode tokens, no mix) num_prefill_tokens = attn_metadata.num_prefill_tokens + if attn_metadata.num_encoder_tokens is not None: + num_encoder_tokens = attn_metadata.num_encoder_tokens + else: + num_encoder_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - - #print('key.shape[0] ' + str(key.shape[0])) - #assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - # f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - #print('Hell224!!') - #assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - # f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - #print('Hell225!!') # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] if (key is not None) and (value is not None): - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] + #print('key.shape[0] ' + str(key.shape[0])) + #assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + # f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + #print('Hell224!!') + #assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + # f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + #print('Hell225!!') + #if attn_type == AttentionType.ENCODER_DECODER: + # key = key + # value = value + #else: + # key = key[:num_prefill_tokens] + # value = value[:num_prefill_tokens] + key = key[:num_encoder_tokens] + value = value[:num_encoder_tokens] + print('key11.shape() ' + str(key.shape)) + print('value11.shape() ' + str(value.shape)) assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens @@ -916,6 +930,11 @@ def unified_flash_attention( _, block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + causal = True + if (attn_type == AttentionType.ENCODER or \ + attn_type == AttentionType.ENCODER_ONLY or \ + attn_type == AttentionType.ENCODER_DECODER) : + causal = False decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), @@ -924,7 +943,7 @@ def unified_flash_attention( block_table=block_tables_arg, cache_seqlens=seq_lens_arg, softmax_scale=softmax_scale, - causal=True, + causal=causal, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, ).squeeze(1) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 1f62cdcf6dd4e..f799e7e631fed 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -315,7 +315,7 @@ def graph_capture_get_metadata_for_batch( block_tables=self._graph_block_tables[:batch_size], use_cuda_graph=True, ) - #if is_encoder_decoder_model: + if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. # assert self.runner.attn_backend.get_name() == "xformers", \ @@ -324,8 +324,8 @@ def graph_capture_get_metadata_for_batch( # self._update_captured_metadata_for_enc_dec_model( # batch_size=batch_size, attn_metadata=attn_metadata) - self._update_captured_metadata_for_enc_dec_model( - batch_size=batch_size, attn_metadata=attn_metadata) + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) return attn_metadata @@ -338,14 +338,14 @@ def get_graph_input_buffers( "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } - #if is_encoder_decoder_model: + if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. #assert self.runner.attn_backend.get_name() == "xformers", \ #f"Expected attn_backend name to be 'xformers', but "\ #f" got '{self.runner.attn_backend.get_name()}'" - self._add_additonal_input_buffers_for_enc_dec_model( - attn_metadata=attn_metadata, input_buffers=input_buffers) + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers def prepare_graph_input_buffers( @@ -357,14 +357,14 @@ def prepare_graph_input_buffers( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - #if is_encoder_decoder_model: + if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. #assert self.runner.attn_backend.get_name() == "xformers", \ #f"Expected attn_backend name to be 'xformers', but "\ #f" got '{self.runner.attn_backend.get_name()}'" - self._prepare_input_buffers_for_enc_dec_model(attn_metadata, - input_buffers) + self._prepare_input_buffers_for_enc_dec_model(attn_metadata, + input_buffers) def begin_forward(self, model_input) -> None: return From d995221936a23b61f466fa497c233256db0a9143 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 28 Oct 2024 03:33:43 +0000 Subject: [PATCH 09/32] Fixes --- tests/encoder_decoder/test_e2e_correctness.py | 2 +- tests/kernels/test_encoder_decoder_attn.py | 49 +++---- .../vision_language/test_florence2.py | 3 +- vllm/attention/backends/flash_attn.py | 125 ++++++++---------- vllm/attention/backends/utils.py | 36 +++-- vllm/model_executor/models/bart.py | 7 - 6 files changed, 89 insertions(+), 133 deletions(-) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index bef0c515b9073..3f3f31f82eed0 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -29,7 +29,7 @@ def vllm_to_hf_output( @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["bfloat16", "float"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 733006056da46..3bcf86d4f6f97 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -23,12 +23,10 @@ # List of support backends for encoder/decoder models LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] -#LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] -#LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.FLASH_ATTN] HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] -#NUM_HEADS = [1] + BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] @@ -132,7 +130,6 @@ class that Attention will automatically select when it is constructed. scale = float(1.0 / (test_pt.head_size**0.5)) attn_backend = make_backend(test_pt.backend_name) - #print('attn_backend ' + str(attn_backend)) attn = Attention( test_pt.num_heads, test_pt.head_size, @@ -640,6 +637,7 @@ def _run_decoder_self_attention_test( test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run decoder self-attention test. @@ -668,7 +666,8 @@ def _run_decoder_self_attention_test( packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - return attn.forward(packed_qkv.query, + return attn.forward(#packed_qkv.query, + packed_qkv.query.view(-1, test_pt.num_heads * test_pt.head_size), packed_qkv.key, packed_qkv.value, kv_cache, @@ -681,6 +680,7 @@ def _run_encoder_decoder_cross_attention_test( decoder_test_params: PhaseTestParameters, cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -727,7 +727,9 @@ def _run_encoder_decoder_cross_attention_test( key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) with set_forward_context(attn_metadata): - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, + return attn.forward(#decoder_test_params.packed_qkvo.packed_qkv.query, + decoder_test_params.packed_qkvo.packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size), key, value, kv_cache, @@ -795,7 +797,6 @@ def test_encoder_only( # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init test_rsrcs = _make_test_resources(test_pt) - print('test_rsrcs.attn_backend ' + str(test_rsrcs.attn_backend)) # Construct encoder attention test params (only used # during prefill) @@ -813,8 +814,6 @@ def test_encoder_only( cross_test_params=None, device=CUDA_DEVICE) - #print('prephase_attn_metadata ' + str(prephase_attn_metadata)) - # PREFILL: encoder attention enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( @@ -823,12 +822,6 @@ def test_encoder_only( # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) - -@pytest.fixture(autouse=True) -def print_test_name(request): - print(f"Running test: {request.node.name}") - - @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -967,39 +960,33 @@ def test_e2e_enc_dec_attn( # PREFILL: encoder attention - #enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, - # enc_test_params, - # prephase_attn_metadata) + enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is encoder attention result correct? - #assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) - - #print('First test succeeded') - #print('prephase_attn_metadata ' + str(prephase_attn_metadata)) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) + test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, test_pt=test_pt) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, prephase_dec_pckd_act_out) - print('Second test succeeded') - # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - prephase_attn_metadata) + prephase_attn_metadata, test_pt=test_pt) # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, prephase_cross_pckd_act_out) - print('Third test succeeded') - # DECODE: build decode-phase attention metadata decphase_attn_metadata: AttentionMetadata = make_test_metadata( @@ -1014,19 +1001,17 @@ def test_e2e_enc_dec_attn( # DECODE: decoder self-attention test decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) + test_rsrcs, decphase_dec_test_params, decphase_attn_metadata, test_pt=test_pt) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, decphase_dec_pckd_act_out) - print('Fourth test succeeded') # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) + test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata, test_pt=test_pt) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, decphase_cross_pckd_act_out) - print('Fifth test succeeded') diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index 483773f069133..069fc287301e9 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -85,7 +85,8 @@ def run_test( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +#@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8e733b4f2e5fb..b2ff0952bd7f7 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -532,8 +532,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) - print('build.seq_start_loc ' + str(seq_start_loc)) - return FlashAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -677,7 +675,7 @@ def forward( return output -def _get_query_key_seq_metadata( +def _get_key_query_seq_metadata( attn_metadata, is_prompt: bool, attn_type: AttentionType, @@ -714,6 +712,43 @@ def _get_query_key_seq_metadata( else: raise AttributeError(f"Invalid attention type {str(attn_type)}") +def _get_num_prefill_encode_decode_tokens( + attn_metadata: FlashAttentionMetadata, + attn_type: AttentionType, +) -> tuple[int, int, int]: + if attn_type == AttentionType.ENCODER: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_tokens, num_encoder_tokens, num_decode_tokens) + +def _get_causal_option(attn_type: AttentionType)-> bool: + if (attn_type == AttentionType.ENCODER or \ + attn_type == AttentionType.ENCODER_ONLY or \ + attn_type == AttentionType.ENCODER_DECODER) : + return False + + return True + + + @torch.library.custom_op("vllm::unified_flash_attention", mutates_args=["kv_cache"]) @@ -747,19 +782,10 @@ def unified_flash_attention( assert isinstance(current_metadata, FlashAttentionMetadata) attn_metadata: FlashAttentionMetadata = current_metadata - print('query.shape ' + str(query.shape)) - if key is not None: - print('key.shape ' + str(key.shape)) - print('value.shape ' + str(key.shape)) - print('attn_type ' + str(attn_type)) - print('num_heads ' + str(num_heads)) - print('num_kv_heads ' + str(num_kv_heads)) - #num_tokens, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, num_heads, head_size) - hidden_size = num_heads * head_size - num_tokens = query.shape[0] if (key is not None) and (value is not None): key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) @@ -793,56 +819,11 @@ def unified_flash_attention( v_scale, ) - - if attn_type == AttentionType.ENCODER: - # Encoder attention - chunked prefill is not applicable; - # derive token-count from query shape & and treat them - # as 100% prefill tokens - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 - elif attn_type == AttentionType.DECODER: - # Decoder self-attention supports chunked prefill. - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - # Only enforce this shape-constraint for decoder - # self-attention - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - else: # attn_type == AttentionType.ENCODER_DECODER - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - if attn_metadata.num_encoder_tokens is not None: - num_encoder_tokens = attn_metadata.num_encoder_tokens - else: - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - # Query for decode. KV is not needed because it is already cached. + num_prefill_tokens, num_encoder_tokens, num_decode_tokens = \ + _get_num_prefill_encode_decode_tokens(attn_metadata, attn_type) decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - if (key is not None) and (value is not None): - #print('key.shape[0] ' + str(key.shape[0])) - #assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - # f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - #print('Hell224!!') - #assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - # f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - #print('Hell225!!') - #if attn_type == AttentionType.ENCODER_DECODER: - # key = key - # value = value - #else: - # key = key[:num_prefill_tokens] - # value = value[:num_prefill_tokens] - key = key[:num_encoder_tokens] - value = value[:num_encoder_tokens] - print('key11.shape() ' + str(key.shape)) - print('value11.shape() ' + str(value.shape)) - assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens @@ -855,13 +836,17 @@ def unified_flash_attention( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = _get_query_key_seq_metadata( + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = _get_key_query_seq_metadata( prefill_meta, True, attn_type) - causal = True + if (attn_type == AttentionType.ENCODER or \ - attn_type == AttentionType.ENCODER_ONLY or \ - attn_type == AttentionType.ENCODER_DECODER) : - causal = False + attn_type == AttentionType.ENCODER_DECODER): + key = key[:num_encoder_tokens] + value = value[:num_encoder_tokens] + else: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + prefill_output = flash_attn_varlen_func( q=query, k=key, @@ -871,7 +856,7 @@ def unified_flash_attention( max_seqlen_q=q_seq_len, max_seqlen_k=k_seq_len, softmax_scale=softmax_scale, - causal=causal, + causal=_get_causal_option(attn_type), window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, @@ -926,12 +911,6 @@ def unified_flash_attention( _, block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - causal = True - if (attn_type == AttentionType.ENCODER or \ - attn_type == AttentionType.ENCODER_ONLY or \ - attn_type == AttentionType.ENCODER_DECODER) : - causal = False - decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, @@ -939,7 +918,7 @@ def unified_flash_attention( block_table=block_tables_arg, cache_seqlens=seq_lens_arg, softmax_scale=softmax_scale, - causal=causal, + causal=True, window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index f799e7e631fed..723d42954936c 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -316,14 +316,11 @@ def graph_capture_get_metadata_for_batch( use_cuda_graph=True, ) if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - # assert self.runner.attn_backend.get_name() == "xformers", \ - # f"Expected attn_backend name to be 'xformers', but "\ - # f" got '{self.runner.attn_backend.get_name()}'" - # self._update_captured_metadata_for_enc_dec_model( - # batch_size=batch_size, attn_metadata=attn_metadata) - + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or 'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -339,11 +336,11 @@ def get_graph_input_buffers( "block_tables": attn_metadata.decode_metadata.block_tables, } if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - #assert self.runner.attn_backend.get_name() == "xformers", \ - #f"Expected attn_backend name to be 'xformers', but "\ - #f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or 'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers @@ -358,13 +355,13 @@ def prepare_graph_input_buffers( input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - #assert self.runner.attn_backend.get_name() == "xformers", \ - #f"Expected attn_backend name to be 'xformers', but "\ - #f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or 'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._prepare_input_buffers_for_enc_dec_model(attn_metadata, - input_buffers) + input_buffers) def begin_forward(self, model_input) -> None: return @@ -394,6 +391,7 @@ def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, attn_metadata.encoder_seq_lens_tensor = torch.full( (batch_size, ), 1, dtype=torch.int).cuda() attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.num_encoder_tokens = 0 def _add_additonal_input_buffers_for_enc_dec_model( self, attn_metadata, input_buffers: Dict[str, Any]): diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index d115ebb1c9249..0543ca978b7dd 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -184,9 +184,6 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - print('Enc qkv_shape ' + str(qkv.shape)) - print('Enc q_shape ' + str(q.shape)) - print('Enc hidden_states_shape ' + str(hidden_states.shape)) attn_output = self.attn(q, k, @@ -269,10 +266,6 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - print('Dec qkv_shape ' + str(qkv.shape)) - print('Dec q_shape ' + str(q.shape)) - print('Dec hidden_states_shape ' + str(hidden_states.shape)) - attn_output = self.attn(q, k, From d1f814018f084a74359307004b88266a0ac66ff4 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 28 Oct 2024 20:28:52 +0000 Subject: [PATCH 10/32] Format --- tests/encoder_decoder/test_e2e_correctness.py | 8 ++ tests/kernels/test_encoder_decoder_attn.py | 76 +++++++++------ tests/kernels/utils.py | 8 +- .../encoder_decoder/language/test_bart.py | 8 ++ vllm/attention/backends/flash_attn.py | 95 ++++++++++++++----- vllm/attention/backends/utils.py | 30 +++--- vllm/attention/backends/xformers.py | 5 +- vllm/attention/selector.py | 5 +- 8 files changed, 163 insertions(+), 72 deletions(-) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index 3f3f31f82eed0..df47cf152ad50 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -7,6 +7,7 @@ import pytest from transformers import AutoModelForSeq2SeqLM +from vllm.attention.selector import get_attn_backend from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs @@ -28,6 +29,13 @@ def vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear the cached value of attention backend before each test.""" + get_attn_backend.cache_clear() + yield + + @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("dtype", ["bfloat16", "float"]) @pytest.mark.parametrize("max_tokens", [128]) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3bcf86d4f6f97..848d8f5986f54 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -11,14 +11,14 @@ import pytest import torch -from vllm.forward_context import get_forward_context, set_forward_context from tests.kernels.utils import * from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.attention.selector import (_Backend, +from vllm.attention.selector import (_Backend, get_attn_backend, global_force_attn_backend_context_manager) +from vllm.forward_context import set_forward_context from vllm.utils import is_hip # List of support backends for encoder/decoder models @@ -27,7 +27,6 @@ NUM_HEADS = [1, 16] - BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] CUDA_DEVICE = "cuda:0" @@ -623,7 +622,8 @@ def _run_encoder_attention_test( packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - return attn.forward(packed_qkv.query.view(-1, test_pt.num_heads * test_pt.head_size), + return attn.forward(packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size), packed_qkv.key, packed_qkv.value, torch.tensor([], @@ -666,13 +666,13 @@ def _run_decoder_self_attention_test( packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - return attn.forward(#packed_qkv.query, - packed_qkv.query.view(-1, test_pt.num_heads * test_pt.head_size), - packed_qkv.key, - packed_qkv.value, - kv_cache, - attn_metadata, - attn_type=attn_type) + return attn.forward( #packed_qkv.query, + packed_qkv.query.view(-1, test_pt.num_heads * test_pt.head_size), + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -727,14 +727,21 @@ def _run_encoder_decoder_cross_attention_test( key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) with set_forward_context(attn_metadata): - return attn.forward(#decoder_test_params.packed_qkvo.packed_qkv.query, - decoder_test_params.packed_qkvo.packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size), - key, - value, - kv_cache, - attn_metadata, - attn_type=attn_type) + return attn.forward( #decoder_test_params.packed_qkvo.packed_qkv.query, + decoder_test_params.packed_qkvo.packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size), + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear the cached value of attention backend before each test.""" + get_attn_backend.cache_clear() + yield @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -782,7 +789,6 @@ def test_encoder_only( * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' - # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): torch.set_default_dtype(torch.bfloat16) @@ -817,11 +823,15 @@ def test_encoder_only( # PREFILL: encoder attention enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( - test_rsrcs.attn, enc_test_params, prephase_attn_metadata, test_pt=test_pt)) + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + test_pt=test_pt)) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -899,7 +909,6 @@ def test_e2e_enc_dec_attn( * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' - # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): torch.set_default_dtype(torch.bfloat16) @@ -971,7 +980,10 @@ def test_e2e_enc_dec_attn( # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, test_pt=test_pt) + test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, @@ -980,8 +992,11 @@ def test_e2e_enc_dec_attn( # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - prephase_attn_metadata, test_pt=test_pt) + test_rsrcs, + prephase_dec_test_params, + prephase_cross_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, @@ -1001,7 +1016,10 @@ def test_e2e_enc_dec_attn( # DECODE: decoder self-attention test decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, decphase_dec_test_params, decphase_attn_metadata, test_pt=test_pt) + test_rsrcs, + decphase_dec_test_params, + decphase_attn_metadata, + test_pt=test_pt) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, @@ -1010,7 +1028,11 @@ def test_e2e_enc_dec_attn( # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata, test_pt=test_pt) + test_rsrcs, + decphase_dec_test_params, + None, + decphase_attn_metadata, + test_pt=test_pt) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 127b980480fbb..909ef82b82194 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,8 +13,8 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, - STR_FLASH_ATTN_VAL, make_tensor_with_pad) +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -631,8 +631,8 @@ def make_kv_cache(num_blocks: int, (2, num_blocks, block_size, num_heads, head_size)).to(device) else: raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." - ) + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " + f"'FLASH_ATTN'.") if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 8e8862fadbf04..1113a66f42f93 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -7,6 +7,7 @@ import pytest from transformers import AutoModelForSeq2SeqLM +from vllm.attention.selector import get_attn_backend from vllm.sequence import SampleLogprobs from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, @@ -170,6 +171,13 @@ def run_test( ) +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear the cached value of attention backend before each test.""" + get_attn_backend.cache_clear() + yield + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index b2ff0952bd7f7..81a5cf57defa3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -12,8 +12,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, - is_block_tables_empty, - get_seq_len_block_table_args) + get_seq_len_block_table_args, + is_block_tables_empty) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -675,11 +675,33 @@ def forward( return output -def _get_key_query_seq_metadata( +def _get_query_key_seq_metadata( attn_metadata, is_prompt: bool, attn_type: AttentionType, ) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ if attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run @@ -712,14 +734,30 @@ def _get_key_query_seq_metadata( else: raise AttributeError(f"Invalid attention type {str(attn_type)}") + def _get_num_prefill_encode_decode_tokens( attn_metadata: FlashAttentionMetadata, attn_type: AttentionType, -) -> tuple[int, int, int]: +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill, encoder, and decode tokens based on the + attention metadata and the specified attention type. + + Args: + attn_metadata (FlashAttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill tokens. + - The number of encoder tokens. + - The number of decode tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ if attn_type == AttentionType.ENCODER: # Encoder attention - chunked prefill is not applicable; - # derive token-count from query shape & and treat them - # as 100% prefill tokens assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens num_encoder_tokens = attn_metadata.num_encoder_tokens @@ -731,23 +769,31 @@ def _get_num_prefill_encode_decode_tokens( num_prefill_tokens = attn_metadata.num_prefill_tokens num_encoder_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = attn_metadata.num_decode_tokens - else: # attn_type == AttentionType.DECODER or - # attn_type == AttentionType.ENCODER_ONLY + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = 0 num_decode_tokens = attn_metadata.num_decode_tokens return (num_prefill_tokens, num_encoder_tokens, num_decode_tokens) -def _get_causal_option(attn_type: AttentionType)-> bool: - if (attn_type == AttentionType.ENCODER or \ - attn_type == AttentionType.ENCODER_ONLY or \ - attn_type == AttentionType.ENCODER_DECODER) : - return False - - return True - +def _get_causal_option(attn_type: AttentionType) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) @torch.library.custom_op("vllm::unified_flash_attention", @@ -773,9 +819,9 @@ def unified_flash_attention( # Convert integer attn_type to enum try: attn_type = AttentionType(attn_type_int_val) - except ValueError: + except ValueError as err: raise AttributeError( - f"Invalid attention type {str(attn_type_int_val)}") + f"Invalid attention type {str(attn_type_int_val)}") from err current_metadata = get_forward_context() assert current_metadata is not None @@ -813,12 +859,13 @@ def unified_flash_attention( value, kv_cache[0], kv_cache[1], - updated_slot_mapping.flatten(), + updated_slot_mapping.flatten() + if updated_slot_mapping is not None else None, kv_cache_dtype, k_scale, v_scale, ) - + num_prefill_tokens, num_encoder_tokens, num_decode_tokens = \ _get_num_prefill_encode_decode_tokens(attn_metadata, attn_type) decode_query = query[num_prefill_tokens:] @@ -836,8 +883,8 @@ def unified_flash_attention( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = _get_key_query_seq_metadata( - prefill_meta, True, attn_type) + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) if (attn_type == AttentionType.ENCODER or \ attn_type == AttentionType.ENCODER_DECODER): @@ -846,7 +893,7 @@ def unified_flash_attention( else: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] - + prefill_output = flash_attn_varlen_func( q=query, k=key, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 723d42954936c..c103e03ff5598 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -7,8 +7,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) -from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.attention.backends.abstract import AttentionType +from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner_base import ModelRunnerBase @@ -318,9 +318,11 @@ def graph_capture_get_metadata_for_batch( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or 'FLASH_ATTN', but "\ - f"got '{self.runner.attn_backend.get_name()}'" + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -338,9 +340,11 @@ def get_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or 'FLASH_ATTN', but "\ - f"got '{self.runner.attn_backend.get_name()}'" + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers @@ -357,11 +361,13 @@ def prepare_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or 'FLASH_ATTN', but "\ - f"got '{self.runner.attn_backend.get_name()}'" - self._prepare_input_buffers_for_enc_dec_model(attn_metadata, - input_buffers) + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) def begin_forward(self, model_input) -> None: return diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 91f9434c91d99..59e6e0015bbdf 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -13,9 +13,9 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, CommonMetadataBuilder, - is_all_encoder_attn_metadata_set, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, - get_seq_len_block_table_args) + is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -360,6 +360,7 @@ def _get_seq_len_block_table_args( ''' return get_seq_len_block_table_args(attn_metadata, is_prompt, attn_type) + class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): _metadata_cls = XFormersMetadata diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d564ec8c26aba..38510237a2014 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -88,7 +88,7 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend -#@lru_cache(maxsize=None) +@lru_cache(maxsize=None) def get_attn_backend( head_size: int, dtype: torch.dtype, @@ -98,17 +98,16 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - print('In get_attn_backend1') if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - print('In get_attn_backend2') backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, is_attention_free) if backend == _Backend.FLASH_ATTN: + logger.info("Using Flash Attention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend From d99370ceaac4e847ea88f6556919ec171602a056 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 28 Oct 2024 20:34:56 +0000 Subject: [PATCH 11/32] Remove unused import --- tests/kernels/test_encoder_decoder_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 02d7d78a8cd47..05087bc3643b4 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -19,7 +19,6 @@ from vllm.attention.selector import (_Backend, get_attn_backend, global_force_attn_backend_context_manager) from vllm.forward_context import set_forward_context -from vllm.utils import is_hip from vllm.platforms import current_platform # List of support backends for encoder/decoder models From 11bda4f4bebcd3aba85a444a72220fcdb8e2613d Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 28 Oct 2024 21:20:20 +0000 Subject: [PATCH 12/32] Reverting layer changes --- vllm/attention/layer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 58d6dadbda110..33d05cbd3fe01 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -78,11 +78,9 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - print('dtype ' + str(dtype)) attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, block_size, is_attention_free, blocksparse_params is not None) - print('attn_backend ' + str(attn_backend)) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, From 040f61e903648716f92b0183a1e1e300f0ac80a5 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 28 Oct 2024 23:37:02 +0000 Subject: [PATCH 13/32] Fixes --- vllm/attention/selector.py | 4 ---- vllm/utils.py | 4 ++-- vllm/worker/enc_dec_model_runner.py | 25 ++++++++++++++++--------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index a8d812447ed0e..8a59cf41a689e 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -180,7 +180,6 @@ def which_attn_to_use( # ENVIRONMENT VARIABLE. backend_by_global_setting: Optional[_Backend] = ( get_global_forced_attn_backend()) - print('backend_by_global_setting ' + str(backend_by_global_setting)) if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -299,9 +298,6 @@ def global_force_attn_backend_context_manager( # Globally force the new backend override global_force_attn_backend(attn_backend) - print('original value ' + str(original_value)) - print('new value ' + str(attn_backend)) - # Yield control back to the enclosed code block try: yield diff --git a/vllm/utils.py b/vllm/utils.py index c3f9a6bdd8b80..1a2d8f5a3bd4e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -80,8 +80,8 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " - "currently supported with encoder/" +STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only " + "backends currently supported with encoder/" "decoder models.") STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 08ee01394228e..374392ea2f676 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -101,8 +101,8 @@ def __init__( models) but these arguments are present here for compatibility with the base-class constructor. ''' - - #self._maybe_force_supported_attention_backend() + print('model_name ' + str(model_config.model)) + self._maybe_force_supported_attention_backend(model_config.model) super().__init__( model_config, @@ -119,7 +119,10 @@ def __init__( # Crash for unsupported encoder/scenarios assert_enc_dec_mr_supported_scenario(self) - def _maybe_force_supported_attention_backend(self): + def _is_xformers_only_encoder_decoder_model(self, model: str) -> bool: + return "llama-3.2-11b-vision-instruct" in model.lower() + + def _maybe_force_supported_attention_backend(self, model: str): ''' Force vLLM to use the XFormers attention backend, which is currently the only supported option. @@ -135,22 +138,26 @@ def raise_backend_err(): is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None - if not (is_forced_by_global or is_forced_by_env_var): + if not (is_forced_by_global or is_forced_by_env_var) \ + and self._is_xformers_only_encoder_decoder_model(model): # The user has not already specified an attention backend # override - logger.info("EncoderDecoderModelRunner requires " - "XFormers backend; overriding backend " - "auto-selection and forcing XFormers.") + logger.info( + "Encoder-Decoder Model %s requires XFormers backend; " + "overriding backend auto-selection and " + "forcing XFormers.", model) global_force_attn_backend(_Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. - if maybe_global_forced_backend != _Backend.XFORMERS: + if maybe_global_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() elif is_forced_by_env_var: # Backend override enforced by vLLM backend # environment variable - if maybe_env_var_forced_backend != _Backend.XFORMERS: + if maybe_env_var_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() def _list_to_int32_tensor( From 1bc6fe18a3952bdd90958da02e09e2b157e529d0 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 01:48:31 +0000 Subject: [PATCH 14/32] Fixes --- tests/encoder_decoder/test_e2e_correctness.py | 8 -------- tests/kernels/test_encoder_decoder_attn.py | 19 ++++++++++++------- tests/kernels/utils.py | 10 ---------- .../encoder_decoder/language/test_bart.py | 8 -------- vllm/attention/backends/flash_attn.py | 3 ++- vllm/worker/enc_dec_model_runner.py | 2 +- 6 files changed, 15 insertions(+), 35 deletions(-) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index df47cf152ad50..855b96bcede95 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -28,14 +28,6 @@ def vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear the cached value of attention backend before each test.""" - get_attn_backend.cache_clear() - yield - - @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("dtype", ["bfloat16", "float"]) @pytest.mark.parametrize("max_tokens", [128]) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 05087bc3643b4..e091570916c80 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -622,8 +622,10 @@ def _run_encoder_attention_test( packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - return attn.forward(packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size), + # TODO - Fix the shape of the query to be [] + reshaped_query = packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, torch.tensor([], @@ -666,8 +668,11 @@ def _run_decoder_self_attention_test( packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - return attn.forward( #packed_qkv.query, - packed_qkv.query.view(-1, test_pt.num_heads * test_pt.head_size), + # The current test assumes that the input query is of the + # shape + reshaped_query = packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, kv_cache, @@ -727,9 +732,9 @@ def _run_encoder_decoder_cross_attention_test( key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) with set_forward_context(attn_metadata): - return attn.forward( #decoder_test_params.packed_qkvo.packed_qkv.query, - decoder_test_params.packed_qkvo.packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size), + reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, key, value, kv_cache, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 909ef82b82194..a8d6ecf9ff754 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -527,7 +527,6 @@ def make_backend(backend_name: str) -> AttentionBackend: from vllm.attention.backends.xformers import XFormersBackend return XFormersBackend() elif backend_name == STR_FLASH_ATTN_VAL: - #print('Hello') from vllm.attention.backends.flash_attn import FlashAttentionBackend return FlashAttentionBackend() @@ -581,10 +580,6 @@ def _make_metadata_tensors( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - print('seq_start_loc ' + str(seq_start_loc)) - print('seq_lens_tensor ' + str(seq_lens_tensor)) - print('max_seq_len ' + str(max_seq_len)) - encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=encoder_seq_lens_tensor.device) @@ -593,8 +588,6 @@ def _make_metadata_tensors( dtype=encoder_seq_start_loc.dtype, out=encoder_seq_start_loc[1:]) - #print('encoder_seq_start_loc ' + str(encoder_seq_start_loc)) - return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len) @@ -839,7 +832,6 @@ def make_test_metadata( * AttentionMetadata structure ''' - #print('Here for metadata!!!') # Decoder self-attention memory mapping # decoder_test_params is None signals encoder-only # scenario, so kv_mmap is None @@ -885,8 +877,6 @@ def make_test_metadata( # (kv_mmap) cross_kv_mmap = cross_test_params.kv_mmap - #print('Here for metadata!!') - if is_prompt: # Prefill-phase scenario diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 1113a66f42f93..a2eea43aa1357 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -170,14 +170,6 @@ def run_test( num_outputs_0_skip_tokens=hf_skip_tokens, ) - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear the cached value of attention backend before each test.""" - get_attn_backend.cache_clear() - yield - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 81a5cf57defa3..36c09ff15633e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -813,7 +813,7 @@ def unified_flash_attention( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, - attn_type_int_val: int = 0, + attn_type_int_val: int = AttentionType.DECODER.value, ) -> torch.Tensor: # Convert integer attn_type to enum @@ -1002,5 +1002,6 @@ def _( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, + attn_type_int_val: int = AttentionType.DECODER.value, ) -> torch.Tensor: return torch.empty_like(query) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 374392ea2f676..1bc06abe2a631 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -101,7 +101,6 @@ def __init__( models) but these arguments are present here for compatibility with the base-class constructor. ''' - print('model_name ' + str(model_config.model)) self._maybe_force_supported_attention_backend(model_config.model) super().__init__( @@ -120,6 +119,7 @@ def __init__( assert_enc_dec_mr_supported_scenario(self) def _is_xformers_only_encoder_decoder_model(self, model: str) -> bool: + # The Llama 3.2 model implementation uses return "llama-3.2-11b-vision-instruct" in model.lower() def _maybe_force_supported_attention_backend(self, model: str): From 4573f8f26cd2f3aac74f2df81e9fc4907a883fbb Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 01:50:28 +0000 Subject: [PATCH 15/32] Format --- tests/encoder_decoder/test_e2e_correctness.py | 2 +- tests/kernels/test_encoder_decoder_attn.py | 26 +++++++++---------- .../encoder_decoder/language/test_bart.py | 2 +- vllm/worker/enc_dec_model_runner.py | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index 855b96bcede95..3f3f31f82eed0 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -7,7 +7,6 @@ import pytest from transformers import AutoModelForSeq2SeqLM -from vllm.attention.selector import get_attn_backend from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs @@ -28,6 +27,7 @@ def vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs + @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("dtype", ["bfloat16", "float"]) @pytest.mark.parametrize("max_tokens", [128]) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index e091570916c80..8e3ceb0fc675f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -668,16 +668,16 @@ def _run_decoder_self_attention_test( packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - # The current test assumes that the input query is of the - # shape + # The current test assumes that the input query is of the + # shape reshaped_query = packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) + -1, test_pt.num_heads * test_pt.head_size) return attn.forward(reshaped_query, - packed_qkv.key, - packed_qkv.value, - kv_cache, - attn_metadata, - attn_type=attn_type) + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -735,11 +735,11 @@ def _run_encoder_decoder_cross_attention_test( reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) return attn.forward(reshaped_query, - key, - value, - kv_cache, - attn_metadata, - attn_type=attn_type) + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) @pytest.fixture(autouse=True) diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index a2eea43aa1357..8e8862fadbf04 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -7,7 +7,6 @@ import pytest from transformers import AutoModelForSeq2SeqLM -from vllm.attention.selector import get_attn_backend from vllm.sequence import SampleLogprobs from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, @@ -170,6 +169,7 @@ def run_test( num_outputs_0_skip_tokens=hf_skip_tokens, ) + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 1bc06abe2a631..b8f98e4185724 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -119,7 +119,7 @@ def __init__( assert_enc_dec_mr_supported_scenario(self) def _is_xformers_only_encoder_decoder_model(self, model: str) -> bool: - # The Llama 3.2 model implementation uses + # The Llama 3.2 model implementation uses return "llama-3.2-11b-vision-instruct" in model.lower() def _maybe_force_supported_attention_backend(self, model: str): From ed587cb26397f9e8890da0abca85be7ba178583d Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 02:33:07 +0000 Subject: [PATCH 16/32] Fix test reset logic --- tests/kernels/test_encoder_decoder_attn.py | 38 +++++++++++++++++----- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 8e3ceb0fc675f..2f979e7adfa2f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -622,7 +622,13 @@ def _run_encoder_attention_test( packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - # TODO - Fix the shape of the query to be [] + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) return attn.forward(reshaped_query, @@ -668,8 +674,13 @@ def _run_decoder_self_attention_test( packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata): - # The current test assumes that the input query is of the - # shape + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) return attn.forward(reshaped_query, @@ -732,6 +743,13 @@ def _run_encoder_decoder_cross_attention_test( key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) return attn.forward(reshaped_query, @@ -743,10 +761,17 @@ def _run_encoder_decoder_cross_attention_test( @pytest.fixture(autouse=True) -def clear_cache(): - """Clear the cached value of attention backend before each test.""" +def set_reset_environment(): + # Set the default torch datatype to bfloat16 to enable + # testing of the Flash Attention backend. Also clear the + # cached value of the backend. + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) get_attn_backend.cache_clear() yield + # Reset the torch datatype to what it was before the test + # so as not to impact the remaining tests. + torch.set_default_dtype(default_dtype) @pytest.mark.skipif(current_platform.is_rocm(), @@ -798,8 +823,6 @@ def test_encoder_only( ''' # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - torch.set_default_dtype(torch.bfloat16) - # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test @@ -920,7 +943,6 @@ def test_e2e_enc_dec_attn( ''' # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - torch.set_default_dtype(torch.bfloat16) # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test From 18e8a9778e88482ae013c6edb90e6cdd1e636902 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 05:31:13 +0000 Subject: [PATCH 17/32] Fixes --- tests/kernels/test_encoder_decoder_attn.py | 23 +++++++++------ tests/kernels/utils.py | 28 +++++++++++++++---- .../vision_language/test_florence2.py | 1 - vllm/attention/backends/flash_attn.py | 28 +++++++++++-------- 4 files changed, 54 insertions(+), 26 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 2f979e7adfa2f..2eebf16961832 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -761,12 +761,13 @@ def _run_encoder_decoder_cross_attention_test( @pytest.fixture(autouse=True) -def set_reset_environment(): +def set_reset_environment(attn_backend): # Set the default torch datatype to bfloat16 to enable # testing of the Flash Attention backend. Also clear the # cached value of the backend. default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.bfloat16) + if attn_backend.name == 'FLASH_ATTN': + torch.set_default_dtype(torch.bfloat16) get_attn_backend.cache_clear() yield # Reset the torch datatype to what it was before the test @@ -859,7 +860,8 @@ def test_encoder_only( test_pt=test_pt)) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, + attn_backend.name) @pytest.mark.skipif(current_platform.is_rocm(), @@ -1006,7 +1008,8 @@ def test_e2e_enc_dec_attn( test_pt=test_pt) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, + attn_backend.name) # PREFILL: decoder self-attention test @@ -1018,7 +1021,8 @@ def test_e2e_enc_dec_attn( # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, - prephase_dec_pckd_act_out) + prephase_dec_pckd_act_out, + attn_backend.name) # PREFILL: encoder/decoder cross-attention test @@ -1031,7 +1035,8 @@ def test_e2e_enc_dec_attn( # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out) + prephase_cross_pckd_act_out, + attn_backend.name) # DECODE: build decode-phase attention metadata @@ -1054,7 +1059,8 @@ def test_e2e_enc_dec_attn( # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out) + decphase_dec_pckd_act_out, + attn_backend.name) # DECODE: encoder/decoder cross-attention test @@ -1067,4 +1073,5 @@ def test_e2e_enc_dec_attn( # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) + decphase_cross_pckd_act_out, + attn_backend.name) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index a8d6ecf9ff754..a218a175a9948 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -558,6 +558,8 @@ def _make_metadata_tensors( * max_context_len: max(context_lens) * max_seq_len: max(seq_lens) * seq_start_loc: start idx of each sequence + * encoder_seq_lens_tensor: encoder seq_lens list, as tensor + * encoder_seq_start_loc: start idx of each encoder sequence * max_encoder_seq_len: encoder seq_lens list, as tensor ''' seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) @@ -615,6 +617,9 @@ def make_kv_cache(num_blocks: int, Returns: * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + * for backend 'XFORMERS' + * kv_cache: 2 x num_blocks x block_size x num_heads x head_size + * for backend 'FLASH_ATTN' ''' if backend == 'XFORMERS': kv_cache = torch.rand( @@ -972,7 +977,8 @@ def make_test_metadata( def assert_actual_matches_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor) -> None: + output_under_test: torch.Tensor, + backend: str) -> None: ''' Assert that observed output matches the ideal output contained in the test parameters data structure. @@ -983,10 +989,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, * output_under_test: actually observed output value ''' ideal_output = test_params.packed_qkvo.ideal_output - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output), - atol=0.01, - rtol=0.016) + if backend == 'XFORMERS': + torch.testing.assert_close(ideal_output, + output_under_test.view_as(ideal_output)) + + elif backend == 'FLASH_ATTN': + # For FlashAttention override the accuracy thresholds to non default + # values since we notice a higher difference between the ideal and + # actual output. + torch.testing.assert_close(ideal_output, + output_under_test.view_as(ideal_output), + atol=0.01, + rtol=0.016) + else: + raise ValueError( + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " + f"'FLASH_ATTN'.") # Copied/modified from torch._refs.__init__.py diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index 069fc287301e9..d686f1da3fa17 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -86,7 +86,6 @@ def run_test( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -#@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 36c09ff15633e..3eeb54c062650 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -13,6 +13,8 @@ compute_slot_mapping, compute_slot_mapping_start_idx, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -174,9 +176,7 @@ def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + return is_all_encoder_attn_metadata_set(self) @property def is_all_cross_attn_metadata_set(self): @@ -185,9 +185,7 @@ def is_all_cross_attn_metadata_set(self): Superset of encoder attention required metadata. ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + return is_all_cross_attn_metadata_set(self) @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: @@ -713,8 +711,11 @@ def _get_query_key_seq_metadata( attn_metadata.seq_start_loc, max_seq_len) elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. if is_prompt: max_seq_len = attn_metadata.max_prefill_seq_len else: @@ -723,6 +724,8 @@ def _get_query_key_seq_metadata( attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len) elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. return (attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len, attn_metadata.encoder_seq_start_loc, @@ -757,20 +760,20 @@ def _get_num_prefill_encode_decode_tokens( is `None` when required for the calculations. """ if attn_type == AttentionType.ENCODER: - # Encoder attention - chunked prefill is not applicable; + # Encoder attention is only invoked during prefill phase. assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens num_encoder_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 elif attn_type == AttentionType.ENCODER_DECODER: - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_prefill_tokens num_encoder_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = attn_metadata.num_decode_tokens else: # attn_type == AttentionType.DECODER or # attn_type == AttentionType.ENCODER_ONLY + # There are no encoder tokens for DECODER and ENCODER_ONLY + # attention type. num_prefill_tokens = attn_metadata.num_prefill_tokens num_encoder_tokens = 0 num_decode_tokens = attn_metadata.num_decode_tokens @@ -839,7 +842,6 @@ def unified_flash_attention( if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] - if (attn_type != AttentionType.ENCODER) and (key is not None) and ( value is not None): if attn_type == AttentionType.ENCODER_DECODER: @@ -936,6 +938,8 @@ def unified_flash_attention( # because different queries might have different lengths. assert decode_meta.max_decode_query_len is not None if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Decoder only models support max_decode_query_len > 1") decode_output = flash_attn_varlen_func( q=decode_query, k=key_cache, From 9f7dc043b841bf7e82e9ef0c594d86cb52885fa2 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 05:41:45 +0000 Subject: [PATCH 18/32] Dummu --- vllm/attention/backends/flash_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 3eeb54c062650..8060e79591f6a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -85,6 +85,7 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) From fce7f621d918a827cb512b29e201139ce1782818 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 05:43:14 +0000 Subject: [PATCH 19/32] Fix --- vllm/attention/backends/flash_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8060e79591f6a..971fe6d65f97e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -73,7 +73,6 @@ def swap_blocks( src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) From d596c236932583a92c22ca5e4b90f1dc9b9972b0 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 05:58:34 +0000 Subject: [PATCH 20/32] Dummy --- vllm/attention/backends/flash_attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 971fe6d65f97e..53fa1551f19cf 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -127,11 +127,13 @@ class FlashAttentionMetadata(AttentionMetadata): # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. + block_tables: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: Optional[bool] # Maximum query length in the batch. From 7bed5e6b660103a375bcfabce1a757c7d766425d Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 06:06:54 +0000 Subject: [PATCH 21/32] Format --- vllm/attention/backends/flash_attn.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 53fa1551f19cf..edeb4ffc5304f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,13 +9,11 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, - compute_slot_mapping, - compute_slot_mapping_start_idx, - get_seq_len_block_table_args, - is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, - is_block_tables_empty) +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, + is_block_tables_empty) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -127,7 +125,6 @@ class FlashAttentionMetadata(AttentionMetadata): # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. - block_tables: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. From 3a7d05e95f69ebf6c9a4bc2f5c3b1b32c3dcb8fa Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 29 Oct 2024 06:11:50 +0000 Subject: [PATCH 22/32] Format --- vllm/attention/backends/flash_attn.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index edeb4ffc5304f..ed3b87f8fdcfc 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,14 +9,15 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionType) -from vllm.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_seq_len_block_table_args, - is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, - is_block_tables_empty) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from .utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, + get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) + if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) From 0604c0a30932d3938a59c35fa5f3c49afba5edbd Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 30 Oct 2024 06:19:22 +0000 Subject: [PATCH 23/32] Comments --- .../test_encoder_decoder_model_runner.py | 1 - vllm/attention/backends/flash_attn.py | 58 +++---------------- vllm/attention/backends/utils.py | 45 +++++++++++++- vllm/attention/backends/xformers.py | 39 ++++--------- vllm/worker/enc_dec_model_runner.py | 19 +++--- 5 files changed, 72 insertions(+), 90 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index e75884a7395e2..3127f8a290813 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -13,7 +13,6 @@ BATCH_SIZES = [1, 4, 16, 64, 256] - def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunner: engine_args = EngineArgs(model, *args, **kwargs) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ed3b87f8fdcfc..fa92746e5b03a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -15,6 +15,7 @@ from .utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, get_seq_len_block_table_args, + get_num_prefill_encode_decode_tokens, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) @@ -132,7 +133,7 @@ class FlashAttentionMetadata(AttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: Optional[bool] + use_cuda_graph: bool # Maximum query length in the batch. max_query_len: Optional[int] = None @@ -664,10 +665,10 @@ def forward( k_scale, v_scale, self.scale, + attn_type.value, self.sliding_window, self.alibi_slopes, self.logits_soft_cap, - attn_type.value, ) return output @@ -738,49 +739,6 @@ def _get_query_key_seq_metadata( raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_num_prefill_encode_decode_tokens( - attn_metadata: FlashAttentionMetadata, - attn_type: AttentionType, -) -> Tuple[int, int, int]: - """ - Calculate the number of prefill, encoder, and decode tokens based on the - attention metadata and the specified attention type. - - Args: - attn_metadata (FlashAttentionMetadata): Attention Metadata object. - attn_type (AttentionType): The type of attention being used. - Returns: - Tuple[int, int, int]: A tuple containing three integers: - - The number of prefill tokens. - - The number of encoder tokens. - - The number of decode tokens. - - Raises: - AssertionError: If the number of encoder tokens in `attn_metadata` - is `None` when required for the calculations. - """ - if attn_type == AttentionType.ENCODER: - # Encoder attention is only invoked during prefill phase. - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 - elif attn_type == AttentionType.ENCODER_DECODER: - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - else: # attn_type == AttentionType.DECODER or - # attn_type == AttentionType.ENCODER_ONLY - # There are no encoder tokens for DECODER and ENCODER_ONLY - # attention type. - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = 0 - num_decode_tokens = attn_metadata.num_decode_tokens - - return (num_prefill_tokens, num_encoder_tokens, num_decode_tokens) - - def _get_causal_option(attn_type: AttentionType) -> bool: """ Determine whether the given attention type is suitable for causal @@ -813,10 +771,10 @@ def unified_flash_attention( k_scale: float, v_scale: float, softmax_scale: float, + attn_type_int_val: int, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, - attn_type_int_val: int = AttentionType.DECODER.value, ) -> torch.Tensor: # Convert integer attn_type to enum @@ -869,7 +827,7 @@ def unified_flash_attention( ) num_prefill_tokens, num_encoder_tokens, num_decode_tokens = \ - _get_num_prefill_encode_decode_tokens(attn_metadata, attn_type) + get_num_prefill_encode_decode_tokens(attn_metadata, attn_type) decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] @@ -913,7 +871,7 @@ def unified_flash_attention( else: # prefix-enabled attention assert attn_type == AttentionType.DECODER, ( - "Decoder only models currently support prefix caching") + "Only decoder-only models support prefix caching") assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) prefill_output = flash_attn_varlen_func( # noqa @@ -939,7 +897,7 @@ def unified_flash_attention( assert decode_meta.max_decode_query_len is not None if decode_meta.max_decode_query_len > 1: assert attn_type == AttentionType.DECODER, ( - "Decoder only models support max_decode_query_len > 1") + "Only decoder-only models support max_decode_query_len > 1") decode_output = flash_attn_varlen_func( q=decode_query, k=key_cache, @@ -1003,9 +961,9 @@ def _( k_scale: float, v_scale: float, softmax_scale: float, + attn_type_int_val: int, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, - attn_type_int_val: int = AttentionType.DECODER.value, ) -> torch.Tensor: return torch.empty_like(query) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index c103e03ff5598..41de2a825995b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,6 +1,6 @@ """Attention backend utils""" from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union import numpy as np import torch @@ -512,3 +512,46 @@ def get_seq_len_block_table_args( attn_metadata.max_encoder_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def get_num_prefill_encode_decode_tokens( + attn_metadata, + attn_type: AttentionType, +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill, encoder, and decode tokens based on the + attention metadata and the specified attention type. + + Args: + attn_metadata (FlashAttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill tokens. + - The number of encoder tokens. + - The number of decode tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ + if attn_type == AttentionType.ENCODER: + # Encoder attention is only invoked during prefill phase. + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + # There are no encoder tokens for DECODER and ENCODER_ONLY + # attention type. + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_tokens, num_encoder_tokens, num_decode_tokens) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 59e6e0015bbdf..41c23eb882e79 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -15,6 +15,7 @@ CommonMetadataBuilder, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + get_num_prefill_encode_decode_tokens, is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -548,33 +549,8 @@ def forward( updated_slot_mapping, self.kv_cache_dtype, k_scale, v_scale) - - if attn_type == AttentionType.ENCODER: - # Encoder attention - chunked prefill is not applicable; - # derive token-count from query shape & and treat them - # as 100% prefill tokens - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 - elif attn_type == AttentionType.DECODER: - # Decoder self-attention supports chunked prefill. - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - # Only enforce this shape-constraint for decoder - # self-attention - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - else: # attn_type == AttentionType.ENCODER_DECODER - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - if attn_metadata.num_encoder_tokens is not None: - num_encoder_tokens = attn_metadata.num_encoder_tokens - else: - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens, num_encoder_tokens, num_decode_tokens = \ + get_num_prefill_encode_decode_tokens(attn_metadata, attn_type) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. @@ -582,8 +558,13 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] if key is not None and value is not None: - key = key[:num_encoder_tokens] - value = value[:num_encoder_tokens] + if (attn_type == AttentionType.ENCODER or \ + attn_type == AttentionType.ENCODER_DECODER): + key = key[:num_encoder_tokens] + value = value[:num_encoder_tokens] + else: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index b8f98e4185724..ffb64cb3618d3 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,6 +18,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader.utils import get_architecture_class_name from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, MultiModalRegistry) @@ -36,6 +37,8 @@ logger = init_logger(__name__) +_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"] + @dataclasses.dataclass(frozen=True) class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): @@ -101,8 +104,7 @@ def __init__( models) but these arguments are present here for compatibility with the base-class constructor. ''' - self._maybe_force_supported_attention_backend(model_config.model) - + self._maybe_force_supported_attention_backend(model_config) super().__init__( model_config, parallel_config, @@ -118,11 +120,10 @@ def __init__( # Crash for unsupported encoder/scenarios assert_enc_dec_mr_supported_scenario(self) - def _is_xformers_only_encoder_decoder_model(self, model: str) -> bool: - # The Llama 3.2 model implementation uses - return "llama-3.2-11b-vision-instruct" in model.lower() + def _is_xformers_only_encoder_decoder_model(self, model: ModelConfig) -> bool: + return get_architecture_class_name(model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS - def _maybe_force_supported_attention_backend(self, model: str): + def _maybe_force_supported_attention_backend(self, model: ModelConfig): ''' Force vLLM to use the XFormers attention backend, which is currently the only supported option. @@ -143,9 +144,9 @@ def raise_backend_err(): # The user has not already specified an attention backend # override logger.info( - "Encoder-Decoder Model %s requires XFormers backend; " - "overriding backend auto-selection and " - "forcing XFormers.", model) + "Encoder-Decoder Model Architecture %s requires XFormers " + "backend; overriding backend auto-selection and " + "forcing XFormers.", get_architecture_class_name(model)) global_force_attn_backend(_Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes From 77ee5e234e048f0a57598196ea6b00b8dc494ce3 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 30 Oct 2024 06:23:50 +0000 Subject: [PATCH 24/32] Format --- tests/worker/test_encoder_decoder_model_runner.py | 1 + vllm/attention/backends/flash_attn.py | 12 +++++------- vllm/attention/backends/xformers.py | 10 ++++------ vllm/worker/enc_dec_model_runner.py | 8 +++++--- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 3127f8a290813..e75884a7395e2 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -13,6 +13,7 @@ BATCH_SIZES = [1, 4, 16, 64, 256] + def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunner: engine_args = EngineArgs(model, *args, **kwargs) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index fa92746e5b03a..12d053051525d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,16 +9,14 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionType) +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_encode_decode_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from .utils import (PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, - get_seq_len_block_table_args, - get_num_prefill_encode_decode_tokens, - is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, is_block_tables_empty) - if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 41c23eb882e79..719eae0bfb852 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,12 +11,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder, - get_seq_len_block_table_args, - is_all_cross_attn_metadata_set, - get_num_prefill_encode_decode_tokens, - is_all_encoder_attn_metadata_set) +from vllm.attention.backends.utils import ( + CommonAttentionState, CommonMetadataBuilder, + get_num_prefill_encode_decode_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index ffb64cb3618d3..bed53a95c1504 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,8 +18,8 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata -from vllm.model_executor.model_loader.utils import get_architecture_class_name from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.utils import get_architecture_class_name from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams @@ -120,8 +120,10 @@ def __init__( # Crash for unsupported encoder/scenarios assert_enc_dec_mr_supported_scenario(self) - def _is_xformers_only_encoder_decoder_model(self, model: ModelConfig) -> bool: - return get_architecture_class_name(model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS + def _is_xformers_only_encoder_decoder_model(self, + model: ModelConfig) -> bool: + return get_architecture_class_name( + model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS def _maybe_force_supported_attention_backend(self, model: ModelConfig): ''' From b147fb9be5117215c0afbb46b01780dff28c75ca Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 30 Oct 2024 07:06:30 +0000 Subject: [PATCH 25/32] Comments --- tests/kernels/test_encoder_decoder_attn.py | 6 ++++ vllm/attention/backends/flash_attn.py | 12 +++++--- vllm/attention/backends/xformers.py | 33 +--------------------- 3 files changed, 15 insertions(+), 36 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 2eebf16961832..a1dd5eeeaa398 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -612,6 +612,8 @@ def _run_encoder_attention_test( (number_of_tokens x num_heads x head_size) query/key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed {query,key,value} and @@ -663,6 +665,8 @@ def _run_decoder_self_attention_test( query/key/value fields * attn_metadata: attention metadata for decoder-self attention (contains KV cache memory-mapping) + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache @@ -725,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test( (number_of_tokens x num_heads x head_size) key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 12d053051525d..d3507b70df941 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -798,12 +798,17 @@ def unified_flash_attention( if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the KV + # cache is already populated with the cross-attention tensor. + # Thus, we skip cache updates during this time. if (attn_type != AttentionType.ENCODER) and (key is not None) and ( value is not None): if attn_type == AttentionType.ENCODER_DECODER: # Update cross-attention KV cache (prefill-only) - # During cross-attention decode, key & value will be None, - # preventing this IF-statement branch from running updated_slot_mapping = attn_metadata.cross_slot_mapping else: # Update self-attention KV cache (prefill/decode) @@ -817,8 +822,7 @@ def unified_flash_attention( value, kv_cache[0], kv_cache[1], - updated_slot_mapping.flatten() - if updated_slot_mapping is not None else None, + updated_slot_mapping.flatten(), # type: ignore[union-attr] kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 719eae0bfb852..0c18eb16a2a04 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -329,37 +329,6 @@ def _set_attn_bias( raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_seq_len_block_table_args( - attn_metadata: XFormersMetadata, - is_prompt: bool, - attn_type: AttentionType, -) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - return get_seq_len_block_table_args(attn_metadata, is_prompt, attn_type) - - class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): _metadata_cls = XFormersMetadata @@ -616,7 +585,7 @@ def forward( seq_lens_arg, max_seq_len_arg, block_tables_arg, - ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, From 7284de5d37a4eec41741455ebeab013ec5d73306 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 30 Oct 2024 07:24:55 +0000 Subject: [PATCH 26/32] Comment --- vllm/worker/enc_dec_model_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index bed53a95c1504..881caba9501fc 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -37,6 +37,9 @@ logger = init_logger(__name__) +# The Mllama model has PagedAttention specific logic because of which it +# can only be run with the XFORMERS backend +# TODO Make Mllama model work with Flash Attention backend. _XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"] From 282a9188a2a12b9aeccd192cf3f6e11d41771dc9 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 1 Nov 2024 00:40:49 +0000 Subject: [PATCH 27/32] Comments --- tests/encoder_decoder/test_e2e_correctness.py | 87 +++++++++++-------- vllm/attention/backends/flash_attn.py | 28 +++--- vllm/attention/backends/utils.py | 34 ++++---- vllm/attention/backends/xformers.py | 34 ++++---- 4 files changed, 96 insertions(+), 87 deletions(-) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index 3f3f31f82eed0..ec5a1ed22c1c6 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -7,12 +7,18 @@ import pytest from transformers import AutoModelForSeq2SeqLM +from vllm.attention.selector import (_Backend, + global_force_attn_backend_context_manager) from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs from ..conftest import DecoderPromptType from ..models.utils import check_logprobs_close +LIST_ENC_DEC_SUPPORTED_BACKENDS = [ + _Backend.XFORMERS, _Backend.FLASH_ATTN, None +] + def vllm_to_hf_output( vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -29,7 +35,8 @@ def vllm_to_hf_output( @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["bfloat16", "float"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @@ -48,6 +55,7 @@ def test_encoder_decoder_e2e( num_logprobs: int, decoder_prompt_type: DecoderPromptType, enforce_eager: bool, + attn_backend: _Backend, ) -> None: ''' End-to-End (E2E) test for the encoder-decoder framework. @@ -56,43 +64,48 @@ def test_encoder_decoder_e2e( implementations to ensure that both implementations produce consistent and correct results. ''' - test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] + with global_force_attn_backend_context_manager(attn_backend): + if attn_backend == _Backend.FLASH_ATTN: + dtype = 'bfloat16' + test_case_prompts = example_encoder_decoder_prompts[ + decoder_prompt_type] - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - with vllm_runner(model, dtype=dtype, - enforce_eager=enforce_eager) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + with vllm_runner(model, dtype=dtype, + enforce_eager=enforce_eager) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_case_prompts, max_tokens, num_logprobs) - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) + hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE + else 0) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d3507b70df941..330ebbb32e402 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,7 +11,7 @@ AttentionType) from vllm.attention.backends.utils import ( PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_num_prefill_encode_decode_tokens, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) from vllm.forward_context import get_forward_context @@ -828,13 +828,14 @@ def unified_flash_attention( v_scale, ) - num_prefill_tokens, num_encoder_tokens, num_decode_tokens = \ - get_num_prefill_encode_decode_tokens(attn_metadata, attn_type) - decode_query = query[num_prefill_tokens:] + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] # QKV for prefill. - query = query[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + query = query[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None @@ -848,13 +849,8 @@ def unified_flash_attention( q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ _get_query_key_seq_metadata(prefill_meta, True, attn_type) - if (attn_type == AttentionType.ENCODER or \ - attn_type == AttentionType.ENCODER_DECODER): - key = key[:num_encoder_tokens] - value = value[:num_encoder_tokens] - else: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] prefill_output = flash_attn_varlen_func( q=query, @@ -937,10 +933,10 @@ def unified_flash_attention( if prefill_output is None: assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) + return decode_output.view(num_decode_query_tokens, hidden_size) if decode_output is None: assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) + return prefill_output.view(num_prefill_query_tokens, hidden_size) # Chunked prefill does not work with speculative decoding. # Therefore, the query length for decode should be 1 in chunked prefill. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 41de2a825995b..60ef00078345f 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -514,13 +514,13 @@ def get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") -def get_num_prefill_encode_decode_tokens( +def get_num_prefill_decode_query_kv_tokens( attn_metadata, attn_type: AttentionType, ) -> Tuple[int, int, int]: """ - Calculate the number of prefill, encoder, and decode tokens based on the - attention metadata and the specified attention type. + Calculate the number of prefill and decode tokens for query, key or value + based on the attention metadata and the specified attention type. Args: attn_metadata (FlashAttentionMetadata): Attention Metadata object. @@ -535,23 +535,27 @@ def get_num_prefill_encode_decode_tokens( AssertionError: If the number of encoder tokens in `attn_metadata` is `None` when required for the calculations. """ + num_prefill_query_tokens = 0 + num_decode_query_tokens = 0 + num_prefill_kv_tokens = 0 if attn_type == AttentionType.ENCODER: # Encoder attention is only invoked during prefill phase. + # The same input servers a both query and key. assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 + num_prefill_query_tokens = attn_metadata.num_encoder_tokens + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = 0 elif attn_type == AttentionType.ENCODER_DECODER: assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + # The key is the encoder/cross-attention. + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens else: # attn_type == AttentionType.DECODER or # attn_type == AttentionType.ENCODER_ONLY - # There are no encoder tokens for DECODER and ENCODER_ONLY - # attention type. - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = 0 - num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + num_prefill_kv_tokens = attn_metadata.num_prefill_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens - return (num_prefill_tokens, num_encoder_tokens, num_decode_tokens) + return (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0c18eb16a2a04..45b05004ad9a2 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -13,7 +13,7 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.utils import ( CommonAttentionState, CommonMetadataBuilder, - get_num_prefill_encode_decode_tokens, get_seq_len_block_table_args, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -516,25 +516,21 @@ def forward( updated_slot_mapping, self.kv_cache_dtype, k_scale, v_scale) - num_prefill_tokens, num_encoder_tokens, num_decode_tokens = \ - get_num_prefill_encode_decode_tokens(attn_metadata, attn_type) + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] + decode_query = query[num_prefill_query_tokens:] # QKV for prefill. - query = query[:num_prefill_tokens] + query = query[:num_prefill_query_tokens] if key is not None and value is not None: - if (attn_type == AttentionType.ENCODER or \ - attn_type == AttentionType.ENCODER_DECODER): - key = key[:num_encoder_tokens] - value = value[:num_encoder_tokens] - else: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. @@ -544,8 +540,8 @@ def forward( # prefix. out = self._run_memory_efficient_xformers_forward( query, key, value, prefill_meta, attn_type=attn_type) - assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out + assert out.shape == output[:num_prefill_query_tokens].shape + output[:num_prefill_query_tokens] = out else: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have prefix attention.") @@ -574,8 +570,8 @@ def forward( k_scale, v_scale, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + assert output[:num_prefill_query_tokens].shape == out.shape + output[:num_prefill_query_tokens] = out if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( @@ -587,7 +583,7 @@ def forward( block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - output[num_prefill_tokens:] = PagedAttention.forward_decode( + output[num_prefill_query_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, value_cache, From c39d4c9bffa0a387ffcf74e6307146a3c1dd987b Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 1 Nov 2024 00:53:38 +0000 Subject: [PATCH 28/32] Comments --- vllm/attention/backends/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 60ef00078345f..6c6e4635d6c26 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -519,7 +519,7 @@ def get_num_prefill_decode_query_kv_tokens( attn_type: AttentionType, ) -> Tuple[int, int, int]: """ - Calculate the number of prefill and decode tokens for query, key or value + Calculate the number of prefill and decode tokens for query, key/value based on the attention metadata and the specified attention type. Args: @@ -527,9 +527,9 @@ def get_num_prefill_decode_query_kv_tokens( attn_type (AttentionType): The type of attention being used. Returns: Tuple[int, int, int]: A tuple containing three integers: - - The number of prefill tokens. - - The number of encoder tokens. - - The number of decode tokens. + - The number of prefill query tokens. + - The number of prefill key/value tokens. + - The number of decode query tokens. Raises: AssertionError: If the number of encoder tokens in `attn_metadata` From cc58ebedea2838876d90086b87158aac16874236 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 1 Nov 2024 01:00:09 +0000 Subject: [PATCH 29/32] Comments --- tests/encoder_decoder/test_e2e_correctness.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index ec5a1ed22c1c6..f2d7e9fd78cf3 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -66,6 +66,7 @@ def test_encoder_decoder_e2e( ''' with global_force_attn_backend_context_manager(attn_backend): if attn_backend == _Backend.FLASH_ATTN: + # Flash Attention works only with bfloat16 data-type dtype = 'bfloat16' test_case_prompts = example_encoder_decoder_prompts[ decoder_prompt_type] From 834572f766fd85a6f471db492a111709d25b1987 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 1 Nov 2024 17:14:53 +0000 Subject: [PATCH 30/32] Comments --- vllm/attention/backends/flash_attn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 41e1612af3d91..cfd7c9c1caff2 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -756,8 +756,6 @@ def _get_causal_option(attn_type: AttentionType) -> bool: or attn_type == AttentionType.ENCODER_DECODER) -@torch.library.custom_op("vllm::unified_flash_attention", - mutates_args=["kv_cache"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, From 2264a62c4ae22c224e17efe0af3c71c3166c5828 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sat, 2 Nov 2024 01:10:39 +0000 Subject: [PATCH 31/32] Dummy --- vllm/attention/backends/flash_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 2975a41797e9f..5583affaa491f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -964,7 +964,6 @@ def unified_flash_attention( output = torch.cat([prefill_output, decode_output], dim=0) return output.view(num_tokens, hidden_size) - def unified_flash_attention_fake( query: torch.Tensor, key: torch.Tensor, From 7ca0ab72a161df1caee415ce087fef04d7d52cda Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sat, 2 Nov 2024 01:12:17 +0000 Subject: [PATCH 32/32] Format --- vllm/attention/backends/flash_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5583affaa491f..2975a41797e9f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -964,6 +964,7 @@ def unified_flash_attention( output = torch.cat([prefill_output, decode_output], dim=0) return output.view(num_tokens, hidden_size) + def unified_flash_attention_fake( query: torch.Tensor, key: torch.Tensor,