From 4a5f7b6558e2b1b80464a5ee9e77e9164060f410 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 2 Jul 2024 00:32:52 +0000 Subject: [PATCH 01/12] unify flash attention backend kernel --- vllm/attention/backends/flash_attn.py | 171 ++++---------------------- vllm/worker/model_runner.py | 46 ++++--- 2 files changed, 49 insertions(+), 168 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8cb5c3101a804..5d3a2c32af681 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm_flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -73,10 +73,8 @@ class FlashAttentionMetadata(AttentionMetadata): updated from `CUDAGraphRunner.forward` API. """ # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] + # the computed tokens + new tokens. + seq_lens: List[int] # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| @@ -88,12 +86,6 @@ class FlashAttentionMetadata(AttentionMetadata): # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -119,70 +111,6 @@ class FlashAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None - - self._cached_prefill_metadata = FlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = FlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - ) - return self._cached_decode_metadata - class FlashAttentionImpl(AttentionImpl): """ @@ -281,7 +209,6 @@ def forward( if kv_cache is not None: key_cache = kv_cache[0] value_cache = kv_cache[1] - # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. @@ -294,74 +221,30 @@ def forward( self.kv_cache_dtype, ) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - out = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ).squeeze(1) + if kv_cache is None or (attn_metadata.block_tables is not None + and attn_metadata.block_tables.numel()) == 0: + k = key + v = value + block_tables = None + else: + k = kv_cache[0] + v = kv_cache[1] + block_tables = attn_metadata.block_tables + + max_seq_len = max(attn_metadata.seq_lens) + output = flash_attn_varlen_func( + q=query, + k=k, + v=v, + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + block_table=block_tables) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 942063677a427..2cc1068ba38b3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -574,7 +574,6 @@ def _prepare_model_input_tensors( batch_size = len(input_tokens) max_query_len = max(query_lens) - max_prefill_seq_len = max(prefill_seq_lens, default=0) max_decode_seq_len = max(decode_seq_lens, default=0) # If cuda graph can be used, pad tensors accordingly. @@ -682,7 +681,6 @@ def _prepare_model_input_tensors( slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - max_prefill_seq_len=max_prefill_seq_len, block_tables=block_tables, paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, @@ -706,10 +704,7 @@ def _prepare_model_input_tensors( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, @@ -880,6 +875,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + query_start_loc = torch.arange(0, + max_batch_size + 2, + dtype=torch.int32, + device=self.device) + seq_start_loc = torch.arange(0, + max_batch_size + 2, + dtype=torch.int32, + device=self.device) # Prepare buffer for outputs. These will be reused for all batch sizes. # It will be filled after the first graph capture. @@ -947,7 +950,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping=slot_mapping[:batch_size], num_prefill_tokens=0, num_decode_tokens=batch_size, - max_prefill_seq_len=0, block_tables=block_tables, paged_kv_indptr=paged_kv_indptr_tensor_host, paged_kv_indices=paged_kv_indices_tensor_host, @@ -971,13 +973,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, + seq_lens=seq_lens[:batch_size].tolist(), + max_query_len=1, + query_start_loc=query_start_loc[:batch_size + 1], + seq_start_loc=seq_start_loc[:batch_size + 1], context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, @@ -1122,9 +1121,7 @@ def execute_model( # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: + if model_input.attn_metadata.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] @@ -1160,7 +1157,7 @@ def execute_model( indices = model_input.sampling_metadata.selected_token_indices if model_input.is_prompt: hidden_states = hidden_states.index_select(0, indices) - elif decode_meta.use_cuda_graph: + elif model_input.attn_metadata.use_cuda_graph: hidden_states = hidden_states[:len(indices)] output.hidden_states = hidden_states @@ -1251,9 +1248,9 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": - attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, + "block_tables": attn_metadata.block_tables, + "seq_start_loc": attn_metadata.seq_start_loc, + "query_start_loc": attn_metadata.query_start_loc } self.output_buffers = {"hidden_states": hidden_states} return hidden_states @@ -1275,11 +1272,12 @@ def forward( self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) if self.backend_name != "flashinfer": - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, - non_blocking=True) self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + attn_metadata.block_tables, non_blocking=True) + self.input_buffers["query_start_loc"].copy_( + attn_metadata.query_start_loc, non_blocking=True) + self.input_buffers["seq_start_loc"].copy_( + attn_metadata.seq_start_loc, non_blocking=True) # Run the graph. self.graph.replay() From e449f00b08759dad2619cc6632bc374a809cd4e0 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 9 Jul 2024 03:57:41 +0000 Subject: [PATCH 02/12] fix --- vllm/attention/backends/flash_attn.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5d3a2c32af681..ef212779d4c9b 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -94,9 +94,6 @@ class FlashAttentionMetadata(AttentionMetadata): # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) @@ -111,6 +108,13 @@ class FlashAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Fields that are not used in flash attention backend, + # but used in other backends + context_lens_tensor: Optional[torch.Tensor] = None + seq_lens_tensor: Optional[torch.Tensor] = None + max_prefill_seq_len: Optional[int] = None + max_decode_seq_len: Optional[int] = None + class FlashAttentionImpl(AttentionImpl): """ From 037c6343f05e2d576121f5be283401a4f6ebb4c1 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 9 Jul 2024 04:11:50 +0000 Subject: [PATCH 03/12] fix xformer backend --- vllm/worker/model_runner.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 82e2efeb87850..6841eb79d7475 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -598,6 +598,7 @@ def _prepare_model_input_tensors( batch_size = len(input_tokens) max_query_len = max(query_lens) + max_prefill_seq_len = max(prefill_seq_lens, default=0) max_decode_seq_len = max(decode_seq_lens, default=0) # If cuda graph can be used, pad tensors accordingly. @@ -738,7 +739,10 @@ def _prepare_model_input_tensors( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, @@ -981,6 +985,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: dtype=torch.int32, device=self.device) seq_lens = [1] * max_batch_size + seq_lens_tensor = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.device) with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the @@ -1050,7 +1057,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: num_decode_tokens=batch_size, slot_mapping=slot_mapping[:batch_size], seq_lens=seq_lens[:batch_size], - seq_lens_tensor=seq_lens[:batch_size], + seq_lens_tensor=seq_lens_tensor[:batch_size], max_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.max_seq_len_to_capture, @@ -1418,6 +1425,10 @@ def forward( attn_metadata.query_start_loc, non_blocking=True) self.input_buffers["seq_start_loc"].copy_( attn_metadata.seq_start_loc, non_blocking=True) + + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.seq_lens_tensor, non_blocking=True) + if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) From 66d33474f600bbbb9d19c3cb4792a7949f138c36 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 9 Jul 2024 16:32:20 +0000 Subject: [PATCH 04/12] fix cudagraph --- vllm/worker/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6841eb79d7475..5b0e0e998c026 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -616,6 +616,7 @@ def _prepare_model_input_tensors( input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) seq_lens.append(1) + query_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) From b1d2b5c5d734ab3bba209a0925832f9bf0b4d0f7 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Fri, 12 Jul 2024 15:33:17 -0700 Subject: [PATCH 05/12] fix ci --- tests/worker/test_model_runner.py | 23 +++++++++++++++++------ vllm/spec_decode/draft_model_runner.py | 4 ++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e1775790c0a03..0bfd4e2ac58bd 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -198,6 +198,11 @@ def test_prepare_decode_cuda_graph(batch_size): # decode has only 1 token for query. start_idx += 1 start_loc.append(start_idx) + # start_loc are padded to expected_bs + 1 + last_loc = start_loc[-1] + 1 + for _ in range(expected_bs - (len(start_loc) - 1)): + start_loc.append(last_loc) + last_loc += 1 assert torch.allclose( attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) @@ -207,6 +212,10 @@ def test_prepare_decode_cuda_graph(batch_size): for seq_len in seq_lens: start_idx += seq_len seq_start_loc.append(start_idx) + last_loc = seq_start_loc[-1] + 1 + for _ in range(expected_bs - (len(start_loc) - 1)): + start_loc.append(last_loc) + last_loc += 1 assert torch.allclose( attn_metadata.seq_start_loc, torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) @@ -373,9 +382,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): attn_metadata = model_runner._prepare_model_input_tensors( seq_group_metadata_list).attn_metadata - for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), - vars(prefill_meta_actual)): - assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), - vars(decode_meta_actual)): - assert attr_expected[1] == attr_actual[1] + if attn_metadata.prefill_metadata: + for attr_expected, attr_actual in zip( + vars(attn_metadata.prefill_metadata), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip( + vars(attn_metadata.decode_metadata), vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 6a2cfc819d8d2..0479730955bfb 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -142,8 +142,8 @@ def execute_model( # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: + if prefill_meta is None and \ + model_input.attn_metadata.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = ( From 2461c8a49592dae295b4ff87aacad269ed2351ae Mon Sep 17 00:00:00 2001 From: lilyliu Date: Fri, 12 Jul 2024 16:27:18 -0700 Subject: [PATCH 06/12] fix ci --- vllm/worker/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 15a0f8dac8e1b..38179350338a3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -754,6 +754,7 @@ def _prepare_model_input_tensors( slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, block_tables=block_tables, paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, From 1618fff7ddc3ed7b44bc4ee7aed75a80d4833f3a Mon Sep 17 00:00:00 2001 From: lilyliu Date: Wed, 17 Jul 2024 20:07:27 -0700 Subject: [PATCH 07/12] fix tests and style --- tests/worker/test_model_runner.py | 19 +++++++++++++------ vllm/attention/backends/flash_attn.py | 3 +-- vllm/worker/model_runner.py | 12 ------------ 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index b5742c4338616..cd7f614f5d040 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -199,6 +199,11 @@ def test_prepare_decode_cuda_graph(batch_size): # decode has only 1 token for query. start_idx += 1 start_loc.append(start_idx) + # start_loc are padded to expected_bs + 1 + last_loc = start_loc[-1] + 1 + for _ in range(expected_bs - (len(start_loc) - 1)): + start_loc.append(last_loc) + last_loc += 1 assert torch.allclose( attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) @@ -374,9 +379,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): attn_metadata = model_runner._prepare_model_input_tensors( seq_group_metadata_list).attn_metadata - for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), - vars(prefill_meta_actual)): - assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), - vars(decode_meta_actual)): - assert attr_expected[1] == attr_actual[1] + if attn_metadata.prefill_metadata: + for attr_expected, attr_actual in zip( + vars(attn_metadata.prefill_metadata), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip( + vars(attn_metadata.decode_metadata), vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 2cbcb78026b77..d8fd8b7e39177 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -421,8 +421,7 @@ def forward( self.kv_cache_dtype, ) - # This is used during the profiling phase - # or TODO + # This is used during the profiling or prefill phase. if kv_cache is None or (attn_metadata.block_tables is not None and attn_metadata.block_tables.numel()) == 0: k = key diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a391b349a05ad..b5b25fc40a70b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1404,15 +1404,3 @@ def _get_graph_batch_size(batch_size: int) -> int: else: return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - -def _is_block_tables_empty(block_tables: Union[None, Dict]): - """ - Check if block_tables is None or a dictionary with all None values. - """ - if block_tables is None: - return True - if isinstance(block_tables, dict) and all( - value is None for value in block_tables.values()): - return True - return False From 548b4d818497b077760ad895818dbb0a7c865667 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Tue, 23 Jul 2024 05:54:57 -0700 Subject: [PATCH 08/12] minor --- vllm/worker/model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 23e0908298acd..ce662fdf0f68f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -517,8 +517,8 @@ def build(self) -> ModelInputForGPU: device=self.runner.device) # Sequence and query lengths. - self.seq_lens.extend([1] * cuda_graph_pad_size) - self.query_lens.extend([1] * cuda_graph_pad_size) + seq_lens.extend([1] * cuda_graph_pad_size) + query_lens.extend([1] * cuda_graph_pad_size) # Attention metadata. attn_metadata = self.attn_metadata_builder.build( @@ -899,7 +899,7 @@ def profile_run(self) -> None: return def remove_all_loras(self): - if not selsf.lora_manager: + if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.remove_all_adapters() From 3c16563a2ffa1a9439980068b5f7bccfbd1135f6 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Wed, 24 Jul 2024 02:17:06 -0700 Subject: [PATCH 09/12] minor --- vllm/attention/backends/flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 21eafe3acbf73..9bec0bf19753d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -134,7 +134,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: @property def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: - return self + return None class FlashAttentionMetadataBuilder( From 0c0f6c81b51cbd742e3386325180fd31917d7226 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 25 Jul 2024 07:27:27 -0700 Subject: [PATCH 10/12] minor --- vllm/attention/backends/flash_attn.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 9bec0bf19753d..6108690846974 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -246,6 +246,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], if block_table: input_block_tables[i, :len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device=device) + seq_lens.extend([1] * cuda_graph_pad_size) + query_lens.extend([1] * cuda_graph_pad_size) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -255,9 +257,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) + context_lens_tensor = None seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) @@ -420,15 +420,18 @@ def forward( ) # This is used during the profiling or prefill phase. - if kv_cache is None or (attn_metadata.block_tables is not None - and attn_metadata.block_tables.numel()) == 0: + print("num_prefills", attn_metadata.num_prefills, attn_metadata.num_decode_tokens, kv_cache is None) + if kv_cache is None: + print("1-----------------") k = key v = value block_tables = None else: + print("2-----------------") k = kv_cache[0] v = kv_cache[1] block_tables = attn_metadata.block_tables + print(block_tables.shape, kv_cache is None, attn_metadata.slot_mapping.shape, key.shape, value.shape) max_seq_len = max(attn_metadata.seq_lens) output = flash_attn_varlen_func( From 7d970f9c28f3ce67e465e2a16da12c81eb0267a3 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Wed, 7 Aug 2024 10:26:22 -0700 Subject: [PATCH 11/12] add for example --- vllm/attention/backends/flash_attn.py | 72 +++++++++++++++++++-------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7aa31834e646d..0b4e06fec290a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,6 +14,7 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.utils import make_tensor_with_pad +import debug_print if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -254,8 +255,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], if block_table: input_block_tables[i, :len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device=device) - seq_lens.extend([1] * cuda_graph_pad_size) - query_lens.extend([1] * cuda_graph_pad_size) + # print("block_tables", block_tables.shape, block_tables.data_ptr()) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -266,15 +266,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], assert max_query_len > 0, ("query_lens: {}".format(query_lens)) context_lens_tensor = None + query_start_loc = torch.tensor([0] + query_lens, + dtype=torch.int32, + device=device).cumsum(dim=0, + dtype=torch.int32) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + # print("seq_lens_tensor", seq_lens_tensor) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) @@ -282,10 +281,12 @@ def build(self, seq_lens: List[int], query_lens: List[int], dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + # print("seq_lens_tensor", seq_lens_tensor) + # print("query_start_loc", query_start_loc) + # print("seq_start_loc", seq_start_loc) + # print("slot_mapping", self.slot_mapping) + # print("max_seq_lens", max(seq_lens)) + # print("max_query_len", max_query_len) slot_mapping_tensor = torch.tensor(self.slot_mapping, dtype=torch.long, @@ -430,20 +431,46 @@ def forward( ) # This is used during the profiling or prefill phase. - print("num_prefills", attn_metadata.num_prefills, attn_metadata.num_decode_tokens, kv_cache is None) - if kv_cache is None: - print("1-----------------") + if kv_cache is None or (attn_metadata.block_tables is not None + and attn_metadata.block_tables.numel()) == 0: + # print("1-----------------") k = key v = value block_tables = None else: - print("2-----------------") + # print("2-----------------") k = kv_cache[0] v = kv_cache[1] block_tables = attn_metadata.block_tables - print(block_tables.shape, kv_cache is None, attn_metadata.slot_mapping.shape, key.shape, value.shape) max_seq_len = max(attn_metadata.seq_lens) + max_k = torch.max(k).reshape(1) + max_v = torch.max(v).reshape(1) + if attn_metadata.use_cuda_graph: + pass + # # block_tables.zero_() + # debug_print.print_tensor(query) + # debug_print.print_tensor(max_k) + # debug_print.print_tensor(max_v) + # debug_print.print_tensor(attn_metadata.query_start_loc) + # debug_print.print_tensor(attn_metadata.seq_start_loc) + # debug_print.print_tensor(attn_metadata.block_tables) + # debug_print.print_tensor(attn_metadata.seq_lens_tensor) + else: + pass + # print("query", query.shape, query[0, 0]) + # print("max_k", k.shape, max_k) + # print("max_v", v.shape, max_v) + # print("query_start_loc", attn_metadata.query_start_loc) + # print("seq_start_loc", attn_metadata.seq_start_loc) + # print("block_tables", block_tables) + # print("seq_lens", attn_metadata.seq_lens) + # print("max_query_len", attn_metadata.max_query_len) + # print("max_seqlen_k", max_seq_len) + # print("scale", self.scale) + # print("sliding_window", self.sliding_window) + # print("alibi_slopes", self.alibi_slopes) + output = flash_attn_varlen_func( q=query, k=k, @@ -452,11 +479,16 @@ def forward( cu_seqlens_k=attn_metadata.seq_start_loc, max_seqlen_q=attn_metadata.max_query_len, max_seqlen_k=max_seq_len, - softmax_scale=self.scale, + softmax_scale=0.125, causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, + window_size=(-1, -1), + alibi_slopes=None, block_table=block_tables) + # if attn_metadata.use_cuda_graph: + # pass + # # debug_print.print_tensor(output[0,0]) + # else: + # print(output[0,0]) # Reshape the output tensor. return output.view(num_tokens, hidden_size) From 66e832be41cd3f29bd2b37303ea5944efcb16204 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Wed, 7 Aug 2024 10:27:03 -0700 Subject: [PATCH 12/12] add for example --- tests/kernels/test_flash_attn.py | 147 ++++++++++++++++++++++++++----- 1 file changed, 125 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index cd06c27175cef..7046751203dc6 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,11 +4,11 @@ import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] +NUM_HEADS = [(12, 12)] +HEAD_SIZES = [64] +BLOCK_SIZES = [16] DTYPES = [torch.float16, torch.bfloat16] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +NUM_BLOCKS = 99682 # Large enough to test overflow in index calculation. def ref_paged_attn( @@ -123,23 +123,8 @@ def test_flash_attn_with_paged_kv( f"{torch.max(torch.abs(output - ref_output))}" -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None]) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_varlen_with_paged_kv( - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - sliding_window: Optional[int], - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) +def prepare_varlen_with_paged_kv_input(seq_lens, num_heads, head_size, + sliding_window, dtype, block_size): num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -174,12 +159,43 @@ def test_varlen_with_paged_kv( dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + max_num_blocks_per_seq = 128 block_tables = torch.randint(0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + output = (query_lens, kv_lens, query, key_cache, value_cache, + cu_query_lens, cu_kv_lens, max_query_len, max_kv_len, scale, + window_size, block_tables) + return output + + +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + + query_lens, kv_lens, query, key_cache, value_cache, \ + cu_query_lens, cu_kv_lens, max_query_len, \ + max_kv_len, scale, window_size, block_tables \ + = prepare_varlen_with_paged_kv_input(seq_lens, num_heads, + head_size, sliding_window, + dtype, block_size) + output = flash_attn_varlen_func( q=query, k=key_cache, @@ -206,3 +222,90 @@ def test_varlen_with_paged_kv( ) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 912)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv_cudagraph(seq_lens, num_heads, head_size, + sliding_window, dtype, block_size): + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + graph_seq_lens = [(1, 2)] + query_lens, kv_lens, g_query, g_key_cache, g_value_cache, \ + g_cu_query_lens, g_cu_kv_lens, max_query_len, \ + max_kv_len, scale, window_size, g_block_tables \ + = prepare_varlen_with_paged_kv_input(graph_seq_lens, num_heads, + head_size, sliding_window, + dtype, block_size) + + # Warmup + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + flash_attn_varlen_func( + q=g_query, + k=g_key_cache, + v=g_value_cache, + cu_seqlens_q=g_cu_query_lens, + cu_seqlens_k=g_cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=g_block_tables, + ) + torch.cuda.current_stream().wait_stream(s) + + # Capture + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + output = flash_attn_varlen_func( + q=g_query, + k=g_key_cache, + v=g_value_cache, + cu_seqlens_q=g_cu_query_lens, + cu_seqlens_k=g_cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=g_block_tables, + ) + torch.cuda.synchronize() + + # Replay + query_lens, kv_lens, query, key_cache, value_cache, \ + cu_query_lens, cu_kv_lens, max_query_len, \ + max_kv_len, scale, window_size, block_tables \ + = prepare_varlen_with_paged_kv_input(seq_lens, num_heads, + head_size, sliding_window, + dtype, block_size) + g_query.copy_(query) + g_key_cache.copy_(key_cache) + g_value_cache.copy_(value_cache) + g_cu_query_lens.copy_(cu_query_lens) + g_cu_kv_lens.copy_(cu_kv_lens) + g_block_tables.copy_(block_tables) + + graph.replay() + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}"