Skip to content

Commit

Permalink
[Core] Refactor _prepare_model_input_tensors - take 2 (vllm-project#6164
Browse files Browse the repository at this point in the history
)
  • Loading branch information
comaniac authored and gnpinkert committed Jul 26, 2024
1 parent 0360ae3 commit 724e4f4
Show file tree
Hide file tree
Showing 12 changed files with 1,050 additions and 470 deletions.
6 changes: 5 additions & 1 deletion tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from vllm.attention import AttentionMetadata
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
Expand All @@ -26,6 +26,10 @@ def get_impl_cls():
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata

@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise AttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataBuilder",
"Attention",
"get_attn_backend",
]
45 changes: 43 additions & 2 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)

import torch

if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase


class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
Expand Down Expand Up @@ -35,6 +39,16 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError

@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
Expand Down Expand Up @@ -110,6 +124,33 @@ def asdict_zerocopy(self,
T = TypeVar("T", bound=AttentionMetadata)


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

@abstractmethod
def __init__(self, input_builder) -> None:
raise NotImplementedError

@abstractmethod
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise NotImplementedError

@abstractmethod
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError


class AttentionImpl(ABC, Generic[T]):

@abstractmethod
Expand Down
11 changes: 11 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
Expand Down Expand Up @@ -93,6 +94,10 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata

@staticmethod
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -244,6 +249,12 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
return self._cached_decode_metadata


class BlocksparseFlashAttentionMetadataBuilder(
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):

_metadata_cls = BlocksparseFlashAttentionMetadata


class BlocksparseFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down
183 changes: 181 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)


class FlashAttentionBackend(AttentionBackend):
Expand All @@ -28,6 +39,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -184,6 +199,170 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
return self._cached_decode_metadata


class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)

if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors."""
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1

logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)


class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down
Loading

0 comments on commit 724e4f4

Please sign in to comment.