Skip to content

Commit

Permalink
[Spec Decode] Move ops.advance_step to flash attn advance_step (vllm-…
Browse files Browse the repository at this point in the history
…project#8224)

Signed-off-by: Amit Garg <[email protected]>
  • Loading branch information
kevin314 authored and garg-amit committed Oct 28, 2024
1 parent ce22148 commit 49ae30c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 33 deletions.
21 changes: 15 additions & 6 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
Expand Down Expand Up @@ -306,14 +307,12 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self, num_seqs: int, num_queries: int):
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
"""
Update metadata in-place to advance one decode step.
"""
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.

# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
Expand Down Expand Up @@ -351,6 +350,16 @@ def advance_step(self, num_seqs: int, num_queries: int):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)

ops.advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)


class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
Expand Down
16 changes: 3 additions & 13 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.sampler import SamplerOutput

try:
Expand Down Expand Up @@ -116,18 +115,9 @@ def _gpu_advance_step(
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
attn_metadata.advance_step(num_seqs, num_queries)

# Update GPU tensors
ops.advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)

attn_metadata.advance_step(model_input, sampled_token_ids,
self.block_size, num_seqs, num_queries)

# Update sampling_metadata
sampling_metadata = model_input.sampling_metadata
Expand Down
19 changes: 5 additions & 14 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import torch

from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
Expand Down Expand Up @@ -499,19 +498,11 @@ def _advance_step(self, model_input: StatefulModelInput,

attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
attn_metadata.advance_step(num_seqs, num_queries)

# Update GPU tensors
ops.advance_step(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=frozen_model_input.input_tokens,
sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
input_positions=frozen_model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)

attn_metadata.advance_step(
frozen_model_input,
model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
num_seqs, num_queries)

if frozen_model_input.seq_lens is not None:
for i in range(num_queries):
Expand Down

0 comments on commit 49ae30c

Please sign in to comment.