From 49ae30cbe4518e41da73dbf91204f40aded28b3b Mon Sep 17 00:00:00 2001 From: Kevin Lin <42618777+kevin314@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:18:14 -0500 Subject: [PATCH] [Spec Decode] Move ops.advance_step to flash attn advance_step (#8224) Signed-off-by: Amit Garg --- vllm/attention/backends/flash_attn.py | 21 +++++++++++++++------ vllm/spec_decode/draft_model_runner.py | 16 +++------------- vllm/worker/multi_step_model_runner.py | 19 +++++-------------- 3 files changed, 23 insertions(+), 33 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index a4fcf3644fd0f..e171ed2eea43e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -16,7 +16,8 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache @@ -306,14 +307,12 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: ) return self._cached_decode_metadata - def advance_step(self, num_seqs: int, num_queries: int): + def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int): """ Update metadata in-place to advance one decode step. """ - # GPU in-place update is currently called separately through - # custom_ops.advance_step(). See draft_model_runner. TODO(will): Move - # this logic to the backend. - # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries @@ -351,6 +350,16 @@ def advance_step(self, num_seqs: int, num_queries: int): self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) + ops.advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 6e35e40294381..1e403637d2388 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,7 +2,6 @@ import torch -from vllm import _custom_ops as ops from vllm.model_executor.layers.sampler import SamplerOutput try: @@ -116,18 +115,9 @@ def _gpu_advance_step( # Update attn_metadata attn_metadata = model_input.attn_metadata assert isinstance(attn_metadata, FlashAttentionMetadata) - attn_metadata.advance_step(num_seqs, num_queries) - - # Update GPU tensors - ops.advance_step(num_seqs=num_seqs, - num_queries=num_queries, - block_size=self.block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=attn_metadata.seq_lens_tensor, - slot_mapping=attn_metadata.slot_mapping, - block_tables=attn_metadata.block_tables) + + attn_metadata.advance_step(model_input, sampled_token_ids, + self.block_size, num_seqs, num_queries) # Update sampling_metadata sampling_metadata = model_input.sampling_metadata diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b13cf39bd846e..9a196c3dfcd1f 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -13,7 +13,6 @@ import torch -from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, @@ -499,19 +498,11 @@ def _advance_step(self, model_input: StatefulModelInput, attn_metadata = frozen_model_input.attn_metadata assert isinstance(attn_metadata, FlashAttentionMetadata) - attn_metadata.advance_step(num_seqs, num_queries) - - # Update GPU tensors - ops.advance_step( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=self.block_size, - input_tokens=frozen_model_input.input_tokens, - sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids, - input_positions=frozen_model_input.input_positions, - seq_lens=attn_metadata.seq_lens_tensor, - slot_mapping=attn_metadata.slot_mapping, - block_tables=attn_metadata.block_tables) + + attn_metadata.advance_step( + frozen_model_input, + model_input.cached_outputs[-1].sampled_token_ids, self.block_size, + num_seqs, num_queries) if frozen_model_input.seq_lens is not None: for i in range(num_queries):