-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Refactor Attention Take 2 (#3462)
- Loading branch information
1 parent
b0dfa91
commit 925f333
Showing
47 changed files
with
1,269 additions
and
1,118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata | ||
from vllm.attention.layer import Attention | ||
from vllm.attention.selector import get_attn_backend | ||
|
||
__all__ = [ | ||
"AttentionBackend", | ||
"AttentionMetadata", | ||
"Attention", | ||
"get_attn_backend", | ||
] |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, fields | ||
from typing import Any, Dict, List, Optional, Tuple, Type | ||
|
||
import torch | ||
|
||
|
||
class AttentionBackend(ABC): | ||
"""Abstract class for attention backends.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_impl_cls() -> Type["AttentionImpl"]: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def make_metadata(*args, **kwargs) -> "AttentionMetadata": | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_kv_cache_shape( | ||
num_blocks: int, | ||
block_size: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
) -> Tuple[int, ...]: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def swap_blocks( | ||
src_kv_cache: torch.Tensor, | ||
dst_kv_cache: torch.Tensor, | ||
src_to_dst: Dict[int, int], | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def copy_blocks( | ||
kv_caches: List[torch.Tensor], | ||
src_to_dists: Dict[int, List[int]], | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
|
||
@dataclass | ||
class AttentionMetadata: | ||
|
||
def asdict_zerocopy(self) -> Dict[str, Any]: | ||
"""Similar to dataclasses.asdict, but avoids deepcopying.""" | ||
# Note that if we add dataclasses as fields, they will need | ||
# similar handling. | ||
return { | ||
field.name: getattr(self, field.name) | ||
for field in fields(self) | ||
} | ||
|
||
|
||
class AttentionImpl(ABC): | ||
|
||
@abstractmethod | ||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: AttentionMetadata, | ||
) -> torch.Tensor: | ||
raise NotImplementedError |
Oops, something went wrong.