From 6c0b04515fee7b402a6febde1467581825bb2164 Mon Sep 17 00:00:00 2001 From: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Date: Tue, 9 Apr 2024 17:10:47 -0500 Subject: [PATCH] [ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm (#3643) Co-authored-by: jpvillam Co-authored-by: Gregory Shtrasberg Co-authored-by: Woosuk Kwon --- Dockerfile.rocm | 14 + vllm/attention/backends/rocm_flash_attn.py | 348 ++++++++ vllm/attention/backends/xformers.py | 78 +- vllm/attention/ops/triton_flash_attention.py | 809 +++++++++++++++++++ vllm/attention/selector.py | 57 +- 5 files changed, 1213 insertions(+), 93 deletions(-) create mode 100644 vllm/attention/backends/rocm_flash_attn.py create mode 100644 vllm/attention/ops/triton_flash_attention.py diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 65a367994f960..10b8bf1e7fabd 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH" # In that case, we need to use the python reference attention implementation in vllm ARG BUILD_FA="1" +# whether to build triton on rocm +ARG BUILD_TRITON="1" + # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y @@ -75,6 +78,17 @@ RUN if [ "$BUILD_FA" = "1" ]; then \ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi +# build triton +RUN if [ "$BUILD_TRITON" = "1" ]; then \ + mkdir -p libs \ + && cd libs \ + && pip uninstall -y triton \ + && git clone https://github.com/ROCm/triton.git \ + && cd triton/python \ + && pip3 install . \ + && cd ../..; \ + fi + COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py new file mode 100644 index 0000000000000..6019d917b4494 --- /dev/null +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -0,0 +1,348 @@ +"""Attention layer ROCm GPUs.""" +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class ROCmFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": + return ROCmFlashAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + + # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seqlen ----------------------| + # |- subquery_len -| + + # WARNING(sang): context_len has different definition depending on if it is + # prefill vs decoding. When it is prefill, it doesn't include new tokens. + # When it is for decoding, it includes a new token. + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # Maximum prompt length in the batch. + max_prompt_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9 + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. + self.use_triton_flash_attn = (os.environ.get( + "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) + if self.use_naive_attn: + # AMD Radeon 7900 series (gfx1100) currently does not support + # xFormers nor FlashAttention. As a temporary workaround, we use + # naive PyTorch implementation of attention. + self.attn_fuc = _naive_attention() + logger.debug("Using naive attention in ROCmBackend") + elif self.use_triton_flash_attn: + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.attn_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + else: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype, + kv_scale, + ) + + if attn_metadata.is_prompt: + # Prompt run. + if kv_cache is None or attn_metadata.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + if self.use_naive_attn or self.use_triton_flash_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + if self.use_naive_attn: + output = self.attn_fuc( + query, + key, + value, + attn_metadata.prompt_lens, + self.scale, + ) + else: + output, _ = self.attn_func( + query, + key, + value, + None, + attn_metadata.seq_start_loc, + attn_metadata.seq_start_loc, + attn_metadata.max_prompt_len, + attn_metadata.max_prompt_len, + True, + self.scale, + ) + else: + output = self.attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + ) + + else: + # prefix-enabled attention + output = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.subquery_start_loc, + attn_metadata.prompt_lens_tensor, + attn_metadata.context_lens, + attn_metadata.max_subquery_len, + self.alibi_slopes, + ) + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + +def _naive_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + prompt_lens: List[int], + scale: float, +) -> torch.Tensor: + num_tokens = query.shape[0] + output = torch.empty_like(query) + start = 0 + for _, prompt_len in enumerate(prompt_lens): + end = start + prompt_len + out = _naive_masked_attention( + query[None, start:end], + key[None, start:end], + value[None, start:end], + scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out) + start += prompt_len + + # Using view got RuntimeError: view size is not compatible + # with input tensor's size and stride (at least one + # dimension spans across two contiguous subspaces). + # Use reshape instead. + return output.reshape(num_tokens, -1) + + +def _naive_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +) -> torch.Tensor: + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d349c3ef19ea7..05b68bba5e6eb 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,5 +1,4 @@ """Attention layer with xFormers and PagedAttention.""" -import importlib from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type @@ -14,7 +13,6 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger -from vllm.utils import is_hip logger = init_logger(__name__) @@ -166,11 +164,6 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - # AMD Radeon 7900 series (gfx1100) currently does not support xFormers - # nor FlashAttention. As a temporary workaround, we use naive PyTorch - # implementation of attention. - self.use_naive_attention = _check_use_naive_attention() - def forward( self, query: torch.Tensor, @@ -233,30 +226,6 @@ def forward( self.num_queries_per_kv, value.shape[-1]) - if self.use_naive_attention: - output = torch.empty_like(query) - start = 0 - for _, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len - out = _naive_masked_attention( - query[None, start:end], - key[None, start:end], - value[None, start:end], - self.num_heads, - self.num_kv_heads, - self.head_size, - self.scale, - ) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out) - start += prompt_len - - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. - return output.reshape(num_tokens, hidden_size) - output = self._run_memory_efficient_xformers_forward( query, key, value, attn_metadata) else: @@ -329,8 +298,6 @@ def _run_memory_efficient_xformers_forward( self.alibi_slopes, self.num_kv_heads, query.dtype, attn_metadata.prompt_lens) - op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( - is_hip()) else None # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -344,8 +311,7 @@ def _run_memory_efficient_xformers_forward( value, attn_bias=attn_metadata.attn_bias[0], p=0.0, - scale=self.scale, - op=op) + scale=self.scale) return out.view_as(query) @@ -363,8 +329,7 @@ def _run_memory_efficient_xformers_forward( value[None, start:end], attn_bias=attn_metadata.attn_bias[i], p=0.0, - scale=self.scale, - op=op) + scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.squeeze(0)) start += prompt_len @@ -405,42 +370,3 @@ def _make_alibi_bias( attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) return attn_biases - - -def _check_use_naive_attention() -> bool: - if not is_hip(): - return False - # For ROCm, check whether flash attention is installed or not. - use_naive_attention = importlib.util.find_spec("flash_attn") is None - if use_naive_attention: - logger.warning("flash_attn is not installed. Using naive attention. " - "This will take significantly more GPU memory.") - return True - return False - - -def _naive_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - scale: float, -) -> torch.Tensor: - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - seq_len, _, _ = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py new file mode 100644 index 0000000000000..b86e845020b07 --- /dev/null +++ b/vllm/attention/ops/triton_flash_attention.py @@ -0,0 +1,809 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + hq, + hk, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + is_mqa = hq != hk + off_h_k = off_h_q % hk if is_mqa else off_h_q + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * hq + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] >= + out_mask_boundary[None, :]) + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + # TODO: Fix assert to check head size <=256 once supported + assert head_size <= 128 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + hq=nheads_q, + hk=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index b5cd39bbe6252..4c699aed48d49 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,3 +1,4 @@ +import enum from functools import lru_cache from typing import Type @@ -10,46 +11,68 @@ logger = init_logger(__name__) +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + + @lru_cache(maxsize=None) def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - if _can_use_flash_attn(dtype): + backend = _which_attn_to_use(dtype) + if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend - elif is_cpu(): - logger.info("Using Torch SDPA backend.") - from vllm.attention.backends.torch_sdpa import TorchSDPABackend - return TorchSDPABackend - else: + elif backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend) return XFormersBackend + elif backend == _Backend.ROCM_FLASH: + logger.info("Using ROCmFlashAttention backend.") + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend + elif backend == _Backend.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + from vllm.attention.backends.torch_sdpa import TorchSDPABackend + return TorchSDPABackend + else: + raise ValueError("Invalid attention backend.") -def _can_use_flash_attn(dtype: torch.dtype) -> bool: +def _which_attn_to_use(dtype: torch.dtype) -> _Backend: + """Returns which flash attention backend to use.""" + if is_cpu(): + return _Backend.TORCH_SDPA + if is_hip(): # AMD GPUs. - logger.info("Cannot use FlashAttention backend for AMD GPUs.") - return False - if is_cpu(): - return False + if torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_atten is not supported on NAVI GPUs.") + return _Backend.ROCM_FLASH + + # NVIDIA GPUs. if torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " "GPUs.") - return False + return _Backend.XFORMERS + if dtype not in (torch.float16, torch.bfloat16): logger.info("Cannot use FlashAttention backend for dtype other than " "torch.float16 or torch.bfloat16.") - return False + return _Backend.XFORMERS try: import flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use FlashAttention because the package is not found. " - "Please install it for better performance.") - return False - return True + "Cannot use FlashAttention backend because the flash_attn package " + "is not found. Please install it for better performance.") + return _Backend.XFORMERS + return _Backend.FLASH_ATTN