This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm (vllm-…
…project#3643) Co-authored-by: jpvillam <[email protected]> Co-authored-by: Gregory Shtrasberg <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]>
- Loading branch information
1 parent
ddf95b9
commit b2134e5
Showing
5 changed files
with
1,212 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,348 @@ | ||
"""Attention layer ROCm GPUs.""" | ||
import os | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Optional, Tuple, Type | ||
|
||
import torch | ||
|
||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
AttentionMetadata) | ||
from vllm.attention.ops.paged_attn import (PagedAttention, | ||
PagedAttentionMetadata) | ||
from vllm.logger import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class ROCmFlashAttentionBackend(AttentionBackend): | ||
|
||
@staticmethod | ||
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: | ||
return ROCmFlashAttentionImpl | ||
|
||
@staticmethod | ||
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": | ||
return ROCmFlashAttentionMetadata(*args, **kwargs) | ||
|
||
@staticmethod | ||
def get_kv_cache_shape( | ||
num_blocks: int, | ||
block_size: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
) -> Tuple[int, ...]: | ||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size, | ||
num_kv_heads, head_size) | ||
|
||
@staticmethod | ||
def swap_blocks( | ||
src_kv_cache: torch.Tensor, | ||
dst_kv_cache: torch.Tensor, | ||
src_to_dst: Dict[int, int], | ||
) -> None: | ||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) | ||
|
||
@staticmethod | ||
def copy_blocks( | ||
kv_caches: List[torch.Tensor], | ||
src_to_dists: Dict[int, List[int]], | ||
) -> None: | ||
PagedAttention.copy_blocks(kv_caches, src_to_dists) | ||
|
||
|
||
@dataclass | ||
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): | ||
"""Metadata for FlashAttentionBackend. | ||
NOTE: Any python object stored here is not updated when it is | ||
cuda-graph replayed. If you have values that need to be changed | ||
dynamically, it should be stored in tensor. The tensor has to be | ||
updated from `CUDAGraphRunner.forward` API. | ||
""" | ||
# Currently, input sequences can only contain all prompts | ||
# or all decoding. True if all sequences are prompts. | ||
is_prompt: bool | ||
# (batch_size,). The prompt length per sequence. None if it is a decoding. | ||
prompt_lens: Optional[List[int]] | ||
# prompt_lens stored as a tensor. | ||
prompt_lens_tensor: Optional[torch.Tensor] | ||
# The number of prompt tokens. Doesn't include padding. | ||
num_prompt_tokens: int | ||
# The number of generation tokens. Doesn't include padding. | ||
num_generation_tokens: int | ||
|
||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. | ||
# |---------- N-1 iteration --------| | ||
# |---------------- N iteration ---------------------| | ||
# |- tokenA -|......................|-- newTokens ---| | ||
# |---------- context_len ----------| | ||
# |-------------------- seqlen ----------------------| | ||
# |- subquery_len -| | ||
|
||
# WARNING(sang): context_len has different definition depending on if it is | ||
# prefill vs decoding. When it is prefill, it doesn't include new tokens. | ||
# When it is for decoding, it includes a new token. | ||
|
||
# Maximum subquery length in the batch. | ||
max_subquery_len: Optional[int] | ||
# Maximum prompt length in the batch. | ||
max_prompt_len: Optional[int] | ||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in | ||
# the batch, used to index into subquery. E.g., if the subquery length | ||
# is [4, 6], it is [0, 4, 10]. | ||
subquery_start_loc: Optional[torch.Tensor] | ||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in | ||
# the batch, used to index into sequence. E.g., if the sequence length is | ||
# [4, 6], it is [0, 4, 10]. | ||
seq_start_loc: Optional[torch.Tensor] | ||
|
||
# Whether or not if cuda graph is enabled. | ||
# Cuda-graph is currently enabled for decoding only. | ||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. | ||
use_cuda_graph: bool | ||
|
||
|
||
class ROCmFlashAttentionImpl(AttentionImpl): | ||
""" | ||
If the input tensors contain prompt tokens, the layout is as follows: | ||
|<--------------- num_prompt_tokens -------------->| | ||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| | ||
Otherwise, the layout is as follows: | ||
|<------------------ num_generation_tokens (M) ----------------->| | ||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| | ||
Generation tokens can contain padding when cuda-graph is used. | ||
Currently, prompt tokens don't contain any padding. | ||
The prompts might have different lengths, while the generation tokens | ||
always have length 1. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.scale = float(scale) | ||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||
self.sliding_window = ((sliding_window, sliding_window) | ||
if sliding_window is not None else (-1, -1)) | ||
if alibi_slopes is not None: | ||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) | ||
self.alibi_slopes = alibi_slopes | ||
|
||
assert self.num_heads % self.num_kv_heads == 0 | ||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
||
suppored_head_sizes = PagedAttention.get_supported_head_sizes() | ||
if head_size not in suppored_head_sizes: | ||
raise ValueError( | ||
f"Head size {head_size} is not supported by PagedAttention. " | ||
f"Supported head sizes are: {suppored_head_sizes}.") | ||
|
||
self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9 | ||
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. | ||
self.use_triton_flash_attn = (os.environ.get( | ||
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) | ||
if self.use_naive_attn: | ||
# AMD Radeon 7900 series (gfx1100) currently does not support | ||
# xFormers nor FlashAttention. As a temporary workaround, we use | ||
# naive PyTorch implementation of attention. | ||
self.attn_fuc = _naive_attention() | ||
logger.debug("Using naive attention in ROCmBackend") | ||
elif self.use_triton_flash_attn: | ||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 | ||
triton_attention) | ||
self.attn_func = triton_attention | ||
logger.debug("Using Triton FA in ROCmBackend") | ||
else: | ||
from flash_attn import flash_attn_varlen_func # noqa: F401 | ||
self.attn_func = flash_attn_varlen_func | ||
logger.debug("Using CK FA in ROCmBackend") | ||
|
||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)""" | ||
tokens, n_kv_heads, head_dim = x.shape | ||
return (x[:, :, | ||
None, :].expand(tokens, n_kv_heads, n_rep, | ||
head_dim).reshape(tokens, n_kv_heads * n_rep, | ||
head_dim)) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: ROCmFlashAttentionMetadata, | ||
kv_scale: float = 1.0, | ||
) -> torch.Tensor: | ||
"""Forward pass with FlashAttention and PagedAttention. | ||
Args: | ||
query: shape = [num_tokens, num_heads * head_size] | ||
key: shape = [num_tokens, num_kv_heads * head_size] | ||
value: shape = [num_tokens, num_kv_heads * head_size] | ||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] | ||
attn_metadata: Metadata for attention. | ||
Returns: | ||
shape = [num_tokens, num_heads * head_size] | ||
""" | ||
num_tokens, hidden_size = query.shape | ||
# Reshape the query, key, and value tensors. | ||
query = query.view(-1, self.num_heads, self.head_size) | ||
key = key.view(-1, self.num_kv_heads, self.head_size) | ||
value = value.view(-1, self.num_kv_heads, self.head_size) | ||
|
||
if kv_cache is not None: | ||
key_cache, value_cache = PagedAttention.split_kv_cache( | ||
kv_cache, self.num_kv_heads, self.head_size) | ||
|
||
# Reshape the input keys and values and store them in the cache. | ||
# If kv_cache is not provided, the new key and value tensors are | ||
# not cached. This happens during the initial memory profiling run. | ||
PagedAttention.write_to_paged_cache( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
attn_metadata.kv_cache_dtype, | ||
kv_scale, | ||
) | ||
|
||
if attn_metadata.is_prompt: | ||
# Prompt run. | ||
if kv_cache is None or attn_metadata.block_tables.numel() == 0: | ||
# triton attention | ||
# When block_tables are not filled, it means q and k are the | ||
# prompt, and they have the same length. | ||
if self.use_naive_attn or self.use_triton_flash_attn: | ||
if self.num_kv_heads != self.num_heads: | ||
# Interleave for MQA workaround. | ||
key = self.repeat_kv(key, self.num_queries_per_kv) | ||
value = self.repeat_kv(value, self.num_queries_per_kv) | ||
if self.use_naive_attn: | ||
output = self.attn_fuc( | ||
query, | ||
key, | ||
value, | ||
attn_metadata.prompt_lens, | ||
self.scale, | ||
) | ||
else: | ||
output, _ = self.attn_func( | ||
query, | ||
key, | ||
value, | ||
None, | ||
attn_metadata.seq_start_loc, | ||
attn_metadata.seq_start_loc, | ||
attn_metadata.max_prompt_len, | ||
attn_metadata.max_prompt_len, | ||
True, | ||
self.scale, | ||
) | ||
else: | ||
output = self.attn_func( | ||
q=query, | ||
k=key, | ||
v=value, | ||
cu_seqlens_q=attn_metadata.seq_start_loc, | ||
cu_seqlens_k=attn_metadata.seq_start_loc, | ||
max_seqlen_q=attn_metadata.max_prompt_len, | ||
max_seqlen_k=attn_metadata.max_prompt_len, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
) | ||
|
||
else: | ||
# prefix-enabled attention | ||
output = PagedAttention.forward_prefix( | ||
query, | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.block_tables, | ||
attn_metadata.subquery_start_loc, | ||
attn_metadata.prompt_lens_tensor, | ||
attn_metadata.context_lens, | ||
attn_metadata.max_subquery_len, | ||
self.alibi_slopes, | ||
) | ||
else: | ||
# Decoding run. | ||
output = PagedAttention.forward_decode( | ||
query, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.block_tables, | ||
attn_metadata.context_lens, | ||
attn_metadata.max_context_len, | ||
attn_metadata.kv_cache_dtype, | ||
self.num_kv_heads, | ||
self.scale, | ||
self.alibi_slopes, | ||
kv_scale, | ||
) | ||
|
||
# Reshape the output tensor. | ||
return output.view(num_tokens, hidden_size) | ||
|
||
|
||
def _naive_attention( | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
prompt_lens: List[int], | ||
scale: float, | ||
) -> torch.Tensor: | ||
num_tokens = query.shape[0] | ||
output = torch.empty_like(query) | ||
start = 0 | ||
for _, prompt_len in enumerate(prompt_lens): | ||
end = start + prompt_len | ||
out = _naive_masked_attention( | ||
query[None, start:end], | ||
key[None, start:end], | ||
value[None, start:end], | ||
scale, | ||
) | ||
# TODO(woosuk): Unnecessary copy. Optimize. | ||
output[start:end].copy_(out) | ||
start += prompt_len | ||
|
||
# Using view got RuntimeError: view size is not compatible | ||
# with input tensor's size and stride (at least one | ||
# dimension spans across two contiguous subspaces). | ||
# Use reshape instead. | ||
return output.reshape(num_tokens, -1) | ||
|
||
|
||
def _naive_masked_attention( | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
scale: float, | ||
) -> torch.Tensor: | ||
seq_len, _, _ = query.shape | ||
attn_mask = torch.triu(torch.ones(seq_len, | ||
seq_len, | ||
dtype=query.dtype, | ||
device=query.device), | ||
diagonal=1) | ||
attn_mask = attn_mask * torch.finfo(query.dtype).min | ||
|
||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() | ||
attn_weights = attn_weights + attn_mask.float() | ||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) | ||
out = torch.einsum("hqk,khd->qhd", attn_weights, value) | ||
return out |
Oops, something went wrong.