diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index bef0c515b9073..f2d7e9fd78cf3 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -7,12 +7,18 @@ import pytest from transformers import AutoModelForSeq2SeqLM +from vllm.attention.selector import (_Backend, + global_force_attn_backend_context_manager) from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs from ..conftest import DecoderPromptType from ..models.utils import check_logprobs_close +LIST_ENC_DEC_SUPPORTED_BACKENDS = [ + _Backend.XFORMERS, _Backend.FLASH_ATTN, None +] + def vllm_to_hf_output( vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -29,7 +35,8 @@ def vllm_to_hf_output( @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @@ -48,6 +55,7 @@ def test_encoder_decoder_e2e( num_logprobs: int, decoder_prompt_type: DecoderPromptType, enforce_eager: bool, + attn_backend: _Backend, ) -> None: ''' End-to-End (E2E) test for the encoder-decoder framework. @@ -56,43 +64,49 @@ def test_encoder_decoder_e2e( implementations to ensure that both implementations produce consistent and correct results. ''' - test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] + with global_force_attn_backend_context_manager(attn_backend): + if attn_backend == _Backend.FLASH_ATTN: + # Flash Attention works only with bfloat16 data-type + dtype = 'bfloat16' + test_case_prompts = example_encoder_decoder_prompts[ + decoder_prompt_type] - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - with vllm_runner(model, dtype=dtype, - enforce_eager=enforce_eager) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + with vllm_runner(model, dtype=dtype, + enforce_eager=enforce_eager) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_case_prompts, max_tokens, num_logprobs) - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) + hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE + else 0) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 4239afb70ff85..ad41b310b059a 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,13 +16,14 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.attention.selector import (_Backend, +from vllm.attention.selector import (_Backend, get_attn_backend, global_force_attn_backend_context_manager) +from vllm.forward_context import set_forward_context from vllm.platforms import current_platform # List of support backends for encoder/decoder models LIST_ENC_DEC_SUPPORTED_BACKENDS = ([ - _Backend.XFORMERS + _Backend.XFORMERS, _Backend.FLASH_ATTN ] if not current_platform.is_rocm() else [_Backend.ROCM_FLASH]) HEAD_SIZES = [64, 256] @@ -147,7 +148,8 @@ class that Attention will automatically select when it is constructed. test_pt.num_heads, test_pt.head_size, test_pt.block_size, - device=CUDA_DEVICE) + device=CUDA_DEVICE, + backend=test_pt.backend_name) return TestResources(scale, attn_backend, attn, kv_cache) @@ -594,6 +596,7 @@ def _run_encoder_attention_test( attn: Attention, encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run encoder attention. @@ -612,6 +615,8 @@ def _run_encoder_attention_test( (number_of_tokens x num_heads x head_size) query/key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed {query,key,value} and @@ -621,20 +626,31 @@ def _run_encoder_attention_test( attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, - packed_qkv.key, - packed_qkv.value, - torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + packed_qkv.key, + packed_qkv.value, + torch.tensor([], + dtype=torch.float32, + device=packed_qkv.query.device), + attn_metadata, + attn_type=attn_type) def _run_decoder_self_attention_test( test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run decoder self-attention test. @@ -652,6 +668,8 @@ def _run_decoder_self_attention_test( query/key/value fields * attn_metadata: attention metadata for decoder-self attention (contains KV cache memory-mapping) + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache @@ -662,12 +680,22 @@ def _run_decoder_self_attention_test( kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, - packed_qkv.key, - packed_qkv.value, - kv_cache, - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -675,6 +703,7 @@ def _run_encoder_decoder_cross_attention_test( decoder_test_params: PhaseTestParameters, cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -703,6 +732,8 @@ def _run_encoder_decoder_cross_attention_test( (number_of_tokens x num_heads x head_size) key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache @@ -720,12 +751,37 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, - key, - value, - kv_cache, - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) + + +@pytest.fixture(autouse=True) +def set_reset_environment(attn_backend): + # Set the default torch datatype to bfloat16 to enable + # testing of the Flash Attention backend. Also clear the + # cached value of the backend. + default_dtype = torch.get_default_dtype() + if attn_backend.name == 'FLASH_ATTN': + torch.set_default_dtype(torch.bfloat16) + get_attn_backend.cache_clear() + yield + # Reset the torch datatype to what it was before the test + # so as not to impact the remaining tests. + torch.set_default_dtype(default_dtype) @pytest.mark.skipif(current_platform.is_rocm(), @@ -775,10 +831,8 @@ def test_encoder_only( * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' - # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test @@ -809,10 +863,14 @@ def test_encoder_only( # PREFILL: encoder attention enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( - test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + test_pt=test_pt)) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, + attn_backend.name) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -892,10 +950,8 @@ def test_e2e_enc_dec_attn( * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' - # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test @@ -955,29 +1011,39 @@ def test_e2e_enc_dec_attn( enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_test_params, - prephase_attn_metadata) + prephase_attn_metadata, + test_pt=test_pt) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, + attn_backend.name) # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) + test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, - prephase_dec_pckd_act_out) + prephase_dec_pckd_act_out, + attn_backend.name) # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - prephase_attn_metadata) + test_rsrcs, + prephase_dec_test_params, + prephase_cross_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out) + prephase_cross_pckd_act_out, + attn_backend.name) # DECODE: build decode-phase attention metadata @@ -993,17 +1059,26 @@ def test_e2e_enc_dec_attn( # DECODE: decoder self-attention test decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) + test_rsrcs, + decphase_dec_test_params, + decphase_attn_metadata, + test_pt=test_pt) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out) + decphase_dec_pckd_act_out, + attn_backend.name) # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) + test_rsrcs, + decphase_dec_test_params, + None, + decphase_attn_metadata, + test_pt=test_pt) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) + decphase_cross_pckd_act_out, + attn_backend.name) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 83ee45f673bdc..03e0c5086eb11 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -537,10 +537,12 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors( - seq_lens: Optional[List[int]], context_lens: Optional[List[int]], - encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] -) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], - torch.Tensor, Optional[int]]: + seq_lens: Optional[List[int]], + context_lens: Optional[List[int]], + encoder_seq_lens: Optional[List[int]], + device: Union[torch.device, str], +) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], + torch.Tensor, torch.Tensor, Optional[int]]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -558,6 +560,8 @@ def _make_metadata_tensors( * max_context_len: max(context_lens) * max_seq_len: max(seq_lens) * seq_start_loc: start idx of each sequence + * encoder_seq_lens_tensor: encoder seq_lens list, as tensor + * encoder_seq_start_loc: start idx of each encoder sequence * max_encoder_seq_len: encoder seq_lens list, as tensor ''' seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) @@ -571,8 +575,26 @@ def _make_metadata_tensors( seq_start_loc = None + if seq_lens_tensor is not None: + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=seq_lens_tensor.device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=encoder_seq_lens_tensor.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, - seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len) + seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, + max_encoder_seq_len) def make_kv_cache(num_blocks: int, @@ -580,6 +602,7 @@ def make_kv_cache(num_blocks: int, head_size: int, block_size: int, device: Union[torch.device, str], + backend: str, default_val: float = 0.0) -> torch.Tensor: ''' Create a fake KV cache. @@ -596,10 +619,20 @@ def make_kv_cache(num_blocks: int, Returns: * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + * for backend 'XFORMERS' + * kv_cache: 2 x num_blocks x block_size x num_heads x head_size + * for backend 'FLASH_ATTN' ''' - - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) + if backend == 'XFORMERS': + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) + elif backend == 'FLASH_ATTN': + kv_cache = torch.rand( + (2, num_blocks, block_size, num_heads, head_size)).to(device) + else: + raise ValueError( + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " + f"'FLASH_ATTN'.") if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache @@ -863,8 +896,9 @@ def make_test_metadata( context_lens_tensor, _, _, - _, + seq_start_loc, encoder_seq_lens_tensor, + encoder_seq_start_loc, max_encoder_seq_len, ) = _make_metadata_tensors(seq_lens, context_lens, @@ -879,6 +913,7 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + seq_start_loc=seq_start_loc, max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, @@ -887,6 +922,7 @@ def make_test_metadata( num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, + encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), @@ -909,8 +945,9 @@ def make_test_metadata( context_lens_tensor, _, _, - _, + seq_start_loc, encoder_seq_lens_tensor, + encoder_seq_start_loc, max_encoder_seq_len, ) = _make_metadata_tensors(seq_lens, context_lens, @@ -925,14 +962,17 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + seq_start_loc=seq_start_loc, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), + max_decode_query_len=1, context_lens_tensor=context_lens_tensor, block_tables=kv_mmap.block_tables, use_cuda_graph=False, num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, + encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), @@ -941,7 +981,8 @@ def make_test_metadata( def assert_actual_matches_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor) -> None: + output_under_test: torch.Tensor, + backend: str) -> None: ''' Assert that observed output matches the ideal output contained in the test parameters data structure. @@ -952,8 +993,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, * output_under_test: actually observed output value ''' ideal_output = test_params.packed_qkvo.ideal_output - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output)) + if backend == 'XFORMERS': + torch.testing.assert_close(ideal_output, + output_under_test.view_as(ideal_output)) + + elif backend == 'FLASH_ATTN': + # For FlashAttention override the accuracy thresholds to non default + # values since we notice a higher difference between the ideal and + # actual output. + torch.testing.assert_close(ideal_output, + output_under_test.view_as(ideal_output), + atol=0.01, + rtol=0.016) + else: + raise ValueError( + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " + f"'FLASH_ATTN'.") # Copied/modified from torch._refs.__init__.py diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index 483773f069133..d686f1da3fa17 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -85,7 +85,7 @@ def run_test( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e3cd758d023bc..2e696d71afa02 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -10,10 +10,11 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) from vllm.forward_context import get_forward_context from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import (async_tensor_h2d, direct_register_custom_op, @@ -73,7 +74,6 @@ def swap_blocks( src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @@ -85,6 +85,7 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) @@ -111,26 +112,12 @@ class FlashAttentionMetadata(AttentionMetadata): # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] @@ -146,11 +133,62 @@ class FlashAttentionMetadata(AttentionMetadata): # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_prefills == 0: @@ -159,32 +197,52 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) self._cached_prefill_metadata = FlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_query_len=0, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=False, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -194,17 +252,25 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: if self._cached_decode_metadata is not None: return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) self._cached_decode_metadata = FlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, @@ -214,9 +280,15 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], + block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata def advance_step(self, @@ -587,16 +659,20 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + output = torch.ops.vllm.unified_flash_attention( query, key, @@ -609,6 +685,7 @@ def forward( k_scale, v_scale, self.scale, + attn_type.value, self.sliding_window, self.alibi_slopes, self.logits_soft_cap, @@ -617,6 +694,89 @@ def forward( return output +def _get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_causal_option(attn_type: AttentionType) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) + + def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -629,60 +789,76 @@ def unified_flash_attention( k_scale: float, v_scale: float, softmax_scale: float, + attn_type_int_val: int, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: + # Convert integer attn_type to enum + try: + attn_type = AttentionType(attn_type_int_val) + except ValueError as err: + raise AttributeError( + f"Invalid attention type {str(attn_type_int_val)}") from err + current_metadata = get_forward_context() assert current_metadata is not None assert isinstance(current_metadata, FlashAttentionMetadata) attn_metadata: FlashAttentionMetadata = current_metadata num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) + if (key is not None) and (value is not None): + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the KV + # cache is already populated with the cross-attention tensor. + # Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + k_scale, + v_scale, + ) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + query = query[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None @@ -690,22 +866,30 @@ def unified_flash_attention( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + prefill_output = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, softmax_scale=softmax_scale, - causal=True, + causal=_get_causal_option(attn_type), window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, ) else: # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) prefill_output = flash_attn_varlen_func( # noqa @@ -730,6 +914,8 @@ def unified_flash_attention( # because different queries might have different lengths. assert decode_meta.max_decode_query_len is not None if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1") decode_output = flash_attn_varlen_func( q=decode_query, k=key_cache, @@ -747,12 +933,17 @@ def unified_flash_attention( ) else: # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, softmax_scale=softmax_scale, causal=True, window_size=window_size, @@ -762,10 +953,10 @@ def unified_flash_attention( if prefill_output is None: assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) + return decode_output.view(num_decode_query_tokens, hidden_size) if decode_output is None: assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) + return prefill_output.view(num_prefill_query_tokens, hidden_size) # Chunked prefill does not work with speculative decoding. # Therefore, the query length for decode should be 1 in chunked prefill. @@ -787,6 +978,7 @@ def unified_flash_attention_fake( k_scale: float, v_scale: float, softmax_scale: float, + attn_type_int_val: int, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 52be85f64d3e8..488913cc4e74f 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,14 +1,15 @@ """Attention backend utils""" from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union import numpy as np import torch from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) -from vllm.platforms import current_platform +from vllm.attention.backends.abstract import AttentionType +from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -372,20 +373,15 @@ def get_graph_input_buffers( "block_tables": attn_metadata.decode_metadata.block_tables, } if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - if current_platform.is_rocm(): - assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), ( - f"Expected attn_backend name to be 'ROCM_FLASH', but " - f" got '{self.runner.attn_backend.get_name()}'") - self._add_additonal_input_buffers_for_enc_dec_model( - attn_metadata=attn_metadata, input_buffers=input_buffers) - else: - assert self.runner.attn_backend.get_name() == "XFORMERS", \ - f"Expected attn_backend name to be 'XFORMERS', but "\ - f" got '{self.runner.attn_backend.get_name()}'" - self._add_additonal_input_buffers_for_enc_dec_model( - attn_metadata=attn_metadata, input_buffers=input_buffers) + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers def prepare_graph_input_buffers( @@ -442,6 +438,7 @@ def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, attn_metadata.encoder_seq_lens_tensor = torch.full( (batch_size, ), 1, dtype=torch.int).cuda() attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.num_encoder_tokens = 0 def _add_additonal_input_buffers_for_enc_dec_model( self, attn_metadata, input_buffers: Dict[str, Any]): @@ -484,3 +481,122 @@ def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, input_buffers["cross_block_tables"].copy_( attn_metadata.decode_metadata.cross_block_tables, non_blocking=True) + + +def is_all_encoder_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((attn_metadata.encoder_seq_lens is not None) + and (attn_metadata.encoder_seq_lens_tensor is not None) + and (attn_metadata.max_encoder_seq_len is not None)) + + +def is_all_cross_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (attn_metadata.is_all_encoder_attn_metadata_set + and (attn_metadata.cross_slot_mapping is not None) + and (attn_metadata.cross_block_tables is not None)) + + +def get_seq_len_block_table_args( + attn_metadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def get_num_prefill_decode_query_kv_tokens( + attn_metadata, + attn_type: AttentionType, +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill and decode tokens for query, key/value + based on the attention metadata and the specified attention type. + + Args: + attn_metadata (FlashAttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill query tokens. + - The number of prefill key/value tokens. + - The number of decode query tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ + num_prefill_query_tokens = 0 + num_decode_query_tokens = 0 + num_prefill_kv_tokens = 0 + if attn_type == AttentionType.ENCODER: + # Encoder attention is only invoked during prefill phase. + # The same input servers a both query and key. + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_encoder_tokens + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + # The key is the encoder/cross-attention. + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + num_prefill_kv_tokens = attn_metadata.num_prefill_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6dd561ef2c4f6..4df03d28f8011 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,8 +11,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) +from vllm.attention.backends.utils import ( + CommonAttentionState, CommonMetadataBuilder, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None @@ -162,9 +169,7 @@ def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + return is_all_encoder_attn_metadata_set(self) @property def is_all_cross_attn_metadata_set(self): @@ -173,9 +178,7 @@ def is_all_cross_attn_metadata_set(self): Superset of encoder attention required metadata. ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + return is_all_cross_attn_metadata_set(self) @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -329,64 +332,6 @@ def _set_attn_bias( raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_seq_len_block_table_args( - attn_metadata: XFormersMetadata, - is_prompt: bool, - attn_type: AttentionType, -) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_lens_tensor, max_seq_len, - attn_metadata.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.cross_block_tables) - elif attn_type == AttentionType.ENCODER: - # No block tables associated with encoder attention - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, None) - elif attn_type == AttentionType.ENCODER_ONLY: - assert is_prompt, "Should not have decode for encoder only model." - - # No block tables associated with encoder attention - return (attn_metadata.seq_lens_tensor, - attn_metadata.max_prefill_seq_len, None) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): _metadata_cls = XFormersMetadata @@ -575,45 +520,21 @@ def forward( updated_slot_mapping, self.kv_cache_dtype, k_scale, v_scale) - - if attn_type == AttentionType.ENCODER: - # Encoder attention - chunked prefill is not applicable; - # derive token-count from query shape & and treat them - # as 100% prefill tokens - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 - elif attn_type == AttentionType.DECODER: - # Decoder self-attention supports chunked prefill. - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - # Only enforce this shape-constraint for decoder - # self-attention - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - else: # attn_type == AttentionType.ENCODER_DECODER - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - if attn_metadata.num_encoder_tokens is not None: - num_encoder_tokens = attn_metadata.num_encoder_tokens - else: - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] + decode_query = query[num_prefill_query_tokens:] # QKV for prefill. - query = query[:num_prefill_tokens] + query = query[:num_prefill_query_tokens] if key is not None and value is not None: - key = key[:num_encoder_tokens] - value = value[:num_encoder_tokens] + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. @@ -623,8 +544,8 @@ def forward( # prefix. out = self._run_memory_efficient_xformers_forward( query, key, value, prefill_meta, attn_type=attn_type) - assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out + assert out.shape == output[:num_prefill_query_tokens].shape + output[:num_prefill_query_tokens] = out else: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have prefix attention.") @@ -653,8 +574,8 @@ def forward( k_scale, v_scale, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + assert output[:num_prefill_query_tokens].shape == out.shape + output[:num_prefill_query_tokens] = out if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( @@ -664,9 +585,9 @@ def forward( seq_lens_arg, max_seq_len_arg, block_tables_arg, - ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - output[num_prefill_tokens:] = PagedAttention.forward_decode( + output[num_prefill_query_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, value_cache, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 376b3136f0fb8..8a59cf41a689e 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -98,7 +98,6 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( @@ -108,6 +107,7 @@ def get_attn_backend( backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, is_attention_free) if backend == _Backend.FLASH_ATTN: + logger.info("Using Flash Attention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index cbdacf779b089..0543ca978b7dd 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -624,8 +624,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Decoder output torch.Tensor """ # retrieve input_ids and inputs_embeds - - input_ids = input_ids.view(-1, input_ids.shape[-1]) inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions( diff --git a/vllm/utils.py b/vllm/utils.py index 63cd275f72deb..ceee6d9d799b9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -80,8 +80,8 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " - "currently supported with encoder/" +STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only " + "backends currently supported with encoder/" "decoder models.") STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 13ddd7f43eaa3..efbd2ae523d60 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -19,6 +19,7 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.utils import get_architecture_class_name from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, MultiModalRegistry) from vllm.platforms import current_platform @@ -37,6 +38,11 @@ logger = init_logger(__name__) +# The Mllama model has PagedAttention specific logic because of which it +# can only be run with the XFORMERS backend +# TODO Make Mllama model work with Flash Attention backend. +_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"] + @dataclasses.dataclass(frozen=True) class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): @@ -102,9 +108,7 @@ def __init__( models) but these arguments are present here for compatibility with the base-class constructor. ''' - - self._maybe_force_supported_attention_backend() - + self._maybe_force_supported_attention_backend(model_config) super().__init__( model_config, parallel_config, @@ -120,7 +124,12 @@ def __init__( # Crash for unsupported encoder/scenarios assert_enc_dec_mr_supported_scenario(self) - def _maybe_force_supported_attention_backend(self): + def _is_xformers_only_encoder_decoder_model(self, + model: ModelConfig) -> bool: + return get_architecture_class_name( + model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS + + def _maybe_force_supported_attention_backend(self, model: ModelConfig): ''' Force vLLM to use the XFormers or ROCM attention backend, which is currently the only supported option. @@ -136,25 +145,27 @@ def raise_backend_err(): is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None - if not (is_forced_by_global or is_forced_by_env_var): + if not (is_forced_by_global or is_forced_by_env_var) \ + and self._is_xformers_only_encoder_decoder_model(model): # The user has not already specified an attention backend # override - logger.info("EncoderDecoderModelRunner requires " - "XFormers or ROCM backend; overriding backend " - "auto-selection and forcing XFormers or ROCM.") + logger.info( + "Encoder-Decoder Model Architecture %s requires XFormers " + "backend; overriding backend auto-selection and " + "forcing XFormers.", get_architecture_class_name(model)) global_force_attn_backend(_Backend.ROCM_FLASH if current_platform. is_rocm() else _Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. - if maybe_global_forced_backend != _Backend.XFORMERS and \ - maybe_global_forced_backend != _Backend.ROCM_FLASH: + if maybe_global_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() elif is_forced_by_env_var: # Backend override enforced by vLLM backend # environment variable - if maybe_env_var_forced_backend != _Backend.XFORMERS and \ - maybe_global_forced_backend != _Backend.ROCM_FLASH: + if maybe_env_var_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() def _list_to_int32_tensor( @@ -536,6 +547,7 @@ def _prepare_encoder_model_input_tensors( attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, attn_metadata.cross_slot_mapping, attn_metadata.cross_block_tables, ) = ( @@ -543,6 +555,7 @@ def _prepare_encoder_model_input_tensors( encoder_seq_lens, encoder_seq_lens_tensor, max_encoder_seq_len, + encoder_seq_start_loc, cross_slot_mapping_tensor, cross_block_tables, )