Skip to content

Commit

Permalink
[Core] Refactor Attention Take 2 (#3462)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Mar 25, 2024
1 parent b0dfa91 commit 925f333
Show file tree
Hide file tree
Showing 47 changed files with 1,269 additions and 1,118 deletions.
3 changes: 1 addition & 2 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import time

import torch
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd)
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

Expand Down
7 changes: 7 additions & 0 deletions tests/samplers/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Run `pytest tests/samplers/test_beam_search.py --forked`.
"""
import gc

import pytest
import torch

# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
Expand Down Expand Up @@ -36,6 +39,10 @@ def test_beam_search_single_input(
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del vllm_model
# NOTE(woosuk): For some reason, the following GC is required to avoid
# GPU OOM errors in the following tests using `vllm_runner`.
gc.collect()
torch.cuda.empty_cache()

for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
Expand Down
60 changes: 30 additions & 30 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, input_metadata, return_prompt_lens, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens

# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is True
assert torch.allclose(input_metadata.prompt_lens_tensor,
assert attn_metadata.is_prompt is True
assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device))
assert input_metadata.prompt_lens == prompt_lens
assert input_metadata.num_prompt_tokens == sum(prompt_lens)
assert input_metadata.num_generation_tokens == 0
assert input_metadata.max_seq_len == max(prompt_lens)
assert attn_metadata.prompt_lens == prompt_lens
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
assert attn_metadata.num_generation_tokens == 0
assert attn_metadata.max_prompt_len == max(prompt_lens)

# Test subquery start locs.
start_idx = 0
Expand All @@ -55,7 +55,7 @@ def test_prepare_prompt(batch_size):
start_idx += prompt_len
start_loc.append(start_idx)
assert torch.allclose(
input_metadata.subquery_start_loc,
attn_metadata.subquery_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))

# Test seq start locs. Note that for normal prefill it is
Expand All @@ -67,22 +67,22 @@ def test_prepare_prompt(batch_size):
seq_start_loc.append(start_idx)

assert torch.allclose(
input_metadata.seq_start_loc,
attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
assert input_metadata.max_context_len is None
assert attn_metadata.max_context_len is None
assert torch.allclose(
input_metadata.context_lens,
torch.zeros(input_metadata.context_lens.shape[0],
attn_metadata.context_lens,
torch.zeros(attn_metadata.context_lens.shape[0],
dtype=torch.int,
device=device))

expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.allclose(input_metadata.block_tables, expected)
assert torch.allclose(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is False
assert input_metadata.kv_cache_dtype == "auto"
assert attn_metadata.use_cuda_graph is False
assert attn_metadata.kv_cache_dtype == "auto"

assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
Expand Down Expand Up @@ -140,34 +140,34 @@ def test_prepare_decode_cuda_graph(batch_size):
block_tables={0: [1]},
))

input_tokens, input_positions, input_metadata, _, _, _ = (
input_tokens, input_positions, attn_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))

expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == expected_bs
assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None
assert input_metadata.max_context_len == max(prompt_lens)
assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None
assert attn_metadata.num_prompt_tokens == 0
assert attn_metadata.num_generation_tokens == expected_bs
assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_context_len == max(prompt_lens)
assert torch.allclose(
input_metadata.context_lens[:len(prompt_lens)],
attn_metadata.context_lens[:len(prompt_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device))

# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert input_metadata.block_tables.shape[0] == len(input_tokens)
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number.
# It is padded up to
assert input_metadata.block_tables.shape[1] == (
assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto"
assert attn_metadata.use_cuda_graph is True
assert attn_metadata.kv_cache_dtype == "auto"

assert input_tokens.shape == (expected_bs, )
assert input_positions.shape == (expected_bs, )
Expand Down
10 changes: 10 additions & 0 deletions vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

__all__ = [
"AttentionBackend",
"AttentionMetadata",
"Attention",
"get_attn_backend",
]
File renamed without changes.
85 changes: 85 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple, Type

import torch


class AttentionBackend(ABC):
"""Abstract class for attention backends."""

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError

@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError

@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError

@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError

@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError


@dataclass
class AttentionMetadata:

def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
}


class AttentionImpl(ABC):

@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
raise NotImplementedError

@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
raise NotImplementedError
Loading

0 comments on commit 925f333

Please sign in to comment.