Skip to content

Commit

Permalink
[ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm (#3643)
Browse files Browse the repository at this point in the history
Co-authored-by: jpvillam <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
  • Loading branch information
4 people authored Apr 9, 2024
1 parent e23a43a commit 6c0b045
Show file tree
Hide file tree
Showing 5 changed files with 1,213 additions and 93 deletions.
14 changes: 14 additions & 0 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"

# whether to build triton on rocm
ARG BUILD_TRITON="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

Expand Down Expand Up @@ -75,6 +78,17 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi

# build triton
RUN if [ "$BUILD_TRITON" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCm/triton.git \
&& cd triton/python \
&& pip3 install . \
&& cd ../..; \
fi

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
Expand Down
348 changes: 348 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,348 @@
"""Attention layer ROCm GPUs."""
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger

logger = init_logger(__name__)


class ROCmFlashAttentionBackend(AttentionBackend):

@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
return ROCmFlashAttentionMetadata(*args, **kwargs)

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int

# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|

# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.

# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool


class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = (os.environ.get(
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
if self.use_naive_attn:
# AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention.
self.attn_fuc = _naive_attention()
logger.debug("Using naive attention in ROCmBackend")
elif self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
return (x[:, :,
None, :].expand(tokens, n_kv_heads, n_rep,
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale,
)

if attn_metadata.is_prompt:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_naive_attn or self.use_triton_flash_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
output = self.attn_fuc(
query,
key,
value,
attn_metadata.prompt_lens,
self.scale,
)
else:
output, _ = self.attn_func(
query,
key,
value,
None,
attn_metadata.seq_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.max_prompt_len,
attn_metadata.max_prompt_len,
True,
self.scale,
)
else:
output = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len,
softmax_scale=self.scale,
causal=True,
)

else:
# prefix-enabled attention
output = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)

# Reshape the output tensor.
return output.view(num_tokens, hidden_size)


def _naive_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
prompt_lens: List[int],
scale: float,
) -> torch.Tensor:
num_tokens = query.shape[0]
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(prompt_lens):
end = start + prompt_len
out = _naive_masked_attention(
query[None, start:end],
key[None, start:end],
value[None, start:end],
scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len

# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return output.reshape(num_tokens, -1)


def _naive_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min

attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
Loading

0 comments on commit 6c0b045

Please sign in to comment.