-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm #3643
Merged
WoosukKwon
merged 26 commits into
vllm-project:main
from
ROCm:jpvillam/triton_upstream_integration
Apr 9, 2024
Merged
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
fb1d862
Add triton version of FA
948684e
Re-Enabled the param checks
906905f
Run yapf and ruffwq
80c7cab
Ununsed variable
734fce7
Re-ran formater
c91e5c3
Make variable code more clear and simplify attn selector
534e1f9
Logic mistake
5c4a993
Merge remote-tracking branch 'upstream/main' into jpvillam/triton_ups…
gshtras 11437bc
Review comments
320d7c4
File rename
b8919d1
File rename
ebba198
Run yapf and ruff
f0d1eeb
Extra bracket in dockerfile
540c20a
Merge remote-tracking branch 'upstream/main' into jpvillam/triton_ups…
80c936c
Merge remote-tracking branch 'origin/jpvillam/triton_upstream_integra…
215d15f
File rename to rocm fa
c6cfdff
Make a ROCM FA backend
f19a592
Merge remote-tracking branch 'upstream/main' into jpvillam/triton_ups…
a115f65
Merge branch 'main' into jpvillam/triton_upstream_integration
WoosukKwon e691488
Add kv_scale to ROCM flash attn
WoosukKwon 26457fa
Remove is_hip from FlashAttentionBackend
WoosukKwon 03679f4
Move naive attention to rocm flash attn
WoosukKwon 3d0db08
Refactor selector
WoosukKwon 0c9b668
yapf
WoosukKwon c939338
Fix use_naive
WoosukKwon 1238bc1
yapf
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
"""Attention layer with Flash and PagedAttention. | ||
|
||
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from | ||
XFormers backend. The duplicated code will be removed once we use flash-attn or | ||
flashinfer for all the attention operations. | ||
""" | ||
from typing import List, Optional, Type | ||
|
||
import torch | ||
|
||
from vllm.attention.backends.abstract import AttentionImpl | ||
from vllm.attention.backends.flash_attn import (FlashAttentionBackend, | ||
FlashAttentionMetadata) | ||
from vllm.attention.ops.flash_attention_triton import triton_attention | ||
from vllm.attention.ops.paged_attn import PagedAttention | ||
|
||
|
||
class FlashAttentionTritonBackend(FlashAttentionBackend): | ||
|
||
@staticmethod | ||
def get_impl_cls() -> Type["FlashAttentionTritonImpl"]: | ||
return FlashAttentionTritonImpl | ||
|
||
|
||
class FlashAttentionTritonImpl(AttentionImpl): | ||
jpvillam-amd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
If the input tensors contain prompt tokens, the layout is as follows: | ||
|<--------------- num_prompt_tokens -------------->| | ||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| | ||
|
||
Otherwise, the layout is as follows: | ||
|<------------------ num_generation_tokens (M) ----------------->| | ||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| | ||
|
||
Generation tokens can contain padding when cuda-graph is used. | ||
Currently, prompt tokens don't contain any padding. | ||
|
||
The prompts might have different lengths, while the generation tokens | ||
always have length 1. | ||
|
||
NOTE: This code is mostly duplicate from flash_attn backend | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.scale = float(scale) | ||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||
self.sliding_window = ((sliding_window, sliding_window) | ||
if sliding_window is not None else (-1, -1)) | ||
if alibi_slopes is not None: | ||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) | ||
self.alibi_slopes = alibi_slopes | ||
|
||
assert self.num_heads % self.num_kv_heads == 0 | ||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
||
suppored_head_sizes = PagedAttention.get_supported_head_sizes() | ||
if head_size not in suppored_head_sizes: | ||
raise ValueError( | ||
f"Head size {head_size} is not supported by PagedAttention. " | ||
f"Supported head sizes are: {suppored_head_sizes}.") | ||
|
||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)""" | ||
tokens, n_kv_heads, head_dim = x.shape | ||
return (x[:, :, | ||
None, :].expand(tokens, n_kv_heads, n_rep, | ||
head_dim).reshape(tokens, n_kv_heads * n_rep, | ||
head_dim)) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: FlashAttentionMetadata, | ||
) -> torch.Tensor: | ||
"""Forward pass with FlashAttention and PagedAttention. | ||
|
||
Args: | ||
query: shape = [num_tokens, num_heads * head_size] | ||
key: shape = [num_tokens, num_kv_heads * head_size] | ||
value: shape = [num_tokens, num_kv_heads * head_size] | ||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] | ||
attn_metadata: Metadata for attention. | ||
Returns: | ||
shape = [num_tokens, num_heads * head_size] | ||
""" | ||
num_tokens, hidden_size = query.shape | ||
# Reshape the query, key, and value tensors. | ||
query = query.view(-1, self.num_heads, self.head_size) | ||
key = key.view(-1, self.num_kv_heads, self.head_size) | ||
value = value.view(-1, self.num_kv_heads, self.head_size) | ||
|
||
if kv_cache is not None: | ||
key_cache, value_cache = PagedAttention.split_kv_cache( | ||
kv_cache, self.num_kv_heads, self.head_size) | ||
|
||
# Reshape the input keys and values and store them in the cache. | ||
# If kv_cache is not provided, the new key and value tensors are | ||
# not cached. This happens during the initial memory profiling run. | ||
PagedAttention.write_to_paged_cache( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
attn_metadata.kv_cache_dtype, | ||
) | ||
|
||
if attn_metadata.is_prompt: | ||
# Prompt run. | ||
if kv_cache is None or attn_metadata.block_tables.numel() == 0: | ||
# triton attention | ||
# When block_tables are not filled, it means q and k are the | ||
# prompt, and they have the same length. | ||
|
||
if self.num_kv_heads != self.num_heads: | ||
# Interleave for MQA workaround. | ||
key = self.repeat_kv(key, self.num_queries_per_kv) | ||
value = self.repeat_kv(value, self.num_queries_per_kv) | ||
|
||
output, _ = triton_attention( | ||
query, | ||
key, | ||
value, | ||
None, | ||
attn_metadata.seq_start_loc, | ||
attn_metadata.seq_start_loc, | ||
attn_metadata.max_prompt_len, | ||
attn_metadata.max_prompt_len, | ||
True, | ||
self.scale, | ||
) | ||
else: | ||
# prefix-enabled attention | ||
output = PagedAttention.forward_prefix( | ||
query, | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.block_tables, | ||
attn_metadata.subquery_start_loc, | ||
attn_metadata.prompt_lens_tensor, | ||
attn_metadata.context_lens, | ||
attn_metadata.max_subquery_len, | ||
self.alibi_slopes, | ||
) | ||
else: | ||
# Decoding run. | ||
output = PagedAttention.forward_decode( | ||
query, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.block_tables, | ||
attn_metadata.context_lens, | ||
attn_metadata.max_context_len, | ||
attn_metadata.kv_cache_dtype, | ||
self.num_kv_heads, | ||
self.scale, | ||
self.alibi_slopes, | ||
) | ||
|
||
# Reshape the output tensor. | ||
return output.view(num_tokens, hidden_size) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Actually in #3648, we are planning to use FlashAttention's recent APIs that support attention keys and values stored in paged KV cache. I believe the new APIs are incompatible with AMD GPUs.
Can we just use Triton FlashAttention at all times?