From 92409f0589b845e8ae011ac61d62083da631c2f5 Mon Sep 17 00:00:00 2001 From: William Lin Date: Fri, 16 Aug 2024 11:41:56 -0700 Subject: [PATCH] [spec decode] [4/N] Move update_flash_attn_metadata to attn backend (#7571) Co-authored-by: Cody Yu --- vllm/attention/backends/abstract.py | 3 ++ vllm/attention/backends/flash_attn.py | 45 ++++++++++++++++++++++++++ vllm/spec_decode/draft_model_runner.py | 34 +------------------ 3 files changed, 49 insertions(+), 33 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 97b13917ccfaa..23c7830cd6264 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -75,6 +75,9 @@ def copy_blocks( ) -> None: raise NotImplementedError + def advance_step(self, num_seqs: int, num_queries: int): + raise NotImplementedError + @dataclass class AttentionMetadata: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f230bb57e3177..f146285bfc9e2 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -297,6 +297,51 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: ) return self._cached_decode_metadata + def advance_step(self, num_seqs: int, num_queries: int): + """ + Update metadata in-place to advance one decode step. + """ + # GPU in-place update is currently called separately through + # custom_ops.advance_step(). See draft_model_runner. TODO(will): Move + # this logic to the backend. + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 7707d38a0f666..324044c96d994 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -97,38 +97,6 @@ def __init__( self.flashinfer_prefill_workspace_buffer = None self.flashinfer_prefill_wrapper = None - def _update_flash_attn_metadata(self, attn_metadata, num_seqs, - num_queries): - assert isinstance(attn_metadata, FlashAttentionMetadata) - - if num_seqs != num_queries: - assert num_seqs > num_queries - assert attn_metadata.use_cuda_graph - - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_prefill_tokens == 0 - assert attn_metadata.num_decode_tokens == num_seqs - assert attn_metadata.slot_mapping.shape == (num_seqs, ) - - assert len(attn_metadata.seq_lens) == num_seqs - assert attn_metadata.seq_lens_tensor.shape == (num_seqs, ) - assert attn_metadata.max_query_len == 1 - assert attn_metadata.max_prefill_seq_len == 0 - assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens) - - assert attn_metadata.query_start_loc.shape == (num_queries + 1, ) - assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, ) - - assert attn_metadata.context_lens_tensor.shape == (num_queries, ) - - assert attn_metadata.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - attn_metadata.seq_lens[i] += 1 - attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens) - def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): @@ -166,7 +134,7 @@ def _gpu_advance_step( # Update attn_metadata attn_metadata = model_input.attn_metadata assert isinstance(attn_metadata, FlashAttentionMetadata) - self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries) + attn_metadata.advance_step(num_seqs, num_queries) # Update GPU tensors ops.advance_step(num_seqs=num_seqs,