diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 8e6c50666e70c..d9404e6442616 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -47,32 +47,32 @@ def test_flash_attn(monkeypatch): # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported data type backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported kv cache data type backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported block size backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported sliding window backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported head size backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL def test_invalid_env(monkeypatch): diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py new file mode 100644 index 0000000000000..f25e7d480b6b3 --- /dev/null +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -0,0 +1,953 @@ +""" +Tests: + +* E2E test of Encoder attention + Decoder self-attention + + Encoder/decoder cross-attention (collectively + "encoder/decoder attention") +* Confirm enc/dec models will fail for chunked prefill +* Confirm enc/dec models will fail for prefix caching + +""" + +from typing import NamedTuple, Optional + +import pytest +import torch + +from tests.kernels.utils import * +from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.backends.abstract import AttentionBackend, AttentionType +from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP +from vllm.utils import is_hip + +HEAD_SIZES = [64, 256] + +NUM_HEADS = [1, 16] + +BATCH_SIZES = [1, 16] +BLOCK_SIZES = [16] +BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] +CUDA_DEVICE = "cuda:0" + +MAX_DEC_SEQ_LENS = [128] +MAX_ENC_SEQ_LENS = [128] + +# Narrow teest-cases for unsupported-scenario +# tests +HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] + + +class TestPoint(NamedTuple): + """ + Encapsulates the attributes which define a single invocation + of the test_e2e_enc_dec_attn() test + + Attributes: + num_heads: The number of heads in the model. + head_size: Head dimension + backend_name: Name of the backend framework used. + batch_size: Number of samples per batch. + block_size: Size of each block of data processed. + max_dec_seq_len: Maximum sequence length for the decoder. + max_enc_seq_len: Maximum sequence length for the encoder. + num_blocks: Number of blocks in the model. + """ + + num_heads: int + head_size: int + backend_name: str + batch_size: int + block_size: int + max_dec_seq_len: int + max_enc_seq_len: int + num_blocks: int + + +class TestResources(NamedTuple): + ''' + Encapsulates key components for performing an + encoder/decoder attention test + + Note that + (1) attn automatically selects an attention backend + based on platform info & a set of canned + heuristics + (2) attn_backend is thus *not the same backend + instance* used by attn, but rather it is + intended to be a + *different instance* of the *same backend class*; + it is assumed that the user of TestResources + will leverage attn_backend for the purpose of + constructing backend-compatible attention + metadata instances + + Attributes: + + * scale: 1/sqrt(d) scale factor for attn + * attn_backend: implementatino of abstraction + attention interface using + a particular kernel library + i.e. XFormers + * attn: Attention layer instance + * kv_cache: shared key/value cache for all attention + ''' + + scale: float + attn_backend: AttentionBackend + attn: Attention + kv_cache: torch.Tensor + + +def _make_test_resources(test_pt: TestPoint, ) -> TestResources: + ''' + Build key components for performing encoder/decoder attention test. + + Note that + (1) The Attention instance constructed here, automatically selects + an attention backend class based on platform info & a set of canned + heuristics, so + (2) The attention backend instance constructed here is thus *not + the same backend instance* used by attn, but rather it is + intended to be a *different instance* of the *same backend class*; + therefore, + (3) This function requires that test_pt.backend_name matches the backend + class that Attention will automatically select when it is constructed. + + + Arguments: + + * test_pt: TestPoint data structure; this function relies on the + following fields: num_heads, head_size, num_blocks, + block_size, backend_name + + Returns: + + * TestResources data structure. + ''' + + scale = float(1.0 / (test_pt.head_size**0.5)) + attn_backend = make_backend(test_pt.backend_name) + attn = Attention( + test_pt.num_heads, + test_pt.head_size, + scale=scale, + ) + if test_pt.num_blocks is None or test_pt.num_heads is None: + # Caller does not require a KV cache + return TestResources(scale, attn_backend, attn, None) + + # Construct KV cache + kv_cache = make_kv_cache(test_pt.num_blocks, + test_pt.num_heads, + test_pt.head_size, + test_pt.block_size, + device=CUDA_DEVICE) + return TestResources(scale, attn_backend, attn, kv_cache) + + +def _encoder_attn_setup( + test_pt: TestPoint, + test_rsrcs: TestResources, +) -> PhaseTestParameters: + ''' + Set up test vectors & data structures for encoder attention test. + + A triplet of synthetic query/key/value tensors are constructed. + Given this is an encoder attention test, the key & value + sequences will have the same length as the corresponding queries. + + The query/key/value tensors are passed to an ideal reference + self-attention implementation to generate an ideal output tensor. + + Encoder inference does not populate the KV cache, therefore + no KV cache memory mapping is constructed + + Arguments: + + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + + + Returns: + + * PhaseTestParameters data structure comprising (1) packed query/key/value + tensors, (2) the ideal output of attention computed using a naive + implementation, and (3) KVCache field set to None + ''' + + ( + num_heads, + head_size, + _, + batch_size, + _, + _, + max_q_seq_len, + _, + ) = test_pt + + scale = test_rsrcs.scale + + max_kv_seq_len = max_q_seq_len + + # Make test tensors + + qkv_in, _, _ = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) + + # Compute correct answer using naive non-causal attention + # implementation + + ideal_output = ref_masked_attention(qkv_in.query, + qkv_in.key, + qkv_in.value, + scale=scale, + q_seq_lens=qkv_in.q_seq_lens, + kv_seq_lens=qkv_in.kv_seq_lens) + + packed_ideal_output, _ = pack_tensor(ideal_output, + qkv_in.q_seq_lens, + device=CUDA_DEVICE) + + packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) + + return PhaseTestParameters( + PackedQKVO(packed_qkv, packed_ideal_output), + None # No KV cache + ) + + +def _decoder_attn_setup( + test_pt: TestPoint, + test_rsrcs: TestResources, + block_base_addr: int = 0, +) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: + ''' + Set up test vectors & data structures for self-attention test. + + A triplet of synthetic query/key/value tensors are constructed ("baseline" + query/key/value). Given this is a self-attention test, the key & value + sequences will have the same length as the corresponding queries. + + "Prefill" query/key/value tensors are derived by masking out the last value + in each baseline query/key/value. These tensors are used to test prefill & + populate KV cache for a subsequent decode test. + + "Decode" query/key/value tensors are derived by extracting *only* the last + value from each baseline query/key/value (i.e. complement of the prefill + tensors.) These tensors are used to test decode, conditional on the kv cache + being populated during the prefill test. + + The baseline query/key/value tensors are passed to an ideal reference + self-attention implementation to generate a "Baseline" ideal output tensor. + This tensor is split into the "Prefill" ideal output tensor (all but the + last element of each output sequence) and the "Decode" ideal output tensor + (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode + test results, respectively. + + This function also constructs the self-attention KV cache memory mapping + (slot mapping and block table), ensuring that the block table starts at + block_base_addr + + Arguments: + + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + * block_base_addr: decoder self-attention block-table base address + + Returns: + * qkv: Unpacked (batch_size x padded_seq_len x num_heads x + head_size) query/key/value tensors + * Prefill-phase decoder self-attention PhaseTestParameters data structure, + including (1) packed (number_of_tokens x num_heads x head_size) + query/key/value tensors along with (2) ideal attention output + computed using a naive implementation, and (3) memory-mapping data + structures appropriate for prefill phase. + * Decode-phase decoder self-attention PhaseTestParameters data structure, + including (1) packed (number_of_tokens x num_heads x head_size) + query/key/value tensors along with (2) ideal attention output + computed using a naive implementation, and (3) memory-mapping data + structures appropriate for decode phase. + * max_block_idx: max physical address in decoder self-attention block-table + (intended to be used as the base address for the encoder/ + decoder cross-attention block-table, which is not + constructed in this function) + ''' + + ( + num_heads, + head_size, + _, + batch_size, + block_size, + max_q_seq_len, + _, + _, + ) = test_pt + + scale = test_rsrcs.scale + + max_kv_seq_len = max_q_seq_len + + # Build test tensors + + ( + qkv, + prefill_qkv, + decode_qkv, + ) = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.DECODER, + device=CUDA_DEVICE) + + # Compute correct answer using naive attention implementation + # with causal attention mask + + causal_mask = make_causal_mask(max_q_seq_len, + max_kv_seq_len).to(CUDA_DEVICE) + + ideal_output = ref_masked_attention(qkv.query, + qkv.key, + qkv.value, + scale=scale, + custom_mask=causal_mask, + q_seq_lens=qkv.q_seq_lens, + kv_seq_lens=qkv.kv_seq_lens) + + # Split out the prefill- & decode-phase ideal answers & pack them + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): + prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + bdx, :prefill_q_seq_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( + prefill_q_seq_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_qkv.q_seq_lens, + device=CUDA_DEVICE) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)], + device=CUDA_DEVICE) + + # Build prefill- & decode-phase data structures + # for decoder self-attention. Block tables and + # slot mapping must be in a format compatible + # with KV caching & attention kernels + # + # Prefill-phase: + # + # * Empty block-tables tensor + # * Slot-mapping with entries for prompt tokens + # + # Decode-phase: + # * Block-tables tensor with minimum number of blocks + # required by total num. tokens in the entirety of all sequences + # (including both prefill & decode) + # * Slot-mapping with entries for tokens that will be decoded in the + # current decode iteration + # + # Note: the format described above is simply mirroring what ModelRunner + # produces + + prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) + + ( + decode_block_tables, + slot_mapping_list, + max_block_idx, + ) = make_block_tables_slot_mapping(block_size, + qkv.q_seq_lens, + device=CUDA_DEVICE, + block_base_addr=block_base_addr) + + ( + prefill_slot_mapping, + decode_slot_mapping, + ) = split_slot_mapping(slot_mapping_list, + qkv.q_seq_lens, + device=CUDA_DEVICE) + + prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) + + decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) + + return ( + qkv, + PhaseTestParameters( # Prefill test params + PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output), + KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), + PhaseTestParameters( # Decode test params + PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output), + KVMemoryMap(decode_block_tables, decode_slot_mapping)), + max_block_idx) + + +def _enc_dec_cross_attn_setup_reuses_query( + decoder_qkv: QKVInputs, + encoder_test_params: PhaseTestParameters, + prefill_decoder_phase_test_params: PhaseTestParameters, + test_pt: TestPoint, + test_rsrcs: TestResources, + block_base_addr: int = 0, +) -> Tuple[PhaseTestParameters, PhaseTestParameters]: + ''' + Set up test vectors & data structures for cross-attention test. + + A triplet of synthetic cross-attention key/value tensors are constructed + ("baseline" key/value). Given this is a cross-attention test, we assume + query tensors were already synthesized for a prior self-attention test and + will be reused for cross-attention. The key & value sequences generated here + may have a different length than the corresponding queries (as is often + the case for cross-attention between decoder and encoder sequences.) + + Cross attention key & value tensors do not grow during autoregressive + inference; thus this function obtains a single key/value pair suitable for + both prefill and decode. + + The "baseline" query tensor is received as an argument. The "baseline" + query/key/value tensors are passed to an ideal reference cross-attention + implementation to generate a "baseline" ideal output tensor. This tensor is + split into the "Prefill" ideal output tensor (all but the last element of + each output sequence) and the "Decode" ideal output tensor (*only* the last + element of each output sequence); the "Prefill" and "Decode" ideal output + tensors can be used to validate the prefill and decode test results, + respectively. + + This function also constructs the cross-attention KV cache memory mapping + (slot mapping and block table), ensuring that the block table starts at + block_base_addr. + + Arguments: + + * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x + num_heads x head_size) decoder self-attention inputs; + this function relies on the query and q_seq_lens + fields + * encoder_test_params: PhaseTestParameters data structure which was + used for encoder inference; KV cache field + is not used by this function + * prefill_decoder_phase_test_params: PhaseTestParameters data structure + used for prefill-phase decoder + self-attention; all fields + including KV cache required + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + * block_base_addr: decoder self-attention block-table base address + + Returns: + + * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data + structure, including (1) packed + (number_of_tokens x num_heads x head_size) query/key/value tensors + along with (2) ideal attention output computed using a + naive implementation, and (3) memory-mapping data structures appropriate + for prefill phase. + * Decode-phase encoder/decoder cross-attention PhaseTestParameters data + structure, including (1) packed + (number_of_tokens x num_heads x head_size) query/key/value tensors + along with (2) ideal attention output computed using a + naive implementation, and (3) memory-mapping data structures appropriate + for decode phase. + ''' + + assert encoder_test_params.packed_qkvo.packed_qkv is not None + assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None + + ( + num_heads, + head_size, + _, + batch_size, + block_size, + max_decoder_seq_len, + max_encoder_seq_len, + _, + ) = test_pt + + scale = test_rsrcs.scale + + decoder_query = decoder_qkv.query + decoder_seq_lens = decoder_qkv.q_seq_lens + encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_q_seq_lens = ( + prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens) + + assert prefill_q_seq_lens is not None + + ( + cross_kv, + _, + _, + ) = make_qkv(batch_size, + max_decoder_seq_len, + max_encoder_seq_len, + num_heads, + head_size, + force_kv_seq_lens=encoder_seq_lens, + attn_type=AttentionType.ENCODER_DECODER, + device=CUDA_DEVICE) + + ideal_output = ref_masked_attention(decoder_query, + cross_kv.key, + cross_kv.value, + scale=scale, + q_seq_lens=decoder_seq_lens, + kv_seq_lens=cross_kv.kv_seq_lens) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + bdx, :prefill_q_seq_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( + prefill_q_seq_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_seq_lens, + device=CUDA_DEVICE) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)], + device=CUDA_DEVICE) + + # Build prefill- & decode-phase data structures + # for encoder/decoder cross-attention. Block tables and + # slot mapping must be in a format compatible + # with KV caching & attention kernels + # + # Whereas decoder self-attention extracts relationships between + # equal-length Q/K/V sequences, which mutually grow in length + # with each decoded token, cross-attention relates the Q sequence + # - which grows with each new decoded token - to fixed-length + # K and V sequences derived from the encoder hidden states. + # + # Prefill-phase: + # + # * Empty block-tables tensor + # * Slot-mapping with as many entries as there are tokens in the encoder + # prompt. + # + # Decode-phase: + # * Block-tables tensor with minimum number of blocks to + # accommodate K & V tensors which are equal in lnegth + # to the encoder prompt length + # * Empty slot-mapping tensor (since K & V are fixed in size, + # new decoded tokens are not KV-cached and require no slot- + # mapping) + # + # Note: the format above is simply an extension of what ModelRunner + # produces for decoder-only models + + prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) + decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) + + ( + decode_block_tables, + prefill_slot_mapping_list, + _, + ) = make_block_tables_slot_mapping(block_size, + cross_kv.kv_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) + + prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, + device=CUDA_DEVICE) + + # Packed key/value (query is already provided) + packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) + + return ( + PhaseTestParameters( # Prefill-phase test params + PackedQKVO(packed_cross_kv, prefill_packed_ideal_output), + KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), + PhaseTestParameters( # Decode-phase test params + PackedQKVO(None, decode_packed_ideal_output), + KVMemoryMap(decode_block_tables, decode_slot_mapping))) + + +def _run_encoder_attention_test( + attn: Attention, + encoder_test_params: PhaseTestParameters, + attn_metadata: AttentionMetadata, +) -> torch.Tensor: + ''' + Run encoder attention. + + attn.forward() is passed attn_type=AttentionType.ENCODER in order + to configure the kernel invocation for encoder attention + + Requires attn_metadata.num_decode_tokens == 0 + (There is no encoder execution in the decode-phase) + + Arguments: + + * attn: Attention wrapper instance + * encoder_test_params: encoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query/key/value fields + * attn_metadata: attention metadata for encoder/decoder-self attention + + Returns: + * Attention.forward() applied to packed {query,key,value} and + & attn_metadata + ''' + assert attn_metadata.num_decode_tokens == 0 + attn_type = AttentionType.ENCODER + packed_qkv = encoder_test_params.packed_qkvo.packed_qkv + assert packed_qkv is not None + return attn.forward(packed_qkv.query, + packed_qkv.key, + packed_qkv.value, + None, + attn_metadata, + attn_type=attn_type) + + +def _run_decoder_self_attention_test( + test_rsrcs: TestResources, + decoder_test_params: PhaseTestParameters, + attn_metadata: AttentionMetadata, +) -> torch.Tensor: + ''' + Run decoder self-attention test. + + attn.forward() is passed attn_type=AttentionType.DECODER + in order to configure the kernel invocation for decoder self-attention. + + Arguments: + + * test_rsrcs: TestResources instance; this function relies on the kv_cache + and attn (Attention wrapper instance) fields + * decoder_test_params: decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query/key/value fields + * attn_metadata: attention metadata for decoder-self attention + (contains KV cache memory-mapping) + + Returns: + * Attention.forward() applied to packed_{query,key,value}, kv_cache + & attn_metadata + ''' + attn_type = AttentionType.DECODER + attn = test_rsrcs.attn + kv_cache = test_rsrcs.kv_cache + packed_qkv = decoder_test_params.packed_qkvo.packed_qkv + assert packed_qkv is not None + return attn.forward(packed_qkv.query, + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) + + +def _run_encoder_decoder_cross_attention_test( + test_rsrcs: TestResources, + decoder_test_params: PhaseTestParameters, + cross_test_params: Optional[PhaseTestParameters], + attn_metadata: AttentionMetadata, +) -> torch.Tensor: + ''' + Run encoder/decoder cross-attention test. + + Via PhaseTestParameters data structures, consumes the same query utilized + for decoder self-attention, plus a key/value specific to cross-attention. + + if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv + is None, this reflects that in decode-phase cross attention there + is no growth in the key and value tensors. + + attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER + in order to configure the kernel invocation for encoder/decoder cross- + attention. + + Arguments: + + * test_rsrcs: TestResources instance; this function relies on the kv_cache + and attn (Attention wrapper instance) fields + * decoder_test_params: decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query field + * cross_test_params: encoder/decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + key/value fields + * attn_metadata: attention metadata for encoder/decoder-self attention + + Returns: + * Attention.forward() applied to packed_{query,key,value}, kv_cache + & attn_metadata + ''' + assert decoder_test_params.packed_qkvo.packed_qkv is not None + + attn_type = AttentionType.ENCODER_DECODER + attn = test_rsrcs.attn + kv_cache = test_rsrcs.kv_cache + if cross_test_params is None: + key = None + value = None + else: + cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv + key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) + value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) + return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) + + +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) +def test_encoder_only(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, max_dec_seq_len: int, + max_enc_seq_len: int, monkeypatch): + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + # Note: KV cache size of 4096 is arbitrary & chosen intentionally + # to be more than necessary, since exceeding the kv cache size + # is not part of this test + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_enc_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Construct encoder attention test params (only used + # during prefill) + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Shared prefill metadata structure + + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + None, + decoder_test_params=None, + encoder_test_params=enc_test_params, + cross_test_params=None, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + + enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( + test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) + + # - Is encoder attention result correct? + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + + +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) +def test_e2e_enc_dec_attn( + num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch, +) -> None: + ''' + End-to-end encoder/decoder test: + + * Construct fake test vectors for (1) encoder attention, + (2) decoder self-attention, and (3) encoder/decoder cross-attention + * Construct (1) attention metadata structure with self- and cross-attention + attributes for prefill-phase, and (2) an analogous attention metadata + structure but for decode-phase + * Test attention steps in the following order + + * Encoder attention + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * Besides being reflective of realistic use-cases, this order would + exacerbate any accidental overlap in the self-/cross-attention + block tables, which one hopes to avoid + + + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + + This test utilizes PyTest monkey patching to force the attention backend + via an environment variable. + + Note on ROCm/HIP: currently encoder/decoder models are not supported on + AMD GPUs, therefore this test simply is skipped if is_hip(). + + Note on metadata: there is a single attention metadata structure shared by + all prefill-phase attention operations (encoder, decoder, enc/dec cross), + and a single one shared by all decode-phase attention operations + (decoder & enc/dec cross.) This is intended to reflect the behavior + of ModelRunner, which constructs a single attention metadata structure for + each prefill or decode run. A realistic scenario would rely on the + attention backend to utilize the appropriate attention metadata fields + according to the value of attn_metadata.attention_type. Thus, this test is + organized so as to confirm that the backend-under-test can handle a + shared prefill attention metadata structure & a shared decode attention + metadata structure. + ''' + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + # Note: KV cache size of 4096 is arbitrary & chosen intentionally + # to be more than necessary, since exceeding the kv cache size + # is not part of this test + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_enc_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Construct encoder attention test params (only used + # during prefill) + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Construct Decoder self-attention prefill-phase & decode-phase + # test params, including query/key/value tensors, decoder self-attention + # memory-mapping. cross_block_base_addr is the uppermost address in the + # decoder self-attention block-table, i.e. a base address which the + # encoder/decoder cross-attention block-table may build downward toward. + + ( + dec_qkv, + prephase_dec_test_params, + decphase_dec_test_params, + cross_block_base_addr, + ) = _decoder_attn_setup(test_pt, test_rsrcs) + + # Construct encoder/decoder cross-attention prefill-phase & decode-phase + # test params, including key/value tensors, cross-attention memory-mapping + + ( + prephase_cross_test_params, + decphase_cross_test_params, + ) = _enc_dec_cross_attn_setup_reuses_query( + dec_qkv, + enc_test_params, + prephase_dec_test_params, + test_pt, + test_rsrcs, + block_base_addr=cross_block_base_addr) + + # Shared prefill metadata structure + assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, + decoder_test_params=prephase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=prephase_cross_test_params, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + + enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata) + + # - Is encoder attention result correct? + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + + # PREFILL: decoder self-attention test + + prephase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) + + # - Is prefill decoder self-attention correct? + assert_actual_matches_ideal(prephase_dec_test_params, + prephase_dec_pckd_act_out) + + # PREFILL: encoder/decoder cross-attention test + + prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, + prephase_attn_metadata) + + # - Is prefill encoder/decoder cross-attention correct? + assert_actual_matches_ideal(prephase_cross_test_params, + prephase_cross_pckd_act_out) + + # DECODE: build decode-phase attention metadata + + decphase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + False, + dec_qkv.q_seq_lens, + decoder_test_params=decphase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=decphase_cross_test_params, + device=CUDA_DEVICE) + + # DECODE: decoder self-attention test + + decphase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) + + # - Is decode-phase decoder self-attention correct? + assert_actual_matches_ideal(decphase_dec_test_params, + decphase_dec_pckd_act_out) + + # DECODE: encoder/decoder cross-attention test + + decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) + + # - Is decode-phase encoder/decoder cross-attention correct? + assert_actual_matches_ideal(decphase_cross_test_params, + decphase_cross_pckd_act_out) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index b401eb87d3ec3..23d627820d247 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,12 +1,211 @@ """Kernel test utils""" +import itertools +import random +from numbers import Number +from typing import Any, List, NamedTuple, Optional, Tuple, Union + import pytest +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.backends.xformers import XFormersBackend +from vllm.utils import make_tensor_with_pad +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" +class QKVInputs(NamedTuple): + ''' + Data structure for representing unpacked attention inputs, + query/key/values and their sequence lengths. + + Attributes: + + * {query,key,value}: unpacked (batch_size x padded_seq_len x + num_heads x head_size) attention inputs + * q_seq_lens: query sequence lengths list + * kv_seq_lens: shared key/value sequence lengths list + ''' + + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + q_seq_lens: List[int] + kv_seq_lens: List[int] + + +class QKVO(NamedTuple): + ''' + Data structure for representing unpacked attention inputs, + alongside unpacked known-correct attention output + + Attributes: + + * qkv: unpacked (batch_size x padded_seq_len x + num_heads x head_size) attention inputs + * ideal_output: unpacked (batch_size x padded_seq_len x + num_heads x head_size) known-correct attention output + ''' + + qkv: QKVInputs + ideal_output: torch.Tensor + + +class PackedQKVInputs(NamedTuple): + ''' + Data structure for representing packed attention inputs + + Attributes: + + * {query,key,value}: packed (number_of_tokens x num_heads + x head_size) attention inputs + * q_start_loc_list: list of query start locations within packed tensor + * kv_start_loc_list: shared list of key/value start locations within + packed tensor + * q_seq_lens: query sequence lengths list + * kv_seq_lens: shared key/value sequence lengths list + ''' + + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + q_start_loc_list: Optional[List[int]] + kv_start_loc_list: Optional[List[int]] + q_seq_lens: Optional[List[int]] + kv_seq_lens: Optional[List[int]] + + +class PackedQKVO(NamedTuple): + ''' + Data structure for representing packed attention inputs, + alongside packed known-correct attention output + + Attributes: + + * packed_qkv: packed (number_of_tokens x num_heads + x head_size) attention inputs + * ideal_output: packed (number_of_tokens x num_heads + x head_size) known-correct attention output + ''' + + packed_qkv: Optional[PackedQKVInputs] + ideal_output: torch.Tensor + + +class KVMemoryMap(NamedTuple): + ''' + Data structure for encapsulating KV cache memory mapping. + + Attributes: + + * block_tables: KV cache block tables + * slot_mapping: mapping of sequence offset to physical address + ''' + + block_tables: torch.Tensor + slot_mapping: torch.Tensor + + +class PhaseTestParameters(NamedTuple): + ''' + Data structure for encapsulating the test parameters + for a given test "phase" (prefill or decode phase) and attention + scenario (encoder, decoder-self, encoder/decoder-cross) + + Attributes: + + * packed_qkvo: packed (number_of_tokens x num_heads + x head_size) attention inputs & known-correct + output + * kv_mmap: KV cache memory mapping, specific to this test phase & + attention scenario + ''' + + packed_qkvo: PackedQKVO + kv_mmap: Optional[KVMemoryMap] + + +def maybe_make_int_tensor( + _list: Optional[List[int]], + device: Union[torch.device, str], +) -> torch.Tensor: + ''' + Convert Python int list to a 1D int torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D int torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.int, device=device) + + +def maybe_make_long_tensor( + _list: Optional[List[int]], + device: Union[torch.device, str], +) -> torch.Tensor: + ''' + Convert Python int list to a 1D long torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D long torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.long, device=device) + + +def maybe_max(_list: Optional[List]) -> Optional[Number]: + ''' + Returns: + + * If _list is not None: max(_list) + * None otherwise + ''' + return None if _list is None else max(_list) + + +def make_causal_mask( + q_max_seq_len: int, + kv_max_seq_len: int, +) -> torch.Tensor: + ''' + Create a q_max_seq_len x kv_max_seq_len causal mask + + Arguments: + + * q_max_seq_len: query max seq len + * kv_max_seq_len: key/value max seq len + + Returns: + + * 2D tensor, q_max_seq_len x kv_max_seq_len + ''' + + # Create a matrix where entry (i, j) is True if i >= j + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) + # Replace True with float('-inf') and False with 0 + mask = mask.masked_fill(mask == 1, + float('-inf')).masked_fill(mask == 0, 0.0) + return mask + + def override_backend_env_variable(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: ''' @@ -20,3 +219,724 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch, * backend_name: attention backend name to force ''' mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) + + +def ref_masked_attention(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_seq_lens: Optional[List] = None, + kv_seq_lens: Optional[List] = None) -> torch.Tensor: + ''' + "Golden" masked attention reference. Supports two types of masking: + + * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out + padding elements + * Custom attention mask, which can force an arbitrary mask tensor, i.e. + causal + + Arguments: + + * query: batch_size x q_padded_seq_len x num_heads x head_size + * key: batch_size x kv_padded_seq_len x num_heads x head_size + * value: batch_size x kv_padded_seq_len x num_heads x head_size + * scale: Attention scale factor + * custom_mask: custom attention mask; good place to inject a causal + attention mask + * q_seq_lens: list of unpadded query seq_lens for each batch index + * kv_seq_lens: list of unpadded key/value seq_lens for each batch index + + Returns: + + * Attention result, batch_size x q_padded_seq_len x num_heads x head_size + ''' + + assert q_seq_lens is not None + assert kv_seq_lens is not None + + batch_size = query.shape[0] + assert (len(q_seq_lens) == batch_size) + assert (len(kv_seq_lens) == batch_size) + + attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() + + # Basic attention mask, derived from seq lens + if (q_seq_lens is not None) or (kv_seq_lens is not None): + attn_mask = torch.zeros_like(attn_weights) + if q_seq_lens is not None: + for bdx, plen in enumerate(q_seq_lens): + attn_mask[bdx, :, plen:, :] = -torch.inf + if kv_seq_lens is not None: + for bdx, plen in enumerate(kv_seq_lens): + attn_mask[bdx, :, :, plen:] = -torch.inf + + attn_weights = attn_weights + attn_mask.float() + + # Custom attention mask + if custom_mask is not None: + attn_weights = attn_weights + custom_mask.float() + + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) + return out + + +def make_qkv( + batch_size: int, + max_q_seq_len: int, + max_kv_seq_len: Optional[int], + num_heads: int, + head_size: int, + device: Union[torch.device, str], + force_kv_seq_lens: Optional[List[int]] = None, + attn_type: AttentionType = AttentionType.ENCODER_DECODER, + force_max_len: bool = False, +) -> Tuple[QKVInputs, QKVInputs, QKVInputs]: + ''' + Construct QKV test tensors for self- and cross-attention. + + Generates three query/key/value triplets: + + * "Baseline" query/key/value (for input to reference attention function) + * "Prefill" query/key/value (last sequence offset zero'd out, for use as + input to prefill kernel) + * "Decode" query/key/value (only the last sequence offset from baseline, + for use as input to decode kernel) + + Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v + seqlens + + Arguments: + + * batch_size + * max_q_seq_len: max query seq len + * max_kv_seq_len: max key/value seq len + * num_heads + * head_size + * is_encoder_decoder_attn: if True, query seqlen may differ from + key/value seqlen (as is often the case for cross-attention); + o/w, query/key/value seqlens match at each batch index + (max_kv_seq_len is unused) + * force_kv_seq_lens: if not None, overrides kv sequence lengths + * attn_type: encoder, decoder self, or enc/dec cross attention + * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query + seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens + and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False + * device: CPU or CUDA device + + Returns: + + * Overall QKVInputs structure (containing full unpacked Q/K/V tensors) + * Prefill QKVInputs structure (containing all but the last sequence offset) + * Decode QKVInputs structure (containing all only the last sequence offset) + ''' + + if force_max_len: + q_seq_lens = [max_q_seq_len for _ in range(batch_size)] + else: + q_seq_lens = [ + random.randint(2, max_q_seq_len) for _ in range(batch_size) + ] + kv_seq_lens = None + if force_kv_seq_lens is not None: + kv_seq_lens = force_kv_seq_lens + elif attn_type != AttentionType.ENCODER_DECODER: + # K,V seq lens match Q for self-attention + kv_seq_lens = q_seq_lens + else: + # K,V seq lens are distinct from Q seq lens & random + assert max_kv_seq_len is not None + if force_max_len: + kv_seq_lens = [max_kv_seq_len] * batch_size + else: + kv_seq_lens = [ + random.randint(2, max_kv_seq_len) for _ in range(batch_size) + ] + + query = torch.rand( + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) + key = torch.rand( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + value = torch.rand( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + + prefill_query = torch.zeros( + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) + prefill_key = torch.zeros( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + prefill_value = torch.zeros( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + + decode_query = torch.zeros( + (batch_size, 1, num_heads, head_size)).to(device) + decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) + decode_value = torch.zeros( + (batch_size, 1, num_heads, head_size)).to(device) + + for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, + kv_seq_lens)): + query[bdx, q_seq_len:, :, :] = 0 + key[bdx, kv_seq_len:, :, :] = 0 + value[bdx, kv_seq_len:, :, :] = 0 + + prefill_query[bdx, + 0:(q_seq_len - 1), :, :] = query[bdx, + 0:(q_seq_len - 1), :, :] + prefill_key[bdx, + 0:(kv_seq_len - 1), :, :] = key[bdx, + 0:(kv_seq_len - 1), :, :] + prefill_value[bdx, 0:(kv_seq_len - + 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] + + decode_query[bdx, :, :, :] = query[bdx, + (q_seq_len - 1):q_seq_len, :, :] + decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] + decode_value[bdx, :, :, :] = value[bdx, + (kv_seq_len - 1):kv_seq_len, :, :] + + prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] + prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] + + decode_q_seq_lens = [1 for _ in q_seq_lens] + decode_kv_seq_lens = [1 for _ in kv_seq_lens] + + return ( + QKVInputs( + query, # Overall QKV inputs + key, + value, + q_seq_lens, + kv_seq_lens), + QKVInputs( + prefill_query, # Prefill subset of QKV sequences + prefill_key, + prefill_value, + prefill_q_seq_lens, + prefill_kv_seq_lens), + QKVInputs( + decode_query, # Decode subset of KV sequences + decode_key, + decode_value, + decode_q_seq_lens, + decode_kv_seq_lens)) + + +def pack_tensor( + unpacked_tensor: torch.Tensor, seq_lens: List[int], + device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]: + ''' + Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an + unpadded number_of_tokens x num_heads x head_size tensor, where + number_of_tokens = sum(seq_lens) + + Arguments: + + * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size + * seq_lens: list of token counts for each seq + * device: CPU or CUDA device + + Returns + + * packed_tensor: number_of_tokens x num_heads x head_size + * start_loc_list: start idx of each batch elt in packed_tensor; [0] + + list(itertools.accumulate(seq_lens)) + ''' + + num_tok = sum(seq_lens) + num_heads = unpacked_tensor.shape[-2] + head_size = unpacked_tensor.shape[-1] + start_loc_list = [0] + list(itertools.accumulate(seq_lens)) + packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) + + for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): + + packed_tensor[start_loc:( + start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] + + return packed_tensor, start_loc_list + + +def pack_qkv(qkv: QKVInputs, device: Union[torch.device, + str]) -> PackedQKVInputs: + ''' + Individually pack each of Q, K and V, each with dimensions batch_size x + padded_seq_len x num_heads x head_size, into respective number_of_tokens x + num_heads x head_size tensors. + + For Q, number_of_tokens = sum(q_seq_lens). + + For K and V, number_of_tokens = sum(kv_seq_lens) + + Arguments: + + * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size) + attention inputs + * device: CPU or CUDA device + + Returns + + * Packed (number_of_tokens x num_heads x head_size) QKV inputs + derived from unpacked inputs + ''' + + if qkv.query is None: + packed_query = None + q_start_loc_list = None + else: + packed_query, q_start_loc_list = pack_tensor(qkv.query, + qkv.q_seq_lens, + device=device) + packed_key, kv_start_loc_list = pack_tensor(qkv.key, + qkv.kv_seq_lens, + device=device) + packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device) + return PackedQKVInputs( + packed_query, packed_key, packed_value, q_start_loc_list, + kv_start_loc_list, + (None if q_start_loc_list is None else qkv.q_seq_lens), + qkv.kv_seq_lens) + + +def make_backend(backend_name: str) -> AttentionBackend: + ''' + Construct the backend instance determined by the backend_name string + argument. + + "XFORMERS" -> construct xformers backend + + TODO: other backends + + Note: at time of writing the Attention wrapper automatically selects + its own backend for Attention.forward(); so the backend instance which + you generate with this function is not meant to be used for *running* + inference, but rather for generating compatible metadata structures + using backend.make_metadata() + + + Returns: + + * Backend instance + ''' + if backend_name == STR_XFORMERS_ATTN_VAL: + return XFormersBackend() + raise AssertionError( + f"Unrecognized backend_name {backend_name} for unit test") + + +def _make_metadata_tensors( + seq_lens: Optional[List[int]], context_lens: Optional[List[int]], + encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] +) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], + torch.Tensor, Optional[int]]: + ''' + Build scalar & tensor values required to build attention metadata structure. + + Arguments: + + * seq_lens: list of token-counts for each decoder input seq + * context_lens: list of context length values for each seq + * encoder_seq_lens: list of token-counts for each encoder input seq + * device: CPU or CUDA device + + Returns: + + * seq_lens_tensor: decoder seq_lens list, as tensor + * context_lens_tensor: context_lens list, as tensor + * max_context_len: max(context_lens) + * max_seq_len: max(seq_lens) + * seq_start_loc: start idx of each sequence + * max_encoder_seq_len: encoder seq_lens list, as tensor + ''' + seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) + context_lens_tensor = maybe_make_int_tensor(context_lens, device) + max_context_len = maybe_max(context_lens) + max_seq_len = maybe_max(seq_lens) + + encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) + max_encoder_seq_len = (None if encoder_seq_lens is None else + max(encoder_seq_lens)) + + seq_start_loc = None + + return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, + seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len) + + +def make_kv_cache(num_blocks: int, + num_heads: int, + head_size: int, + block_size: int, + device: Union[torch.device, str], + default_val: float = 0.0) -> torch.Tensor: + ''' + Create a fake KV cache. + + Arguments: + + * num_blocks: number of blocks in the KV cache + * num_heads: number of attention heads + * head_size: head dimension + * block_size: number of offsets within a block + * device: CPU or CUDA device + * default_val: initialization value for KV cache elements + + Returns: + + * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + ''' + + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) + if default_val is not None: + kv_cache[:, :, :] = default_val + return kv_cache + + +def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: + ''' + Compute the minimum number of blocks required to hold num_tokens tokens, + given block_size + ''' + return (num_tokens + block_size) // block_size + + +def make_empty_slot_mapping_tensor(device: Union[torch.device, str]): + return maybe_make_long_tensor([], device) + + +def make_empty_block_tables_tensor(device: Union[torch.device, str]): + return torch.tensor([], device=device) + + +def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], + device: Union[torch.device, str]): + ''' + Split a slot mapping into valid prefill- and decode-phase slot mappings. + + Context: + * Your goal is to test (1) prefill of N prompts, with prompt-lengths + {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token + for all N prompts (N tokens total); the resultant sequence lengths + after decode would be {K_i + 1 for i \\in [0,N)} + * The test you want to do requires (1) having the prefill slot mapping + for all tokens present during prefill, the number of which is + M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N + decoded tokens + + This function consumes a single 1D slot mapping, which is the + concatenation of N slot mappings each of length K_i + 1 (corresponding + to the sequence lengths after decode), with a total length of + P = \\sum_i{K_i + 1} = M + N + + The prefill-phase slot mapping results from excising the (K_i + 1)-th entry + from each of the N subsequences in the slot mapping (i.e. omitting the + decoded token's mapping.) + + The N excised entries are appended to obtain the decode-phase slot mapping + + Arguments: + + * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N + post-decode sequences + * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the + description above) + * device: cuda, cpu, etc. + + Returns: + + * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor) + reflecting all N prefill prompts + * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting + all N decoded tokens + ''' + + prefill_slot_mapping = [] + decode_slot_mapping = [] + + base_idx = 0 + for seq_len in seq_lens: + prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx + + seq_len - 1)]) + decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1]) + base_idx += seq_len + + return (maybe_make_long_tensor(prefill_slot_mapping, device), + maybe_make_long_tensor(decode_slot_mapping, device)) + + +def make_block_tables_slot_mapping( + block_size: int, + seq_lens: List[int], + device: Union[torch.device, str], + block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]: + ''' + Construct fake block tables & slot mappings. + + For a sequence with num_tokens tokens the minimum number + of required KV cache blocks is + + num_blocks = (num_tokens + block_size) // block_size + + Then the minimum KV cache size in blocks is + + total_cache_blocks = sum(num_blocks for all seqs) + + Then, the blocktable mapping counts downward from + + block_base_addr + total_cache_blocks + + to + + block_base_addr + + + The constructed block-tables and slot-mapping are sized to the + lengths of the sequences in their entirety (as reflected by seq_lens), + i.e. the total of prefill prompt tokens + decoded tokens. + + Arguments: + + * block_size: number of offsets per block + * seq_lens: list of token-counts for each sequence + * block_base_addr: the block table base address + * device: CPU or CUDA device + + Return: + + * block_tables_tensor: block table for sequence + * slot_mapping_list: slot mapping for sequence + * max_block_idx: the highest block address within this block table + ''' + + # Provision minimum number of KV cache blocks + num_blocks_list = [ + _num_tokens_to_min_blocks(num_tokens, block_size) + for num_tokens in seq_lens + ] + max_block_table_len = max(num_blocks_list) + block_table_pad_tokens = 10 + + block_tables = [] + slot_mapping_list = [] + # Compute uppermost address of block table + total_cache_blocks = sum(num_blocks_list) + block_base_idx = block_base_addr + total_cache_blocks + max_block_idx = block_base_idx + for sdx, num_tokens in enumerate(seq_lens): + num_blocks = num_blocks_list[sdx] + block_table = list( + range(block_base_idx, block_base_idx - num_blocks, -1)) + for idx in range(num_tokens): + mapping_value = ( + idx % block_size) + block_table[idx // block_size] * block_size + slot_mapping_list.append(mapping_value) + + block_base_idx -= num_blocks + block_tables.append(block_table) + + block_tables_tensor = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len + block_table_pad_tokens, + pad=0, + dtype=torch.int, + device=device, + ) + + return (block_tables_tensor, slot_mapping_list, max_block_idx) + + +def make_test_metadata( + attn_backend: AttentionBackend, + is_prompt: bool, + seq_lens: Optional[List[int]], + decoder_test_params: Optional[PhaseTestParameters], + device: Union[torch.device, str], + encoder_test_params: Optional[PhaseTestParameters] = None, + cross_test_params: Optional[PhaseTestParameters] = None +) -> AttentionMetadata: + ''' + Construct fake attention metadata for a given test phase + (prefill-phase or decode-phase). + + encoder_test_params and cross_test_params arguments allow encoder + attention and enc/dec cross-attention (respectively) to use distinct + metadata values from decoder self-attention (decoder_test_params.) + + if encoder_test_params and cross_test_params are None, the attention + metadata will support decoder-only scenario. + + Assumptions: + + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + + Arguments: + + * attn_backend: Backend for sourcing attention kernels + * is_prompt: prefill if True, o/w decode + * seq_lens: list of token counts for each sequence + * decoder_test_params: decoder self-attention test params; + this function requires + kv_mmap (memory mapping) field + * device: CPU or CUDA device + * encoder_test_params: encoder attention test params; + this function requires encoder query + sequence lengths field. If None, + encoder query sequence lengths are + treated as None + * cross_test_params: enc/dec cross-attention test params; + this function requires kv_mmap field. + If None, KV cache memory map data + structures are treated as None + + Return: + + * AttentionMetadata structure + ''' + + # Decoder self-attention memory mapping + # decoder_test_params is None signals encoder-only + # scenario, so kv_mmap is None + kv_mmap = (None + if decoder_test_params is None else decoder_test_params.kv_mmap) + + # This function constructs metadata assuming no chunked prefill, + # i.e. 100% prefill tokens or 100% decode tokens + # + # - If is_prompt, num_prefills_or_decodes is the number of prefills + # and num_prefill_or_decode_tokens is the number of prefill tokens + # - If not is_prompt, num_prefills_or_decodes is the number of decodes + # and num_prefill_or_decode_tokens is the number of decode tokens + # + # seq_lens is None signals encoder-only + # scenario, in which case num_prefills_or_decodes and + # num_prefill_or_decode_tokens are unused + num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens)) + + num_prefill_or_decode_tokens = (None if seq_lens is None else ( + sum(seq_lens) if is_prompt else len(seq_lens))) + + # Seems for non-prefix-caching scenarios context_lens + # is never needed + context_lens = None + + if encoder_test_params is None: + encoder_seq_lens = None + num_encoder_tokens = None + else: + # Encoder/decoder or encoder-only models only: + # * Extract encoder input sequence lengths + assert encoder_test_params.packed_qkvo.packed_qkv is not None + encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens + num_encoder_tokens = (None if encoder_seq_lens is None else + (sum(encoder_seq_lens))) + + if cross_test_params is None: + cross_kv_mmap = None + else: + # Encoder/decoder or encoder-only models only: + # * Extract *cross-attention* slot_mapping and block table + # (kv_mmap) + cross_kv_mmap = cross_test_params.kv_mmap + + if is_prompt: + # Prefill-phase scenario + + num_prefills = num_prefills_or_decodes + num_prefill_tokens = num_prefill_or_decode_tokens + num_decode_tokens = 0 + + ( + seq_lens_tensor, + context_lens_tensor, + _, + _, + _, + encoder_seq_lens_tensor, + max_encoder_seq_len, + ) = _make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=None if seq_lens is None else max(seq_lens), + max_decode_seq_len=0, + context_lens_tensor=context_lens_tensor, + block_tables=(None if kv_mmap is None else kv_mmap.block_tables), + use_cuda_graph=False, + num_encoder_tokens=num_encoder_tokens, + encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, + cross_slot_mapping=(None if cross_kv_mmap is None else + cross_kv_mmap.slot_mapping), + cross_block_tables=(None if cross_kv_mmap is None else + cross_kv_mmap.block_tables)) + + else: # not is_prompt + # Decode-phase scenario + + assert kv_mmap is not None + assert num_prefill_or_decode_tokens is not None + assert seq_lens is not None + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = num_prefill_or_decode_tokens + + ( + seq_lens_tensor, + context_lens_tensor, + _, + _, + _, + encoder_seq_lens_tensor, + max_encoder_seq_len, + ) = _make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=kv_mmap.slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=0, + max_decode_seq_len=max(seq_lens), + context_lens_tensor=context_lens_tensor, + block_tables=kv_mmap.block_tables, + use_cuda_graph=False, + num_encoder_tokens=num_encoder_tokens, + encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, + cross_slot_mapping=(None if cross_kv_mmap is None else + cross_kv_mmap.slot_mapping), + cross_block_tables=(None if cross_kv_mmap is None else + cross_kv_mmap.block_tables)) + + +def assert_actual_matches_ideal(test_params: PhaseTestParameters, + output_under_test: torch.Tensor) -> None: + ''' + Assert that observed output matches the ideal output + contained in the test parameters data structure. + + Arguments: + + * test_params: Test parameters including packed ideal output + * output_under_test: actually observed output value + ''' + ideal_output = test_params.packed_qkvo.ideal_output + assert torch.allclose(ideal_output, + output_under_test.view_as(ideal_output)) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 40768532f59c2..adb8325168cdf 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,11 +1,18 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields +from enum import Enum, auto from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) import torch +class AttentionType(Enum): + DECODER = auto() # Decoder attention between previous layer Q/K/V + ENCODER = auto() # Encoder attention between previous layer Q/K/V + ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + + class AttentionBackend(ABC): """Abstract class for attention backends.""" @@ -128,5 +135,6 @@ def forward( kv_cache: torch.Tensor, attn_metadata: T, kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 7b4578fcd8b9d..fe4c4a45dca0d 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -4,7 +4,7 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from vllm.attention.ops.paged_attn import PagedAttention @@ -328,6 +328,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -340,6 +341,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "BlocksparseFlashAttentionImpl") + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8cb5c3101a804..048abed48d2e9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class FlashAttentionBackend(AttentionBackend): @@ -257,6 +257,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -269,6 +270,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a9ab2313013d7..b27e3e40f566d 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -14,7 +14,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class FlashInferBackend(AttentionBackend): @@ -224,8 +224,14 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: FlashInferMetadata, kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 5114bfa6e1589..6a1295b1000bc 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -7,7 +7,7 @@ from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -157,6 +157,7 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: IpexAttnMetadata, # type: ignore kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -170,6 +171,11 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 22cb1a1bd1fd3..7a6954ceb6d6a 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -6,7 +6,7 @@ import torch_xla.experimental.dynamo_set_buffer_donor from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class PallasAttentionBackend(AttentionBackend): @@ -132,6 +132,7 @@ def forward( kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], attn_metadata: PallasMetadata, kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -146,6 +147,11 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 31ae0751486f5..81b546c65c819 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,7 +6,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -297,6 +297,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -309,6 +310,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "ROCmFlashAttentionImpl") + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 63f8466da9316..48418f24870f9 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,7 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.utils import is_cpu @@ -145,6 +145,7 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -158,6 +159,11 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TorchSDPABackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py new file mode 100644 index 0000000000000..a3cfc6e20748b --- /dev/null +++ b/vllm/attention/backends/utils.py @@ -0,0 +1,7 @@ +"""Attention backend utils""" + +# Error string(s) for encoder/decoder +# unsupported attention scenarios + +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " + "with encoder/decoder models.") diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index ff449c3ff74f8..6cc5f1d1477ae 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -6,10 +6,11 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import (AttentionBias, BlockDiagonalCausalMask, + BlockDiagonalMask, LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] + seq_start_loc: Optional[torch.Tensor] = None + # (batch_size,) A tensor of context lengths (tokens that are computed # so far). - context_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] = None - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + + # Self-attention prefill/decode metadata cache _cached_prefill_metadata: Optional["XFormersMetadata"] = None _cached_decode_metadata: Optional["XFormersMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt @@ -115,6 +141,28 @@ def __post_init__(self): # from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None + self.encoder_attn_bias: Optional[List[AttentionBias]] = None + self.cross_attn_bias: Optional[List[AttentionBias]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -122,30 +170,50 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: return None if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], + query_start_loc=query_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=False, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -154,29 +222,146 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: return None if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, + slot_mapping=slot_mapping, + seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], + block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata +def _get_attn_bias( + attn_metadata: XFormersMetadata, + attn_type: AttentionType, +) -> Optional[AttentionBias]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if attn_type == AttentionType.DECODER: + return attn_metadata.attn_bias + elif attn_type == AttentionType.ENCODER: + return attn_metadata.encoder_attn_bias + else: + # attn_type == AttentionType.ENCODER_DECODER + return attn_metadata.cross_attn_bias + + +def _set_attn_bias( + attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]], + attn_type: AttentionType, +) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if attn_type == AttentionType.DECODER: + attn_metadata.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + attn_metadata.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + attn_metadata.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_seq_len_block_table_args( + attn_metadata: XFormersMetadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -238,51 +423,144 @@ def __init__( def forward( self, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * XFormersImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - if kv_cache is not None: + # Check that appropriate attention metadata attributes are + # selected for the desired attention type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + # Self-attention vs. cross-attention will impact + # which KV cache memory-mapping & which + # seqlen datastructures we utilize + + if (attn_type != AttentionType.ENCODER and kv_cache is not None): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens + if (key is not None) and (value is not None): + + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + updated_slot_mapping, + self.kv_cache_dtype, + kv_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] + if key is not None and value is not None: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens @@ -294,10 +572,14 @@ def forward( # block tables are empty if the prompt does not have a cached # prefix. out = self._run_memory_efficient_xformers_forward( - query, key, value, prefill_meta) + query, key, value, prefill_meta, attn_type=attn_type) assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: + + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, @@ -320,13 +602,20 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: + + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) + output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -343,6 +632,7 @@ def _run_memory_efficient_xformers_forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: XFormersMetadata, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. @@ -356,8 +646,12 @@ def _run_memory_efficient_xformers_forward( key: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally """ - assert attn_metadata.seq_lens is not None + original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -375,18 +669,39 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - if attn_metadata.attn_bias is None: + attn_bias = _get_attn_bias(attn_metadata, attn_type) + if attn_bias is None: if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + if (attn_type == AttentionType.ENCODER_DECODER): + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens is not None + + # Default enc/dec cross-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + elif attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + + # Default encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.encoder_seq_lens) + else: + assert attn_metadata.seq_lens is not None + + # Default decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) - attn_metadata.attn_bias = [attn_bias] + attn_bias = [attn_bias] else: - attn_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.seq_lens) + assert attn_metadata.seq_lens is not None + attn_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, query.dtype, + attn_metadata.seq_lens) + + _set_attn_bias(attn_metadata, attn_bias, attn_type) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -400,7 +715,7 @@ def _run_memory_efficient_xformers_forward( query, key, value, - attn_bias=attn_metadata.attn_bias[0], + attn_bias=attn_bias[0], p=0.0, scale=self.scale) return out.view_as(original_query) @@ -409,6 +724,7 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. + assert attn_metadata.seq_lens is not None output = torch.empty_like(original_query) start = 0 for i, seq_len in enumerate(attn_metadata.seq_lens): @@ -417,7 +733,7 @@ def _run_memory_efficient_xformers_forward( query[None, start:end], key[None, start:end], value[None, start:end], - attn_bias=attn_metadata.attn_bias[i], + attn_bias=attn_bias[i], p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dfe93be462184..b8cc87be8c748 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import AttentionMetadata, AttentionType from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -90,9 +90,16 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, + attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale) + + return self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._kv_scale, + attn_type=attn_type) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore