From fb1d8623ae35d326ee049ea4a84c015fd3cc8604 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Mon, 25 Mar 2024 23:33:59 +0000 Subject: [PATCH 01/21] Add triton version of FA --- Dockerfile.rocm | 14 + vllm/attention/backends/flash_attn.py | 41 +- vllm/attention/backends/flash_attn_triton.py | 176 ++++++ vllm/attention/ops/flash_attention_triton.py | 557 +++++++++++++++++++ vllm/attention/selector.py | 42 +- 5 files changed, 806 insertions(+), 24 deletions(-) create mode 100644 vllm/attention/backends/flash_attn_triton.py create mode 100644 vllm/attention/ops/flash_attention_triton.py diff --git a/Dockerfile.rocm b/Dockerfile.rocm index a45265d79a6ac..e7f52307a6aa2 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -26,6 +26,9 @@ ARG BUILD_FA="1" # whether to build cupy on rocm ARG BUILD_CUPY="1" +# whether to build triton on rocm +ARG BUILD_TRITON="1" + # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y @@ -95,6 +98,17 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \ && cd ..; \ fi +# build triton +RUN if [ "$BUILD_TRITON" = "1" ]; then \ + mkdir -p libs \ + && cd libs \ + && pip uninstall -y triton \ + && git clone https://github.com/ROCmSoftwarePlatform/triton.git \ + && cd triton/python \ + && pip3 install . \ + && cd ../..; \ + fi + COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e50d52377b8e0..0de55bbd54d54 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,6 +14,7 @@ AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) +from vllm.utils import is_hip class FlashAttentionBackend(AttentionBackend): @@ -192,19 +193,33 @@ def forward( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + if is_hip(): + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + ) + else: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: # prefix-enabled attention output = PagedAttention.forward_prefix( diff --git a/vllm/attention/backends/flash_attn_triton.py b/vllm/attention/backends/flash_attn_triton.py new file mode 100644 index 0000000000000..2b4761867446f --- /dev/null +++ b/vllm/attention/backends/flash_attn_triton.py @@ -0,0 +1,176 @@ +"""Attention layer with Flash and PagedAttention. + +NOTE(woosuk): At the moment, this file includes a lot of duplicated code from +XFormers backend. The duplicated code will be removed once we use flash-attn or +flashinfer for all the attention operations. +""" +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.attention.ops.flash_attention_triton import triton_attention + + +class FlashAttentionTritonBackend(FlashAttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionTritonImpl + + +class FlashAttentionTritonImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + NOTE: This code is mostly duplicate from flash_attn backend + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return ( + x[:, :, None, :] + .expand(tokens, n_kv_heads, n_rep, head_dim) + .reshape(tokens, n_kv_heads * n_rep, head_dim) + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype) + + if attn_metadata.is_prompt: + # Prompt run. + if kv_cache is None or attn_metadata.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + + output, _ = triton_attention( + query, + key, + value, + None, + attn_metadata.seq_start_loc, + attn_metadata.seq_start_loc, + attn_metadata.max_prompt_len, + attn_metadata.max_prompt_len, + True, + self.scale, + ) + else: + # prefix-enabled attention + output = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.subquery_start_loc, + attn_metadata.prompt_lens_tensor, + attn_metadata.context_lens, + attn_metadata.max_subquery_len, + self.alibi_slopes, + ) + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/ops/flash_attention_triton.py b/vllm/attention/ops/flash_attention_triton.py new file mode 100644 index 0000000000000..cdc87ac7ceb76 --- /dev/null +++ b/vllm/attention/ops/flash_attention_triton.py @@ -0,0 +1,557 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype:tl.constexpr = torch.float16 + +@triton.jit +def cdiv_fn(x,y): + return (x + y - 1) // y + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0,1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + +@triton.jit +def _attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr +): + # loop over k, v, and update accumulator + for start_n in range (block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None,:] + mask = size_n < boundary_m[:,None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + m_ij = tl.maximum(m_i, tl.max(qk,1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + ], + key=['hq', 'hk', 'IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], +) +@triton.jit +def attn_fwd( + Q, K, V, bias, sm_scale, L, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + stride_bz, stride_bh, stride_bm, stride_bn, + cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, + hq, hk, + ACTUAL_BLOCK_DMODEL:tl.constexpr, + MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, + BLOCK_N + ) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + #tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) + #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + #l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + #tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + is_mqa = hq != hk + off_h_k = off_h_q % hk if is_mqa else off_h_q + need_padding = False + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + need_padding = True + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + need_padding = True + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0) + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, bias_ptr, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks*BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks*BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, + IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + #overflow_size = end_m_idx - seqlen_q + #if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + #else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0,1)) + +def check_args(q, k, v, o, max_seqlens): + pass + assert q.dim() == k.dim() and q.dim() == v.dim() + + if self.varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias == None + # TODO:Remove once dropout is supported with varlen + assert self.dropout_p == 0.0 + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + # TODO: Fix assert to check head size <=256 once supported + assert head_size <= 128 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, causal=False, sm_scale=1.0, bias=None): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + #check_args(q, k, v, o, metadata.max_seq_len) + if True: #varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META['BLOCK_M']), + nheads_q, + batch + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = (bias.stride(0), bias.stride(1), + bias.stride(2), bias.stride(3)) + else: + bias_strides = (0,0,0,0) + + attn_fwd[grid]( + q, k, v, bias, sm_scale, None, o, + *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, + cu_seqlens_q, cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + hq=nheads_q, hk=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + +triton_attention = _attention.apply diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 90fce1a0349b2..907255a00af9a 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,6 +1,8 @@ from functools import lru_cache import torch +import os + from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger @@ -11,11 +13,16 @@ @lru_cache(maxsize=None) def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: - if _can_use_flash_attn(dtype): + if _which_attn_to_use(dtype) == "FlashAttention": logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend + elif _which_attn_to_use(dtype) == "FlashAttentionTriton": + logger.info("Using FlashAttentionTriton backend.") + from vllm.attention.backends.flash_attn_triton import ( # noqa: F401 + FlashAttentionTritonBackend) + return FlashAttentionTritonBackend else: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -23,24 +30,37 @@ def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: return XFormersBackend -def _can_use_flash_attn(dtype: torch.dtype) -> bool: - if is_hip(): - # AMD GPUs. - logger.info("Cannot use FlashAttention backend for AMD GPUs.") - return False - if torch.cuda.get_device_capability()[0] < 8: +def _which_attn_to_use(dtype: torch.dtype) -> str: + """Returns if and which flash attention to use. + + Returns: + int: 0 for Xformers, 1 for default implementation, 2 for triton implementation. + """ + + use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', "True").lower() in ("true", "1") + if not is_hip() and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " "GPUs.") - return False + return "XFormers" + + if is_hip() and torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_atten is not supported on NAVI GPUs. " + "Using xformers backend.") + return "XFormers" + if dtype not in (torch.float16, torch.bfloat16): logger.info("Cannot use FlashAttention backend for dtype other than " "torch.float16 or torch.bfloat16.") - return False + return "XFormers" try: import flash_attn # noqa: F401 except ImportError: logger.info("flash_attn is not found.") - return False - return True + if is_hip() and use_flash_attn_triton: + pass + else: + return "XFormers" + return "FlashAttentionTriton" if use_flash_attn_triton else "FlashAttention" From 948684e8bfb6ca780ae708d0376baf675c597eb9 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 26 Mar 2024 16:07:42 +0000 Subject: [PATCH 02/21] Re-Enabled the param checks --- vllm/attention/ops/flash_attention_triton.py | 20 +++++++------------- vllm/attention/selector.py | 7 ++++--- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/vllm/attention/ops/flash_attention_triton.py b/vllm/attention/ops/flash_attention_triton.py index cdc87ac7ceb76..af3f3dab3abe3 100644 --- a/vllm/attention/ops/flash_attention_triton.py +++ b/vllm/attention/ops/flash_attention_triton.py @@ -440,22 +440,15 @@ def attn_fwd( # TODO: Do the boundary check optionally. tl.store(O_block_ptr, acc, boundary_check=(0,1)) -def check_args(q, k, v, o, max_seqlens): - pass +def check_args(q, k, v, o, varlen=True, max_seqlens=None, cu_seqlens_q=None, cu_seqlens_k=None): assert q.dim() == k.dim() and q.dim() == v.dim() - - if self.varlen: + if varlen: assert q.dim() == 3 total_q, nheads_q, head_size = q.shape total_k, nheads_k, _ = k.shape - assert self.cu_seqlens_q is not None - assert self.cu_seqlens_k is not None - assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) - # TODO: Remove once bias is supported with varlen - assert self.bias == None - # TODO:Remove once dropout is supported with varlen - assert self.dropout_p == 0.0 - assert not self.return_encoded_softmax + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) else: assert q.dim() == 4 batch, nheads_q, seqlen_q, head_size = q.shape @@ -475,7 +468,8 @@ class _attention(torch.autograd.Function): def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, causal=False, sm_scale=1.0, bias=None): if o is None: o = torch.empty_like(q, dtype=v.dtype) - #check_args(q, k, v, o, metadata.max_seq_len) + + check_args(q, k, v, o, varlen=True, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k) if True: #varlen total_q, nheads_q, head_size = q.shape total_k, nheads_k, _ = k.shape diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 907255a00af9a..03a2c142c4fae 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -31,12 +31,13 @@ def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: def _which_attn_to_use(dtype: torch.dtype) -> str: - """Returns if and which flash attention to use. + """Returns which flash attention backend to use. Returns: - int: 0 for Xformers, 1 for default implementation, 2 for triton implementation. + str: XFormers, FlashAttention, or FlashAttentionTriton """ - + + # NOTE: Defaulting to triton FA for AMD cards. use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', "True").lower() in ("true", "1") if not is_hip() and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. From 906905f32ed9048d23b5c7e0ed400fe5f3139212 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 26 Mar 2024 18:55:17 +0000 Subject: [PATCH 03/21] Run yapf and ruffwq --- vllm/attention/backends/flash_attn_triton.py | 45 +- vllm/attention/ops/flash_attention_triton.py | 553 ++++++++++++++----- vllm/attention/selector.py | 8 +- 3 files changed, 433 insertions(+), 173 deletions(-) diff --git a/vllm/attention/backends/flash_attn_triton.py b/vllm/attention/backends/flash_attn_triton.py index 2b4761867446f..bb349db9cb1a9 100644 --- a/vllm/attention/backends/flash_attn_triton.py +++ b/vllm/attention/backends/flash_attn_triton.py @@ -4,35 +4,35 @@ XFormers backend. The duplicated code will be removed once we use flash-attn or flashinfer for all the attention operations. """ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Type import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) -from vllm.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionMetadata) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) +from vllm.attention.backends.abstract import ( + AttentionImpl, ) +from vllm.attention.backends.flash_attn import ( + FlashAttentionBackend, + FlashAttentionMetadata, +) +from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.flash_attention_triton import triton_attention class FlashAttentionTritonBackend(FlashAttentionBackend): @staticmethod - def get_impl_cls() -> Type["FlashAttentionImpl"]: + def get_impl_cls() -> Type["FlashAttentionTritonImpl"]: return FlashAttentionTritonImpl class FlashAttentionTritonImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| + |<--------------- num_prompt_tokens -------------->| |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. @@ -75,11 +75,10 @@ def __init__( def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" tokens, n_kv_heads, head_dim = x.shape - return ( - x[:, :, None, :] - .expand(tokens, n_kv_heads, n_rep, head_dim) - .reshape(tokens, n_kv_heads * n_rep, head_dim) - ) + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) def forward( self, @@ -113,10 +112,14 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype) + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype, + ) if attn_metadata.is_prompt: # Prompt run. diff --git a/vllm/attention/ops/flash_attention_triton.py b/vllm/attention/ops/flash_attention_triton.py index af3f3dab3abe3..3cb92637ea8f2 100644 --- a/vllm/attention/ops/flash_attention_triton.py +++ b/vllm/attention/ops/flash_attention_triton.py @@ -3,7 +3,8 @@ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) Credits: OpenAI kernel team, AMD ML Frameworks Triton team Features supported: @@ -23,57 +24,71 @@ import triton import triton.language as tl -torch_dtype:tl.constexpr = torch.float16 +torch_dtype: tl.constexpr = torch.float16 + @triton.jit -def cdiv_fn(x,y): +def cdiv_fn(x, y): return (x + y - 1) // y + @triton.jit def max_fn(x, y): return tl.math.max(x, y) + @triton.jit def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): ms = tl.arange(0, m) ns = tl.arange(0, n) return philox_offset + ms[:, None] * stride + ns[None, :] + @triton.jit def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) + @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) rng_keep = rng_output > dropout_p return rng_keep + @triton.jit def load_fn(block_ptr, first, second, pad): if first and second: - tensor = tl.load(block_ptr, boundary_check=(0,1), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) elif first: - tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) elif second: - tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) else: tensor = tl.load(block_ptr) return tensor + @triton.jit def _attn_fwd_inner( - acc, l_i, m_i, q, - K_block_ptr, V_block_ptr, + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, start_m, actual_seqlen_k, dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, - block_min, block_max, + block_min, + block_max, offs_n_causal, masked_blocks, n_extra_tokens, @@ -88,29 +103,41 @@ def _attn_fwd_inner( MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr + PADDED_HEAD: tl.constexpr, ): # loop over k, v, and update accumulator - for start_n in range (block_min, block_max, BLOCK_N): + for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) if PRE_LOAD_V: - v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: + if MASK_STEPS: # noqa: SIM102 # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None,:] - mask = size_n < boundary_m[:,None] + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -119,30 +146,52 @@ def _attn_fwd_inner( # -- compute qk ---- qk += tl.dot(q, k) if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") - # While bias is added after multiplying qk with sm_scale, - # our optimization to use 2^x instead of e^x results in an additional + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. - qk += (bias * 1.44269504089) - m_ij = tl.maximum(m_i, tl.max(qk,1)) + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N - keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) if RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i @@ -153,43 +202,148 @@ def _attn_fwd_inner( if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) return acc, l_i, m_i + @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - ], - key=['hq', 'hk', 'IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], ) @triton.jit def attn_fwd( - Q, K, V, bias, sm_scale, L, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - stride_bz, stride_bh, stride_bm, stride_bn, - cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, - hq, hk, - ACTUAL_BLOCK_DMODEL:tl.constexpr, - MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + hq, + hk, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, ): start_m = tl.program_id(0) off_h_q = tl.program_id(1) @@ -220,81 +374,85 @@ def attn_fwd( # This block of code determines what N is, and if this WG is operating # on those M rows. n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if (IS_CAUSAL): + if IS_CAUSAL: # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn matrix + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, - BLOCK_N - ) + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) + order=(1, 0), ) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) # We still need to write 0s to the result - #tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) - #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # We store inf to LSE, not -inf because in the bwd pass, we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. - #l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - #tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here too? + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? return is_mqa = hq != hk off_h_k = off_h_q % hk if is_mqa else off_h_q - need_padding = False n_extra_tokens = 0 if seqlen_k < BLOCK_N: - need_padding = True n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: - need_padding = True n_extra_tokens = seqlen_k % BLOCK_N - padded_head = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL # Compute pointers for all the tensors used in this kernel. - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) + order=(1, 0), ) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) K_block_ptr = tl.make_block_ptr( base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) + order=(0, 1), ) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) V_block_ptr = tl.make_block_ptr( base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0) + order=(1, 0), ) if BIAS_TYPE != 0: bias_ptr = tl.make_block_ptr( @@ -308,21 +466,23 @@ def attn_fwd( else: bias_ptr = None if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + batch_philox_offset = philox_offset_base + off_z * hq \ + + off_h_q * seqlen_q * seqlen_k else: batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. In - # this case, we return an invalid pointer so indicate the mask is not valid. + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. if RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0) - ) + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) else: encoded_softmax_block_ptr = 0 # initialize pointer to m and l @@ -346,52 +506,95 @@ def attn_fwd( else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual + # Compute for full blocks. Here we set causal to false regardless of its # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N acc, l_i, m_i = _attn_fwd_inner( - acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, bias_ptr, + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, # IS_CAUSAL, .... - False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, ) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks*BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks*BLOCK_N, 0)) + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N)) + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) if RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks)) acc, l_i, m_i = _attn_fwd_inner( - acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, - IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, ) # epilogue acc = acc / l_i[:, None] @@ -405,42 +608,57 @@ def attn_fwd( start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: + if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + out_ptrs_mask = (mask_m_offsets[:, None] >= + out_mask_boundary[None, :]) z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve - #overflow_size = end_m_idx - seqlen_q - #if overflow_size > 0: + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) # # This is a > check because mask being 0 blocks the store. # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - #else: + # else: # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) + order=(1, 0), ) - # Need boundary check on this to make sure the padding from the + # Need boundary check on this to make sure the padding from the # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0,1)) - -def check_args(q, k, v, o, varlen=True, max_seqlens=None, cu_seqlens_q=None, cu_seqlens_k=None): + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): assert q.dim() == k.dim() and q.dim() == v.dim() if varlen: assert q.dim() == 3 @@ -463,14 +681,37 @@ def check_args(q, k, v, o, varlen=True, max_seqlens=None, cu_seqlens_q=None, cu_ assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 + class _attention(torch.autograd.Function): + @staticmethod - def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, causal=False, sm_scale=1.0, bias=None): + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): if o is None: o = torch.empty_like(q, dtype=v.dtype) - check_args(q, k, v, o, varlen=True, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k) - if True: #varlen + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen total_q, nheads_q, head_size = q.shape total_k, nheads_k, _ = k.shape batch = len(cu_seqlens_q) - 1 @@ -498,11 +739,10 @@ def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seql else: padded_d_model = head_size - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META['BLOCK_M']), + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, - batch + batch, ) encoded_softmax = None @@ -512,20 +752,36 @@ def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seql philox_offset = 0x1D4B42 if bias is not None: - bias_strides = (bias.stride(0), bias.stride(1), - bias.stride(2), bias.stride(3)) + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) else: - bias_strides = (0,0,0,0) + bias_strides = (0, 0, 0, 0) attn_fwd[grid]( - q, k, v, bias, sm_scale, None, o, - *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, - cu_seqlens_q, cu_seqlens_k, + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - hq=nheads_q, hk=nheads_k, + hq=nheads_q, + hk=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, @@ -534,7 +790,7 @@ def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seql BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if bias is None else 1, ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False + RETURN_ENCODED_SOFTMAX=False, ) ctx.grid = grid @@ -548,4 +804,5 @@ def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seql ctx.return_encoded_softmax = False return o, encoded_softmax + triton_attention = _attention.apply diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 03a2c142c4fae..19a33459cecfe 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -3,7 +3,6 @@ import torch import os - from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils import is_hip @@ -36,9 +35,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: Returns: str: XFormers, FlashAttention, or FlashAttentionTriton """ - + # NOTE: Defaulting to triton FA for AMD cards. - use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', "True").lower() in ("true", "1") + use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', + "True").lower() in ("true", "1") if not is_hip() and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " @@ -63,5 +63,5 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: if is_hip() and use_flash_attn_triton: pass else: - return "XFormers" + return "XFormers" return "FlashAttentionTriton" if use_flash_attn_triton else "FlashAttention" From 80c7cabb2bd16834bb5197200fd61e14913facaf Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 26 Mar 2024 19:05:09 +0000 Subject: [PATCH 04/21] Ununsed variable --- vllm/attention/ops/flash_attention_triton.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/flash_attention_triton.py b/vllm/attention/ops/flash_attention_triton.py index 3cb92637ea8f2..b86e845020b07 100644 --- a/vllm/attention/ops/flash_attention_triton.py +++ b/vllm/attention/ops/flash_attention_triton.py @@ -466,8 +466,9 @@ def attn_fwd( else: bias_ptr = None if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + off_z * hq \ - + off_h_q * seqlen_q * seqlen_k + batch_philox_offset = philox_offset_base \ + + (off_z * hq + off_h_q) \ + * seqlen_q * seqlen_k else: batch_philox_offset = 0 # We can ask to return the dropout mask without actually doing any dropout. From 734fce7285cff6e665b561096eda9ac366029b82 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 26 Mar 2024 19:25:16 +0000 Subject: [PATCH 05/21] Re-ran formater --- vllm/attention/backends/flash_attn_triton.py | 11 ++++------- vllm/attention/selector.py | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/flash_attn_triton.py b/vllm/attention/backends/flash_attn_triton.py index bb349db9cb1a9..3a136d0a2476b 100644 --- a/vllm/attention/backends/flash_attn_triton.py +++ b/vllm/attention/backends/flash_attn_triton.py @@ -8,14 +8,11 @@ import torch -from vllm.attention.backends.abstract import ( - AttentionImpl, ) -from vllm.attention.backends.flash_attn import ( - FlashAttentionBackend, - FlashAttentionMetadata, -) -from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.backends.abstract import AttentionImpl +from vllm.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionMetadata) from vllm.attention.ops.flash_attention_triton import triton_attention +from vllm.attention.ops.paged_attn import PagedAttention class FlashAttentionTritonBackend(FlashAttentionBackend): diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 19a33459cecfe..8f25adfbf7f1f 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,7 +1,7 @@ +import os from functools import lru_cache import torch -import os from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger From c91e5c35746f57336b68a089eb1e77ad66f5f138 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 26 Mar 2024 21:04:56 +0000 Subject: [PATCH 06/21] Make variable code more clear and simplify attn selector --- vllm/attention/selector.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 8f25adfbf7f1f..bd4d2e9c98506 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -36,9 +36,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: str: XFormers, FlashAttention, or FlashAttentionTriton """ - # NOTE: Defaulting to triton FA for AMD cards. - use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', - "True").lower() in ("true", "1") + # NOTE: Allow for switching between Triton and FA + # Defaulting to triton FA for AMD cards. + use_flash_attn_triton = (os.environ.get("VLLM_USE_FLASH_ATTN_TRITON", + "True").lower() + in ("true", "1")) and not is_hip() + if not is_hip() and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " @@ -56,12 +59,11 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: "torch.float16 or torch.bfloat16.") return "XFormers" - try: - import flash_attn # noqa: F401 - except ImportError: - logger.info("flash_attn is not found.") - if is_hip() and use_flash_attn_triton: - pass - else: + if not use_flash_attn_triton: + # Only test for flash_attn if we are using it. + try: + import flash_attn # noqa: F401 + except ImportError: return "XFormers" + return "FlashAttentionTriton" if use_flash_attn_triton else "FlashAttention" From 534e1f941395bb2bc3c4f8004ff739b0ce5fdc55 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 26 Mar 2024 22:58:11 +0000 Subject: [PATCH 07/21] Logic mistake --- vllm/attention/selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index bd4d2e9c98506..726c16e362137 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -40,7 +40,7 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: # Defaulting to triton FA for AMD cards. use_flash_attn_triton = (os.environ.get("VLLM_USE_FLASH_ATTN_TRITON", "True").lower() - in ("true", "1")) and not is_hip() + in ("true", "1")) and is_hip() if not is_hip() and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. From 11437bc89f4c5d1b3f37e7296724f2722f21e397 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Thu, 28 Mar 2024 21:19:32 +0000 Subject: [PATCH 08/21] Review comments Rename backend --- Dockerfile.rocm | 2 +- requirements-rocm.txt | 2 +- vllm/attention/backends/flash_attn_triton.py | 8 ++++---- vllm/attention/selector.py | 18 +++++++++--------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 37c24ab563f3f..e9ba74181c6ea 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -104,6 +104,6 @@ RUN cd /app \ && cd .. RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir ray[all] +RUN python3 -m pip install --no-cache-dir ray[all]]==2.9.3 CMD ["/bin/bash"] diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 6acf70695cef8..4883c44280a2e 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -3,7 +3,7 @@ ninja # For faster builds. typing-extensions>=4.8.0 starlette psutil -ray >= 2.9 +ray == 2.9.3 sentencepiece # Required for LLaMA tokenizer. numpy tokenizers>=0.15.0 diff --git a/vllm/attention/backends/flash_attn_triton.py b/vllm/attention/backends/flash_attn_triton.py index 3a136d0a2476b..a9dc57e74d4a4 100644 --- a/vllm/attention/backends/flash_attn_triton.py +++ b/vllm/attention/backends/flash_attn_triton.py @@ -15,14 +15,14 @@ from vllm.attention.ops.paged_attn import PagedAttention -class FlashAttentionTritonBackend(FlashAttentionBackend): +class TritonFlashAttentionBackend(FlashAttentionBackend): @staticmethod - def get_impl_cls() -> Type["FlashAttentionTritonImpl"]: - return FlashAttentionTritonImpl + def get_impl_cls() -> Type["TritonFlashAttentionImpl"]: + return TritonFlashAttentionImpl -class FlashAttentionTritonImpl(AttentionImpl): +class TritonFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prompt_tokens -------------->| diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 726c16e362137..bd40a2879d4e6 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -17,11 +17,11 @@ def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend - elif _which_attn_to_use(dtype) == "FlashAttentionTriton": - logger.info("Using FlashAttentionTriton backend.") - from vllm.attention.backends.flash_attn_triton import ( # noqa: F401 - FlashAttentionTritonBackend) - return FlashAttentionTritonBackend + elif _which_attn_to_use(dtype) == "TritonFlashAttention": + logger.info("Using TritonFlashAttention backend.") + from vllm.attention.backends.triton_flash_attn import ( # noqa: F401 + TritonFlashAttentionBackend) + return TritonFlashAttentionBackend else: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -33,12 +33,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: """Returns which flash attention backend to use. Returns: - str: XFormers, FlashAttention, or FlashAttentionTriton + str: XFormers, FlashAttention, or TritonFlashAttention """ # NOTE: Allow for switching between Triton and FA # Defaulting to triton FA for AMD cards. - use_flash_attn_triton = (os.environ.get("VLLM_USE_FLASH_ATTN_TRITON", + use_triton_flash_attn = (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) and is_hip() @@ -59,11 +59,11 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: "torch.float16 or torch.bfloat16.") return "XFormers" - if not use_flash_attn_triton: + if not use_triton_flash_attn: # Only test for flash_attn if we are using it. try: import flash_attn # noqa: F401 except ImportError: return "XFormers" - return "FlashAttentionTriton" if use_flash_attn_triton else "FlashAttention" + return "TritonFlashAttention" if use_triton_flash_attn else "FlashAttention" From 320d7c4fb230c3c9916bbf58716ab287d7bdba68 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Thu, 28 Mar 2024 21:20:27 +0000 Subject: [PATCH 09/21] File rename --- .../backends/{flash_attn_triton.py => triton_flash_attn.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vllm/attention/backends/{flash_attn_triton.py => triton_flash_attn.py} (100%) diff --git a/vllm/attention/backends/flash_attn_triton.py b/vllm/attention/backends/triton_flash_attn.py similarity index 100% rename from vllm/attention/backends/flash_attn_triton.py rename to vllm/attention/backends/triton_flash_attn.py From b8919d185fb2afb4d53b7090e9ba244243613f52 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Thu, 28 Mar 2024 21:24:13 +0000 Subject: [PATCH 10/21] File rename --- vllm/attention/backends/triton_flash_attn.py | 2 +- .../{flash_attention_triton.py => triton_flash_attention.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename vllm/attention/ops/{flash_attention_triton.py => triton_flash_attention.py} (100%) diff --git a/vllm/attention/backends/triton_flash_attn.py b/vllm/attention/backends/triton_flash_attn.py index a9dc57e74d4a4..03b7195f534e9 100644 --- a/vllm/attention/backends/triton_flash_attn.py +++ b/vllm/attention/backends/triton_flash_attn.py @@ -11,7 +11,7 @@ from vllm.attention.backends.abstract import AttentionImpl from vllm.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) -from vllm.attention.ops.flash_attention_triton import triton_attention +from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.attention.ops.paged_attn import PagedAttention diff --git a/vllm/attention/ops/flash_attention_triton.py b/vllm/attention/ops/triton_flash_attention.py similarity index 100% rename from vllm/attention/ops/flash_attention_triton.py rename to vllm/attention/ops/triton_flash_attention.py From ebba1982aceb4ea424c62044f4cedef186a4a8c6 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Thu, 28 Mar 2024 21:59:54 +0000 Subject: [PATCH 11/21] Run yapf and ruff --- vllm/attention/backends/triton_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/triton_flash_attn.py b/vllm/attention/backends/triton_flash_attn.py index 03b7195f534e9..19ff0449eaf2a 100644 --- a/vllm/attention/backends/triton_flash_attn.py +++ b/vllm/attention/backends/triton_flash_attn.py @@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import AttentionImpl from vllm.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) -from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.ops.triton_flash_attention import triton_attention class TritonFlashAttentionBackend(FlashAttentionBackend): From f0d1eeb2740124e7bb8c01c8e6b8915116d9bf4d Mon Sep 17 00:00:00 2001 From: jpvillam Date: Fri, 29 Mar 2024 16:50:31 +0000 Subject: [PATCH 12/21] Extra bracket in dockerfile --- Dockerfile.rocm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index e9ba74181c6ea..10b8bf1e7fabd 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -104,6 +104,6 @@ RUN cd /app \ && cd .. RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir ray[all]]==2.9.3 +RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 CMD ["/bin/bash"] From 215d15f14a765b57fd1dc649a133c8e93831a185 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 2 Apr 2024 15:14:16 +0000 Subject: [PATCH 13/21] File rename to rocm fa --- .../backends/{triton_flash_attn.py => rocm_flash_atten.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vllm/attention/backends/{triton_flash_attn.py => rocm_flash_atten.py} (100%) diff --git a/vllm/attention/backends/triton_flash_attn.py b/vllm/attention/backends/rocm_flash_atten.py similarity index 100% rename from vllm/attention/backends/triton_flash_attn.py rename to vllm/attention/backends/rocm_flash_atten.py From c6cfdffcb768ceb0cc721a1551d864e48be84cc2 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 2 Apr 2024 16:01:16 +0000 Subject: [PATCH 14/21] Make a ROCM FA backend --- vllm/attention/backends/rocm_flash_atten.py | 176 ------------ vllm/attention/backends/rocm_flash_attn.py | 286 ++++++++++++++++++++ vllm/attention/selector.py | 25 +- 3 files changed, 295 insertions(+), 192 deletions(-) delete mode 100644 vllm/attention/backends/rocm_flash_atten.py create mode 100644 vllm/attention/backends/rocm_flash_attn.py diff --git a/vllm/attention/backends/rocm_flash_atten.py b/vllm/attention/backends/rocm_flash_atten.py deleted file mode 100644 index 19ff0449eaf2a..0000000000000 --- a/vllm/attention/backends/rocm_flash_atten.py +++ /dev/null @@ -1,176 +0,0 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" -from typing import List, Optional, Type - -import torch - -from vllm.attention.backends.abstract import AttentionImpl -from vllm.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionMetadata) -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.triton_flash_attention import triton_attention - - -class TritonFlashAttentionBackend(FlashAttentionBackend): - - @staticmethod - def get_impl_cls() -> Type["TritonFlashAttentionImpl"]: - return TritonFlashAttentionImpl - - -class TritonFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - NOTE: This code is mostly duplicate from flash_attn backend - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") - - def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - tokens, n_kv_heads, head_dim = x.shape - return (x[:, :, - None, :].expand(tokens, n_kv_heads, n_rep, - head_dim).reshape(tokens, n_kv_heads * n_rep, - head_dim)) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache is not None: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - ) - - if attn_metadata.is_prompt: - # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: - # triton attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - - if self.num_kv_heads != self.num_heads: - # Interleave for MQA workaround. - key = self.repeat_kv(key, self.num_queries_per_kv) - value = self.repeat_kv(value, self.num_queries_per_kv) - - output, _ = triton_attention( - query, - key, - value, - None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, - True, - self.scale, - ) - else: - # prefix-enabled attention - output = PagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, - self.alibi_slopes, - ) - else: - # Decoding run. - output = PagedAttention.forward_decode( - query, - key_cache, - value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, - attn_metadata.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) - - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py new file mode 100644 index 0000000000000..f90f6d328a542 --- /dev/null +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -0,0 +1,286 @@ +"""Attention layer with Flash and PagedAttention. + +NOTE(woosuk): At the moment, this file includes a lot of duplicated code from +XFormers backend. The duplicated code will be removed once we use flash-attn or +flashinfer for all the attention operations. +""" +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class ROCmFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": + return ROCmFlashAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + + # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seqlen ----------------------| + # |- subquery_len -| + + # WARNING(sang): context_len has different definition depending on if it is + # prefill vs decoding. When it is prefill, it doesn't include new tokens. + # When it is for decoding, it includes a new token. + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # Maximum prompt length in the batch. + max_prompt_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + # NOTE: Allow for switching between Triton and CK + # Defaulting to triton + self.use_triton_flash_attn = (os.environ.get( + "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) + if self.use_triton_flash_attn: + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.fa_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + else: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.fa_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype, + ) + + if attn_metadata.is_prompt: + # Prompt run. + if kv_cache is None or attn_metadata.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + if self.use_triton_flash_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + + output, _ = self.fa_func( + query, + key, + value, + None, + attn_metadata.seq_start_loc, + attn_metadata.seq_start_loc, + attn_metadata.max_prompt_len, + attn_metadata.max_prompt_len, + True, + self.scale, + ) + else: + output = self.fa_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + ) + + else: + # prefix-enabled attention + output = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.subquery_start_loc, + attn_metadata.prompt_lens_tensor, + attn_metadata.context_lens, + attn_metadata.max_subquery_len, + self.alibi_slopes, + ) + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index c831cafffa194..80ecc6e754e6d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,3 @@ -import os from functools import lru_cache from typing import Type @@ -18,11 +17,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend - elif _which_attn_to_use(dtype) == "TritonFlashAttention": - logger.info("Using TritonFlashAttention backend.") - from vllm.attention.backends.triton_flash_attn import ( # noqa: F401 - TritonFlashAttentionBackend) - return TritonFlashAttentionBackend + elif _which_attn_to_use(dtype) == "ROCmFlashAttention": + logger.info("Using ROCmFlashAttention backend.") + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend else: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -34,15 +33,9 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: """Returns which flash attention backend to use. Returns: - str: XFormers, FlashAttention, or TritonFlashAttention + str: XFormers, FlashAttention, or ROCmFlashAttention """ - # NOTE: Allow for switching between Triton and FA - # Defaulting to triton FA for AMD cards. - use_triton_flash_attn = (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", - "True").lower() - in ("true", "1")) and is_hip() - if not is_hip() and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " @@ -60,8 +53,8 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: "torch.float16 or torch.bfloat16.") return "XFormers" - if not use_triton_flash_attn: - # Only test for flash_attn if we are using it. + if not is_hip(): + # ROCm backend has its own check for this. try: import flash_attn # noqa: F401 except ImportError: @@ -70,4 +63,4 @@ def _which_attn_to_use(dtype: torch.dtype) -> str: "Please install it for better performance.") return "XFormers" - return "TritonFlashAttention" if use_triton_flash_attn else "FlashAttention" + return "ROCmFlashAttention" if is_hip() else "FlashAttention" From e691488626b860c1c500da4db9bb8c692e269d8c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 Apr 2024 17:46:49 +0000 Subject: [PATCH 15/21] Add kv_scale to ROCM flash attn --- vllm/attention/backends/rocm_flash_attn.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index f90f6d328a542..b18dfa307f12b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,9 +1,4 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" +"""Attention layer ROCm GPUs.""" import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type @@ -181,6 +176,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, + kv_scale: float, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -213,6 +209,7 @@ def forward( value_cache, attn_metadata.slot_mapping, attn_metadata.kv_cache_dtype, + kv_scale, ) if attn_metadata.is_prompt: @@ -280,6 +277,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + kv_scale, ) # Reshape the output tensor. From 26457faf70728da8faa4d7b4f6b8352cfe4503e6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 Apr 2024 17:47:09 +0000 Subject: [PATCH 16/21] Remove is_hip from FlashAttentionBackend --- vllm/attention/backends/flash_attn.py | 41 +++++++++------------------ 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index a34e99e0a1575..4e0d9d1418b32 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,7 +14,6 @@ AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) -from vllm.utils import is_hip class FlashAttentionBackend(AttentionBackend): @@ -195,33 +194,19 @@ def forward( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if is_hip(): - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, - softmax_scale=self.scale, - causal=True, - ) - else: - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to From 03679f4bc12a8c98f64a04c9351ff190f854ddbe Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 Apr 2024 18:06:41 +0000 Subject: [PATCH 17/21] Move naive attention to rocm flash attn --- vllm/attention/backends/rocm_flash_attn.py | 115 +++++++++++++++++---- vllm/attention/backends/xformers.py | 74 +------------ 2 files changed, 96 insertions(+), 93 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b18dfa307f12b..403200d03bcd1 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,4 +1,5 @@ """Attention layer ROCm GPUs.""" +import importlib import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type @@ -147,18 +148,24 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - # NOTE: Allow for switching between Triton and CK - # Defaulting to triton + self.use_naive_attn = _check_use_naive_attention() + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = (os.environ.get( "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) - if self.use_triton_flash_attn: + if self.use_naive_attn: + # AMD Radeon 7900 series (gfx1100) currently does not support xFormers + # nor FlashAttention. As a temporary workaround, we use naive PyTorch + # implementation of attention. + self.attn_fuc = _naive_attention() + logger.debug("Using naive attention in ROCmBackend") + elif self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) - self.fa_func = triton_attention + self.attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") else: from flash_attn import flash_attn_varlen_func # noqa: F401 - self.fa_func = flash_attn_varlen_func + self.attn_func = flash_attn_varlen_func logger.debug("Using CK FA in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -218,26 +225,34 @@ def forward( # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if self.use_triton_flash_attn: + if self.use_naive_attn or self.use_triton_flash_attn: if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) - - output, _ = self.fa_func( - query, - key, - value, - None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, - True, - self.scale, - ) + if self.use_naive_attn: + output = self.attn_fuc( + query, + key, + value, + attn_metadata.prompt_lens, + self.scale, + ) + else: + output, _ = self.attn_func( + query, + key, + value, + None, + attn_metadata.seq_start_loc, + attn_metadata.seq_start_loc, + attn_metadata.max_prompt_len, + attn_metadata.max_prompt_len, + True, + self.scale, + ) else: - output = self.fa_func( + output = self.attn_func( q=query, k=key, v=value, @@ -282,3 +297,63 @@ def forward( # Reshape the output tensor. return output.view(num_tokens, hidden_size) + + +def _check_use_naive_attention() -> bool: + # For ROCm, check whether flash attention is installed or not. + use_naive_attention = importlib.util.find_spec("flash_attn") is None + if use_naive_attention: + logger.warning("flash_attn is not installed. Using naive attention. " + "This will take significantly more GPU memory.") + return True + return False + + +def _naive_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + prompt_lens: List[int], + scale: float, +) -> torch.Tensor: + num_tokens = query.shape[0] + output = torch.empty_like(query) + start = 0 + for _, prompt_len in enumerate(prompt_lens): + end = start + prompt_len + out = _naive_masked_attention( + query[None, start:end], + key[None, start:end], + value[None, start:end], + scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out) + start += prompt_len + + # Using view got RuntimeError: view size is not compatible + # with input tensor's size and stride (at least one + # dimension spans across two contiguous subspaces). + # Use reshape instead. + return output.reshape(num_tokens, -1) + + +def _naive_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +) -> torch.Tensor: + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d349c3ef19ea7..5c36061e20bf9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,5 +1,4 @@ """Attention layer with xFormers and PagedAttention.""" -import importlib from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type @@ -166,11 +165,6 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - # AMD Radeon 7900 series (gfx1100) currently does not support xFormers - # nor FlashAttention. As a temporary workaround, we use naive PyTorch - # implementation of attention. - self.use_naive_attention = _check_use_naive_attention() - def forward( self, query: torch.Tensor, @@ -233,30 +227,6 @@ def forward( self.num_queries_per_kv, value.shape[-1]) - if self.use_naive_attention: - output = torch.empty_like(query) - start = 0 - for _, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len - out = _naive_masked_attention( - query[None, start:end], - key[None, start:end], - value[None, start:end], - self.num_heads, - self.num_kv_heads, - self.head_size, - self.scale, - ) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out) - start += prompt_len - - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. - return output.reshape(num_tokens, hidden_size) - output = self._run_memory_efficient_xformers_forward( query, key, value, attn_metadata) else: @@ -329,8 +299,6 @@ def _run_memory_efficient_xformers_forward( self.alibi_slopes, self.num_kv_heads, query.dtype, attn_metadata.prompt_lens) - op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( - is_hip()) else None # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -344,8 +312,7 @@ def _run_memory_efficient_xformers_forward( value, attn_bias=attn_metadata.attn_bias[0], p=0.0, - scale=self.scale, - op=op) + scale=self.scale) return out.view_as(query) @@ -405,42 +372,3 @@ def _make_alibi_bias( attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) return attn_biases - - -def _check_use_naive_attention() -> bool: - if not is_hip(): - return False - # For ROCm, check whether flash attention is installed or not. - use_naive_attention = importlib.util.find_spec("flash_attn") is None - if use_naive_attention: - logger.warning("flash_attn is not installed. Using naive attention. " - "This will take significantly more GPU memory.") - return True - return False - - -def _naive_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - scale: float, -) -> torch.Tensor: - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - seq_len, _, _ = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out From 3d0db08a68b1d6ce0df2a63867177667baf3fa4d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 Apr 2024 18:06:53 +0000 Subject: [PATCH 18/21] Refactor selector --- vllm/attention/selector.py | 75 ++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index e8f64642afb67..4c699aed48d49 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,3 +1,4 @@ +import enum from functools import lru_cache from typing import Type @@ -10,64 +11,68 @@ logger = init_logger(__name__) +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + + @lru_cache(maxsize=None) def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - if _which_attn_to_use(dtype) == "FlashAttention": + backend = _which_attn_to_use(dtype) + if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend - elif _which_attn_to_use(dtype) == "ROCmFlashAttention": + elif backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + from vllm.attention.backends.xformers import ( # noqa: F401 + XFormersBackend) + return XFormersBackend + elif backend == _Backend.ROCM_FLASH: logger.info("Using ROCmFlashAttention backend.") from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 ROCmFlashAttentionBackend) return ROCmFlashAttentionBackend - elif _which_attn_to_use(dtype) == "CPUFlashAttention": + elif backend == _Backend.TORCH_SDPA: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend else: - logger.info("Using XFormers backend.") - from vllm.attention.backends.xformers import ( # noqa: F401 - XFormersBackend) - return XFormersBackend - + raise ValueError("Invalid attention backend.") -def _which_attn_to_use(dtype: torch.dtype) -> str: - """Returns which flash attention backend to use. - Returns: - str: XFormers, FlashAttention, - CPUFlashAttention, or ROCmFlashAttention - """ +def _which_attn_to_use(dtype: torch.dtype) -> _Backend: + """Returns which flash attention backend to use.""" if is_cpu(): - return "CPUFlashAttention" + return _Backend.TORCH_SDPA - if not is_hip() and torch.cuda.get_device_capability()[0] < 8: + if is_hip(): + # AMD GPUs. + if torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_atten is not supported on NAVI GPUs.") + return _Backend.ROCM_FLASH + + # NVIDIA GPUs. + if torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " "GPUs.") - return "XFormers" - - if is_hip() and torch.cuda.get_device_capability()[0] != 9: - # not Instinct series GPUs. - logger.info("flash_atten is not supported on NAVI GPUs. " - "Using xformers backend.") - return "XFormers" + return _Backend.XFORMERS if dtype not in (torch.float16, torch.bfloat16): logger.info("Cannot use FlashAttention backend for dtype other than " "torch.float16 or torch.bfloat16.") - return "XFormers" - - if not is_hip(): - # ROCm backend has its own check for this. - try: - import flash_attn # noqa: F401 - except ImportError: - logger.info( - "Cannot use FlashAttention because the package is not found. " - "Please install it for better performance.") - return "XFormers" + return _Backend.XFORMERS - return "ROCmFlashAttention" if is_hip() else "FlashAttention" + try: + import flash_attn # noqa: F401 + except ImportError: + logger.info( + "Cannot use FlashAttention backend because the flash_attn package " + "is not found. Please install it for better performance.") + return _Backend.XFORMERS + return _Backend.FLASH_ATTN From 0c9b6688a7bf3797697dc3298b1227ae7148eb83 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 Apr 2024 18:11:57 +0000 Subject: [PATCH 19/21] yapf --- vllm/attention/backends/rocm_flash_attn.py | 6 +++--- vllm/attention/backends/xformers.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 403200d03bcd1..32b2092c4030d 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -153,9 +153,9 @@ def __init__( self.use_triton_flash_attn = (os.environ.get( "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) if self.use_naive_attn: - # AMD Radeon 7900 series (gfx1100) currently does not support xFormers - # nor FlashAttention. As a temporary workaround, we use naive PyTorch - # implementation of attention. + # AMD Radeon 7900 series (gfx1100) currently does not support + # xFormers nor FlashAttention. As a temporary workaround, we use + # naive PyTorch implementation of attention. self.attn_fuc = _naive_attention() logger.debug("Using naive attention in ROCmBackend") elif self.use_triton_flash_attn: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5c36061e20bf9..05b68bba5e6eb 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -13,7 +13,6 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger -from vllm.utils import is_hip logger = init_logger(__name__) @@ -330,8 +329,7 @@ def _run_memory_efficient_xformers_forward( value[None, start:end], attn_bias=attn_metadata.attn_bias[i], p=0.0, - scale=self.scale, - op=op) + scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.squeeze(0)) start += prompt_len From c9393387b8bc662b2fae7f935086553e57859d68 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 Apr 2024 18:30:43 +0000 Subject: [PATCH 20/21] Fix use_naive --- vllm/attention/backends/rocm_flash_attn.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 32b2092c4030d..bc527d4d5b5b6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -148,7 +148,7 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - self.use_naive_attn = _check_use_naive_attention() + self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9 # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = (os.environ.get( "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) @@ -183,7 +183,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -299,16 +299,6 @@ def forward( return output.view(num_tokens, hidden_size) -def _check_use_naive_attention() -> bool: - # For ROCm, check whether flash attention is installed or not. - use_naive_attention = importlib.util.find_spec("flash_attn") is None - if use_naive_attention: - logger.warning("flash_attn is not installed. Using naive attention. " - "This will take significantly more GPU memory.") - return True - return False - - def _naive_attention( query: torch.Tensor, key: torch.Tensor, From 1238bc131336e4814c77eee98c4423453e7d7a63 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 Apr 2024 18:32:04 +0000 Subject: [PATCH 21/21] yapf --- vllm/attention/backends/rocm_flash_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index bc527d4d5b5b6..6019d917b4494 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,5 +1,4 @@ """Attention layer ROCm GPUs.""" -import importlib import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type