Skip to content

Commit

Permalink
prefill only attention
Browse files Browse the repository at this point in the history
  • Loading branch information
noooop committed Sep 30, 2024
1 parent 7b26033 commit 250be85
Show file tree
Hide file tree
Showing 21 changed files with 1,326 additions and 0 deletions.
Empty file.
Empty file.
Empty file.
77 changes: 77 additions & 0 deletions tests/wde/prefill_only/layers/attention/test_basic_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import itertools as it

import pytest
import torch

from tests.wde.utils import compare_embeddings
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.wde.core.layers.attention import Attention
from vllm.wde.prefill_only.layers.attention.selector import (AttentionImpls,
AttentionType,
AttnBackend,
_Backend)

SEQ_LENS = [1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29]


@pytest.mark.parametrize("head_dim", [64])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4, 8])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"])
@pytest.mark.parametrize("n_seqs", list(range(1, len(SEQ_LENS))))
def test_basic_correctness(head_dim: int, num_heads: int, num_kv_heads: int,
attn_type: str, dtype: str, n_seqs: int):
assert num_heads % num_kv_heads == 0

torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]

attention_impls = AttentionImpls[dtype]

seq_lens = SEQ_LENS[:n_seqs]
batchsize = sum(seq_lens)

query = torch.rand((batchsize, num_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))
key = torch.rand((batchsize, num_kv_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))
value = torch.rand((batchsize, num_kv_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))

impl_outputs_list = []

for attention_impl in attention_impls:
selected_backend = _Backend.backend_name_to_enum(attention_impl)
backend_cls = AttnBackend.get_backend_cls(selected_backend)

attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type)

attn_backend = backend_cls(attn_type_enum)
scaling = head_dim**-0.5

attn = Attention(num_heads,
head_dim,
scale=scaling,
num_kv_heads=num_kv_heads,
attn_backend=attn_backend)

metadata_builder = attn_backend.make_metadata_builder()
attn_metadata = metadata_builder(seq_lens=seq_lens)
attn_metadata = attn_metadata.to("cuda:0")

outputs = attn.forward(query, key, value, attn_metadata)

impl_outputs_list.append((attention_impl, outputs))

tolerance = 1e-2
for a, b in it.combinations(impl_outputs_list, 2):
similarities = compare_embeddings(a[1], b[1])
all_similarities = torch.stack(similarities)

assert torch.all(
(all_similarities <= 1.0 + tolerance)
& (all_similarities >= 1.0 - tolerance)
), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0"
56 changes: 56 additions & 0 deletions tests/wde/prefill_only/layers/attention/test_enum_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from vllm.wde.prefill_only.layers.attention.backends.abstract import (
PrefillOnlyAttentionBackend)
from vllm.wde.prefill_only.layers.attention.selector import (AttentionImpls,
AttentionType,
AttnBackend,
_Backend)


def get_attn_backend(attention_impl: str, attn_type: str):
selected_backend = _Backend.backend_name_to_enum(attention_impl)
backend_cls = AttnBackend.get_backend_cls(selected_backend)

attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type)

attn_backend = backend_cls(attn_type_enum)
return attn_backend


@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
def test_backend(dtype: str, attn_type: str):
attention_impls = AttentionImpls[dtype]

for attention_impl in attention_impls:
attn_backend = get_attn_backend(attention_impl, attn_type)

assert isinstance(attn_backend, PrefillOnlyAttentionBackend)


@pytest.mark.parametrize("attn_type", ["ENCODER_DECODER"])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
def test_ENCODER_DECODER_not_supported(dtype: str, attn_type: str):
attention_impls = AttentionImpls[dtype]

for attention_impl in attention_impls:
with pytest.raises(NotImplementedError):
get_attn_backend(attention_impl, attn_type)


def test_not_supported_backend():
attention_impls = ["not_supported_backend", 0, 1.0]

for attention_impl in attention_impls:
with pytest.raises(ValueError):
selected_backend = _Backend.backend_name_to_enum(attention_impl)
AttnBackend.get_backend_cls(selected_backend)


def test_not_supported_attn_type():
attn_types = ["not_supported_attn_type", 0, 1.0]

for attn_type in attn_types:
with pytest.raises(ValueError):
AttentionType.attn_type_name_to_enum(attn_type)
9 changes: 9 additions & 0 deletions tests/wde/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch.nn.functional as F


def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(e1, e2, dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities
Empty file.
8 changes: 8 additions & 0 deletions vllm/wde/core/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from vllm.wde.core.layers.attention.abstract import (AttentionBackend,
AttentionMetadata,
AttentionType)
from vllm.wde.core.layers.attention.layer import Attention

__all__ = [
"Attention", "AttentionMetadata", "AttentionBackend", "AttentionType"
]
124 changes: 124 additions & 0 deletions vllm/wde/core/layers/attention/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar

import torch


class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
ENCODER = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V

@staticmethod
def attn_type_name_to_enum(attn_type: str) -> "AttentionType":
assert attn_type is not None

attn_type_members = AttentionType.__members__
if attn_type not in attn_type_members:
raise ValueError(
f"Invalid attn_type '{attn_type}'. "
f"Available backends: {', '.join(attn_type_members)} "
"(case-sensitive).")

return AttentionType[attn_type]


class AttentionBackend(ABC):
"""Abstract class for attention backends."""

def __init__(self, attn_type: AttentionType):
self._attn_type = attn_type

@property
def attn_type(self) -> AttentionType:
return self._attn_type

@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError

@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)


@dataclass
class AttentionMetadata:
pass

def to(self, device, non_blocking=False):
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.to(device, non_blocking=non_blocking)

return self


T = TypeVar("T", bound=AttentionMetadata)


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

@abstractmethod
def __init__(self) -> None:
raise NotImplementedError

@abstractmethod
def __call__(self, *args, **kwargs) -> T:
raise NotImplementedError


class AttentionImpl(ABC, Generic[T]):

@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
raise NotImplementedError

@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_cache: Optional[torch.Tensor] = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.ENCODER,
) -> torch.Tensor:
raise NotImplementedError
101 changes: 101 additions & 0 deletions vllm/wde/core/layers/attention/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Attention layer."""
from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.wde.core.layers.attention.abstract import AttentionBackend


class Attention(nn.Module):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
attn_backend: AttentionBackend,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
prefix: str = "",
) -> None:
super().__init__()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
# block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
else:
kv_cache_dtype = "auto"
# block_size = 16
sliding_window = None
if num_kv_heads is None:
num_kv_heads = num_heads

# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._k_scale = 1.0
self._v_scale = 1.0
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)

impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
self.attn_type = attn_backend.attn_type

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:

return self.impl.forward(query, key, value, attn_metadata, kv_cache,
self._k_scale, self._v_scale, self.attn_type)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}"
return s
Empty file.
Empty file.
1 change: 1 addition & 0 deletions vllm/wde/prefill_only/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Empty file.
Loading

0 comments on commit 250be85

Please sign in to comment.