From a565083b19fab61ff0f8705f22254b2f47aedc1b Mon Sep 17 00:00:00 2001 From: Jeremy Francis Reizenstein Date: Mon, 19 Feb 2024 10:20:03 +0000 Subject: [PATCH] Export recent MHA changes from fbcode triton_splitk extensively expanded. Includes some paged attention, merge_attentions ghstack-source-id: 2e01b95df689395cb64844eb6db5e5638058e599 Pull Request resolved: https://github.com/fairinternal/xformers/pull/1031 __original_commit__ = fairinternal/xformers@19b9c2106e2d62828a7cac3a6f7a8ef16719cdef --- CHANGELOG.md | 2 + tests/test_mem_eff_attention.py | 444 +++++- tests/utils.py | 69 +- xformers/attn_bias_utils.py | 14 +- .../benchmarks/benchmark_attn_decoding.py | 5 +- xformers/ops/fmha/__init__.py | 62 + xformers/ops/fmha/attn_bias.py | 122 +- xformers/ops/fmha/common.py | 7 +- xformers/ops/fmha/flash.py | 4 +- xformers/ops/fmha/triton_splitk.py | 1326 +++++++++++------ 10 files changed, 1552 insertions(+), 503 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c26162a249..50a3e3b56e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.0.25] - TBD ### Added +- New merge_attentions function ### Improved - fMHA: Updated Flash-Attention to v2.5.2: this has a performance improvement for multiquery. +- fMHA: triton_splitk changed and expanded. Now amalgamates using LSE. Can autotune, supports causal with a small number of queries - not just 1. Experimental support for paged attention. ### Removed ## [0.0.24] - 2024-01-31 diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 441a36d9ae..ae717b32ab 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import logging import math import random from functools import partial @@ -18,10 +19,10 @@ from xformers.attn_bias_utils import create_attn_bias from xformers.ops import fmha from xformers.ops.fmha import ALL_BW_OPS, ALL_FW_OPS -from xformers.ops.fmha.common import AttentionOpBase +from xformers.ops.fmha.common import AttentionFwOpBase, AttentionOpBase from xformers.ops.fmha.dispatch import _dispatch_fw_priority_list -from .utils import assert_allclose +from .utils import assert_allclose, pack_kv_cache torch.backends.cuda.matmul.allow_tf32 = False cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") @@ -43,6 +44,8 @@ "T", Type[fmha.common.AttentionFwOpBase], Type[fmha.common.AttentionBwOpBase] ) +logger = logging.getLogger("xformers") + def _filter_unsupported_ops(ops: Sequence[T]) -> Sequence[T]: return [ @@ -176,10 +179,10 @@ def _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv( fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, }: Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) + 2 - elif ( - bias_type - is fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask - ): + elif bias_type in { + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + }: Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) shape = (B, Mq, Mkv, H, K, Kv) combination.append((op, device, dtype, bias_type, *shape)) @@ -410,6 +413,7 @@ def create_tensors( ( fmha.attn_bias.BlockDiagonalMask, fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, ), ): query, key, value = [ @@ -1636,13 +1640,10 @@ def dequant_cache(x): @sm80_or_better_only -@pytest.mark.skipif( - fmha.triton_splitk.FwOp_S2.OPERATOR is None, reason="splitK disabled" -) @pytest.mark.parametrize( "op,dequant,dtype", [ - (fmha.triton_splitk.FwOp_S2, False, "bf16"), + (fmha.triton_splitk.FwOp_S1, False, "bf16"), (fmha.triton_splitk.FwOp_S2, False, "f16"), (fmha.triton_splitk.FwOp_S2, True, "bf16"), ( @@ -1666,9 +1667,6 @@ def test_triton_splitk_decoder( bsz: int, dtype: str, ) -> None: - if dequant: - pytest.skip("dequant is not supported") - # We omit dequant with f16: it needs a very high tol test_decoder( op, @@ -1681,6 +1679,42 @@ def test_triton_splitk_decoder( ) +@sm80_or_better_only +@pytest.mark.parametrize( + "op", + [ + fmha.triton_splitk.FwOp_S1, + fmha.triton_splitk.FwOp_S2, + ], + ids=lambda op: f"splitk{op.SPLIT_K}", +) +@pytest.mark.parametrize("multiquery", [True, False], ids=lambda x: "mq" if x else "") +# n_heads=1 => it's ambiguous whether can count as multiquery +@pytest.mark.parametrize("padding, bsz", [(32, 8), (44, 1)]) +@pytest.mark.parametrize("dtype", ["f16", "bf16"]) +@pytest.mark.parametrize("n_heads, num_queries", [(2, 4), (2, 5), (6, 7), (20, 3)]) +def test_triton_splitk_decoder_manyqueries( + op, + multiquery: bool, + n_heads: int, + padding: int, + bsz: int, + dtype: str, + num_queries: int, +) -> None: + kv_heads = 1 if multiquery else None + test_decoder( + op, + kv_heads=kv_heads, + n_heads=n_heads, + padding=padding, + bsz=bsz, + dtype=dtype, + num_queries=num_queries, + dequant=False, + ) + + def test_attn_bias_from_seqlens() -> None: bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens([3, 5, 1]) out = bias.split(torch.randn([1, 3 + 5 + 1, 16])) @@ -2093,7 +2127,7 @@ def test_forward_splitk( @cuda_only -@pytest.mark.parametrize("op", [fmha.triton_splitk.FwOp]) +@pytest.mark.parametrize("op", [fmha.triton_splitk.FwOp, fmha.flash.FwOp]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "B_Mkv_H_K", @@ -2323,4 +2357,386 @@ def test_cutlassB_iter_order( assert num_parallel_blocks == num_actual +@sm80_or_better_only +@pytest.mark.parametrize("B", [1, 5, 128]) +@pytest.mark.parametrize("MAX_T", [64, 128, 2048, 4096, 8192]) +@pytest.mark.parametrize( + "op", + [ + fmha.triton_splitk.FwOp, + fmha.triton_splitk.FwOp_S8, + fmha.triton_splitk.FwOp_Map[48], + ], +) +@pytest.mark.parametrize("num_quant_groups", [0, 1, 8]) +@pytest.mark.parametrize("page_size", [64, 128, 256]) +def test_paged_attention( + B, MAX_T: int, num_quant_groups: bool, page_size: int, op: Type[AttentionFwOpBase] +): + paged_attention_run_inner(B, MAX_T, num_quant_groups, page_size, op, bench=False) + + +def paged_attention_run_inner( + B: int, + MAX_T: int, + num_quant_groups: bool, + page_size: int, + op: Type[AttentionFwOpBase], + bench: bool, +) -> None: + import triton + + torch.manual_seed(10) + TEST_WARMUP_MS = 500 + TEST_RUN_MS = 5000 + + N_H_L = 8 + N_KVH_L = 1 + D_H = 128 + D_H_KV = D_H // 8 + num_quant_groups if num_quant_groups else D_H + kv_seqlens = torch.randint(low=1, high=MAX_T + 1, size=(B,)).tolist() + + attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=[1] * B, + kv_padding=MAX_T, + kv_seqlen=kv_seqlens, + ) + + q = torch.randn((B, 1, N_H_L, D_H), dtype=torch.bfloat16, device="cuda") + if num_quant_groups: + # Using high=64 below, because with 256 both paged and non-paged paths + # will produce NaNs - probably some quantization coeffitions are NaNs + # after the bitwise cast. + cache_k = torch.randint( + 0, 64, (B, MAX_T, N_KVH_L, D_H_KV * 4), dtype=torch.uint8, device="cuda" + ) + cache_k = cache_k.view(dtype=torch.int32) + cache_v = torch.randint( + 0, 64, (B, MAX_T, N_KVH_L, D_H_KV * 4), dtype=torch.uint8, device="cuda" + ) + cache_v = cache_v.view(dtype=torch.int32) + + op = type( + f"{op.__name__}_{num_quant_groups}", + (op,), + {"NUM_GROUPS": num_quant_groups}, + ) + else: + cache_k = torch.randn( + (B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" + ) + cache_v = torch.randn_like(cache_k) + + axq = q.view(1, B * 1, N_H_L, D_H) + axk = cache_k.view(1, B * MAX_T, N_KVH_L, D_H_KV).expand( + 1, B * MAX_T, N_H_L, D_H_KV + ) + axv = cache_v.view(1, B * MAX_T, N_KVH_L, D_H_KV).expand( + 1, B * MAX_T, N_H_L, D_H_KV + ) + + k_cache_size_usual = axk.numel() + + # First, create "wasteful" K/V cache, where every block in logical cache has a physical representation, + # even if there's nothing stored there + + # Paged attention requires k.shape[1] and v.shape[1] to be divisible by page_size, so pad + padded_per_row_len = ((MAX_T + page_size - 1) // page_size) * page_size + block_tables = torch.arange( + B * padded_per_row_len // page_size, device="cuda", dtype=torch.int32 + ).reshape(B, -1) + + shape_padded = (B, padded_per_row_len, N_KVH_L, D_H_KV) + axk_padded = torch.empty(shape_padded, device=axk.device, dtype=axk.dtype) + axv_padded = torch.empty(shape_padded, device=axv.device, dtype=axv.dtype) + axk_padded[:, :MAX_T] = axk.view(B, -1, N_H_L, D_H_KV)[:, :, :1, :] + axv_padded[:, :MAX_T] = axv.view(B, -1, N_H_L, D_H_KV)[:, :, :1, :] + + axk_padded = axk_padded.view(1, B * padded_per_row_len, N_KVH_L, D_H_KV) + axv_padded = axv_padded.view(1, B * padded_per_row_len, N_KVH_L, D_H_KV) + + axk_padded = axk_padded.expand(-1, -1, N_H_L, -1) + axv_padded = axv_padded.expand(-1, -1, N_H_L, -1) + + attn_bias_paged = attn_bias.make_paged( + block_tables=block_tables, page_size=page_size + ) + + y_usual = fmha.memory_efficient_attention_forward( + axq, + axk, + axv, + attn_bias, + op=op, + ) + if bench: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_usual = fmha.memory_efficient_attention_forward( + axq, + axk, + axv, + attn_bias, + op=op, + ) + t_ms = triton.testing.do_bench( + lambda g=g: g.replay(), + warmup=TEST_WARMUP_MS, + rep=TEST_RUN_MS, + ) + logger.info(f"Non-paged attention took {t_ms * 1e3:.2f}us") + + y_wasteful = fmha.memory_efficient_attention_forward( + axq, + axk_padded, + axv_padded, + attn_bias_paged, + op=op, + ) + if bench: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_wasteful = fmha.memory_efficient_attention_forward( + axq, + axk_padded, + axv_padded, + attn_bias_paged, + op=op, + ) + t_ms = triton.testing.do_bench( + lambda g=g: g.replay(), + warmup=TEST_WARMUP_MS, + rep=TEST_RUN_MS, + ) + logger.info(f"Paged attention with wasteful K/V-cache took {t_ms * 1e3:.2f}us") + + torch.testing.assert_close( + y_wasteful, + y_usual, + atol=1.0e-2, + rtol=1.0e-2, + ) + + # Now let's create a "packed" K/V cache, where only meaniningful logical blocks are mapped to physical blocks + (block_tables, packed_cache_k, packed_cache_v) = pack_kv_cache( + cache_k, + cache_v, + kv_seqlens, + page_size, + ) + attn_bias_paged = attn_bias.make_paged( + block_tables=block_tables, page_size=page_size + ) + axk = packed_cache_k.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV) + axv = packed_cache_v.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV) + + k_cache_size_packed = axk.numel() + + y_packed = fmha.memory_efficient_attention_forward( + axq, + axk, + axv, + attn_bias_paged, + op=op, + ) + + logger.info( + f"KV-cache size reduced by {(100 * (1 - k_cache_size_packed/k_cache_size_usual)):.2f}%" + ) + + torch.testing.assert_close(y_wasteful, y_packed) + + # Let's swap two blocks, and adjust two corresponding entries in the block table. The result shouldn't change + i, j = 0, axk.shape[1] // page_size - 1 + + axk = axk[:, :, :1, :] + axv = axv[:, :, :1, :] + + vals_i = axk[:, i * page_size : (i + 1) * page_size, :, :].clone() + vals_j = axk[:, j * page_size : (j + 1) * page_size, :, :].clone() + axk[:, i * page_size : (i + 1) * page_size, :, :] = vals_j + axk[:, j * page_size : (j + 1) * page_size, :, :] = vals_i + + vals_i = axv[:, i * page_size : (i + 1) * page_size, :, :].clone() + vals_j = axv[:, j * page_size : (j + 1) * page_size, :, :].clone() + axv[:, i * page_size : (i + 1) * page_size, :, :] = vals_j + axv[:, j * page_size : (j + 1) * page_size, :, :] = vals_i + + axk = axk.expand(-1, -1, N_H_L, -1) + axv = axv.expand(-1, -1, N_H_L, -1) + + where_i = block_tables == i + where_j = block_tables == j + block_tables.masked_fill_(where_i, j) + block_tables.masked_fill_(where_j, i) + + y_swapped = fmha.memory_efficient_attention_forward( + axq, + axk, + axv, + attn_bias_paged, + op=op, + ) + if bench: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_swapped = fmha.memory_efficient_attention_forward( + axq, + axk, + axv, + attn_bias_paged, + op=op, + ) + t_ms = triton.testing.do_bench( + lambda g=g: g.replay(), + warmup=TEST_WARMUP_MS, + rep=TEST_RUN_MS, + ) + logger.info(f"Paged attention with packed K/V-cache took {t_ms * 1e3:.2f}us") + + torch.testing.assert_close(y_swapped, y_packed) + + +@sm80_or_better_only +def test_merging_attentions_decoding(): + """ + Compute decoding attention on chunks of K/V and merge them together. + Compare with computing attention on the whole K/V. + """ + + MAX_T = 8192 + B = 128 + N_KVH_L = 1 + N_H_L = 8 + D_H = 128 + dtype = torch.bfloat16 + + num_chunks = 10 + + chunk_starts = sorted( + torch.randint(low=1, high=MAX_T // 2, size=(num_chunks,)).tolist() + ) + chunk_starts[0] = 0 + chunk_starts.append(MAX_T) + + # We construct sequances so that even the last chunk has a non-empty part of every sequence. + # Otherwise the corresponding LSE will be -inf and that'll propagate to the whole sum. + # It is possible to teach the kernel to ignore infinite LSEs, but in practical use cases + # of merging attention, e.g. a batch of sequences with a common prefix, this condition should be satisfied. + k_lens = torch.randint(low=chunk_starts[-2] + 1, high=MAX_T, size=(B,)).tolist() + q_lens = [1 for _ in k_lens] + B_T = sum(q_lens) + + q = torch.randn((1, B_T, N_H_L, D_H), dtype=dtype, device="cuda") + k = torch.randn((B, MAX_T, N_KVH_L, D_H), dtype=dtype, device="cuda") + v = torch.randn_like(k) + + # Compute per-chunk attention + chunks_output = [] + for i in range(num_chunks): + chunk_start, chunk_end = chunk_starts[i], chunk_starts[i + 1] + k_chunk = k[:, chunk_start:chunk_end, ...] + v_chunk = v[:, chunk_start:chunk_end, ...] + axk = k_chunk.reshape(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H) + axv = v_chunk.reshape(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H) + + attn_bias = ( + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=q_lens, + kv_padding=chunk_end - chunk_start, + kv_seqlen=[max(min(x, chunk_end) - chunk_start, 0) for x in k_lens], + ) + ) + + attn_chunk, lse_chunk = fmha.memory_efficient_attention_forward_requires_grad( + q, + axk, + axv, + attn_bias, + ) + attn_chunk = attn_chunk.reshape(B, -1, N_H_L, D_H) + chunks_output.append((attn_chunk, lse_chunk)) + + # Merge attention from all chunks + attn_split = torch.stack([attn_chunk for attn_chunk, _ in chunks_output]) + lse_split = torch.stack([lse_chunk for _, lse_chunk in chunks_output]) + attn_out, lse_out = fmha.merge_attentions( + attn_split.permute(0, 1, 3, 2, 4), lse_split + ) + + # Compute attention on the full K/V + attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=q_lens, + kv_padding=MAX_T, + kv_seqlen=k_lens, + ) + axk = k.view(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H) + axv = v.view(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H) + attn_full, lse_full = fmha.memory_efficient_attention_forward_requires_grad( + q, + axk, + axv, + attn_bias, + ) + + attn_out = attn_out.reshape(1, B_T, N_H_L, D_H) + torch.testing.assert_close(lse_out, lse_full, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(attn_out, attn_full, rtol=1e-3, atol=1e-3) + + +@sm80_or_better_only +@pytest.mark.parametrize("bmghk", (False, True)) +def test_merging_attentions_against_ref(bmghk: bool): + split_k = 16 + B = 12 + M = 137 + G = 2 if bmghk else 1 + N_H_L = 8 + D_H = 128 + dtype = torch.float32 + + attn_split = torch.randn([split_k, B, N_H_L, G, M, D_H], dtype=dtype, device="cuda") + lse_split = torch.randn([split_k, B, N_H_L, G, M], dtype=dtype, device="cuda") + + if not bmghk: + attn_split = attn_split[:, :, :, 0, :, :] + lse_split = lse_split[:, :, :, 0, :] + + attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split) + + attn_out_ref, lse_out_ref = _merge_attentions_ref(attn_split, lse_split) + + torch.testing.assert_close(lse_out, lse_out_ref, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(attn_out, attn_out_ref, rtol=1e-4, atol=1e-4) + + +def _merge_attentions_ref(attn_split, lse_split): + """ + attn_split: [split_k, B, H, G, M_ceil, Kq] + lse_split: [split_k, B, H, G, M] + """ + is_bmghk = len(attn_split.shape) == 6 + if not is_bmghk: + attn_split = attn_split.unsqueeze(3) + lse_split = lse_split.unsqueeze(3) + + lse_split = lse_split.unsqueeze(5) # [split_k, B, M, G, H, 1] + + lse_max, _ = torch.max(lse_split, dim=0, keepdim=True) # [1, B, M, G, H, 1] + sumexp_normalized = torch.exp(lse_split - lse_max) # [split_k, B, M, G, H, 1] + denominator = sumexp_normalized.sum(dim=0) # [B, M, G, H, 1] + numerator = (sumexp_normalized * attn_split).sum(dim=0) # [B, M, G, H, K] + + attn_out = numerator / denominator # [B, M_ceil, G, H, Kq] + lse_out = (lse_max.squeeze(0) + torch.log(denominator)).squeeze( + 4 + ) # [B, M_ceil, G, H] + + if not is_bmghk: + attn_out = attn_out.squeeze(2) + lse_out = lse_out.squeeze(2) + + return attn_out, lse_out + + # end of file diff --git a/tests/utils.py b/tests/utils.py index 77c606f897..5e0c8eb195 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import List, Optional, Tuple import numpy as np import torch @@ -36,3 +36,70 @@ def assert_allclose( f" at {max_location} of shape {tuple(out.shape)} / atol={atol}, rtol={rtol}" f"/ total failing elements: {num_different}, percentage={percentage}" ) + + +def pack_kv_cache( + cache_k: torch.Tensor, + cache_v: torch.Tensor, + kv_seqlens: List[int], + BLOCK_N: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Create block tables and pages K/V cache for testing paged attention. + Args: + cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D]. + Note that these tensors are unexpanded, + i.e. for multiquery case cache_k.shape[2] = 1 + kv_seqlens: list of K/V sequence lengths + BLOCK_N: number of tokens per per paged attention block + B: batch size + Returns: + block_tables: [B, MAX_BLOCKS] + packed_cache_k: [1, total_len_rounded, H_kv, D] + packed_cache_v: [1, total_len_rounded, H_kv, D] + where total_len_rounded is a sum of K/V seqlens, each rounded up + to a multiple of BLOCK_N. + """ + + kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens] + + total_len_rounded = sum(kv_seqlens_rounded) + + B, MAX_T, H, D = cache_k.shape + + packed_cache_k = torch.empty( + total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype + ) + packed_cache_v = torch.empty( + total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype + ) + seqstart = 0 + for b in range(B): + packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[ + b, : kv_seqlens[b] + ].clone() + packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[ + b, : kv_seqlens[b] + ].clone() + seqstart += kv_seqlens_rounded[b] + + num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N + block_tables = ( + torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32) + .unsqueeze(0) + .expand(B, num_blocks_per_row) + ) + seqstarts = ( + ( + torch.tensor(kv_seqlens_rounded).cumsum(dim=0) + - torch.tensor(kv_seqlens_rounded) + ) + .to(device="cuda") + .unsqueeze(1) + ) // BLOCK_N + block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32) + return ( + block_tables, + packed_cache_k.unsqueeze(0), + packed_cache_v.unsqueeze(0), + ) diff --git a/xformers/attn_bias_utils.py b/xformers/attn_bias_utils.py index cc7f81ce64..b940e72166 100644 --- a/xformers/attn_bias_utils.py +++ b/xformers/attn_bias_utils.py @@ -150,7 +150,10 @@ def create_attn_bias( if bias_type is fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask: block_diag = block_diag.make_causal_from_bottomright() return block_diag - if bias_type == fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask: + if bias_type in [ + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + ]: assert fmt in ["BMHK", "BMGHK"] q, k = _rand_seqlens_padded_k(r, batch_size, q_len, kv_len) g_block_diag = ( @@ -160,6 +163,15 @@ def create_attn_bias( kv_seqlen=k, ) ) + if bias_type == fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask: + page_size = r.choice([64, 128, 256]) + pages_per_row = (kv_len + page_size - 1) // page_size + block_tables = torch.randperm( + batch_size * pages_per_row, device=device + ).reshape(batch_size, pages_per_row) + return g_block_diag.make_paged( + block_tables=block_tables, page_size=page_size + ) return g_block_diag if bias_type == fmha.attn_bias.LocalAttentionFromBottomRightMask: return bias_type( diff --git a/xformers/benchmarks/benchmark_attn_decoding.py b/xformers/benchmarks/benchmark_attn_decoding.py index 52feff7100..2300e233fe 100644 --- a/xformers/benchmarks/benchmark_attn_decoding.py +++ b/xformers/benchmarks/benchmark_attn_decoding.py @@ -67,7 +67,10 @@ def __init__( self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool ) -> None: dtype = torch.float16 - self.sub_label = f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K}" + self.sub_label = ( + f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K} TotalBytes=" + f"{((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2}" + ) self.label = "attn_decoding" self.shapes = (B, Mq, Mkv, Hq, Hkv, K) diff --git a/xformers/ops/fmha/__init__.py b/xformers/ops/fmha/__init__.py index bf7368da3b..8b29eef9a8 100644 --- a/xformers/ops/fmha/__init__.py +++ b/xformers/ops/fmha/__init__.py @@ -414,6 +414,68 @@ def _memory_efficient_attention_backward( return grads +def merge_attentions( + attn_split: torch.Tensor, lse_split: torch.Tensor, write_lse: bool = True +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Combine attention output computed on different parts of K/V for the same + query to get attention on the whole K/V. See https://arxiv.org/abs/2402.05099 + The result is equal to + Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...) + LSE_full = log(exp(LSE1) + exp(LSE2) + ...) + Attention inputs are in BH(G)MK format, stacked along dim 0. Attention output also is in BH(G)MK. + Args: + attn_split: [split_k, B, H, G, M, Kq] or [split_k, B, H, M, Kq] + lse_split: [split_k, B, H, G, M] or [split_k, B, H, M] + Res: + attn_out: [B, H, G, M, K] or [B, H, M, K] + lse_out: [B, H, G, M] or [B, H, M] + """ + + assert ( + attn_split.ndim == lse_split.ndim + 1 + ), f"{attn_split.shape=} {lse_split.shape=}" + + is_bmhk = attn_split.ndim == 5 + if is_bmhk: + attn_split = attn_split.unsqueeze(3) + lse_split = lse_split.unsqueeze(3) + + split_k, B, H, G, M_ceil, Kq = attn_split.shape + split_k1, B1, H1, G1, M = lse_split.shape + assert ( + B == B1 and G == G1 and H == H1 and split_k == split_k1 and M_ceil >= M + ), f"{attn_split.shape=} {lse_split.shape=}" + + attn_split = attn_split.permute(1, 2, 3, 0, 4, 5).view( + B, H * G, split_k, M_ceil, Kq + ) + lse_split = lse_split.permute(1, 2, 3, 0, 4).view(B, H * G, split_k, M) + + attn_out = torch.empty( + B, H, G, M, Kq, device=attn_split.device, dtype=attn_split.dtype + ) + if write_lse: + lse_out = torch.empty( + B * H * G, M, device=attn_split.device, dtype=torch.float32 + ) + else: + lse_out = None + + triton_splitk.merge_attentions( + attn_out.permute(0, 3, 2, 1, 4), lse_out, attn_split, lse_split + ) + if lse_out is not None: + lse_out = lse_out.view(B, H, G, M) + + if is_bmhk: + attn_out = attn_out[:, :, 0] + if lse_out is not None: + lse_out = lse_out[:, :, 0] + + return attn_out, lse_out + + ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [ cutlass.FwOp, flash.FwOp, diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index 78044f7db5..bf6b12842d 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -327,10 +327,14 @@ def intervals(self) -> Iterable[Tuple[int, int]]: yield from zip(self.seqstart_py, self.seqstart_py[1:]) @classmethod - def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + def _get_seqstart( + cls, seqlens: Iterable[int] + ) -> Tuple[int, int, List[int], torch.Tensor]: """ - Input tensors are assumed to be in shape [B, M, *] + Given sequence lengths, returns the min/max value and the sequence start + positions (offsets), with first element being 0 (returned in list and Tensor). """ + assert not isinstance(seqlens, torch.Tensor) seqstart_py = [0] max_seqlen = -1 @@ -340,6 +344,16 @@ def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": max_seqlen = max(max_seqlen, seqlen) seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) seqstart = torch.tensor(seqstart_py, dtype=torch.int32) + + return (min_seqlen, max_seqlen, seqstart_py, seqstart) + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + """ + min_seqlen, max_seqlen, seqstart_py, seqstart = cls._get_seqstart(seqlens) + return cls( max_seqlen=max_seqlen, min_seqlen=min_seqlen, @@ -440,7 +454,9 @@ def from_seqlens_padded( seqstart = padding * torch.arange(batch_size) """ assert not isinstance(seqlens, torch.Tensor) - assert all(seqlen <= padding for seqlen in seqlens) + assert all( + seqlen <= padding for seqlen in seqlens + ), f"Seqlens {seqlens} Padding {padding}" seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) return cls( seqlen=torch.tensor(seqlens, dtype=torch.int32), @@ -839,6 +855,106 @@ def from_seqlens( k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + def make_paged(self, block_tables: torch.Tensor, page_size: int): + paged_bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + block_tables=block_tables, + page_size=page_size, + ) + paged_bias.k_seqinfo.padding = block_tables.shape[1] * page_size + return paged_bias + + +@dataclass +class PagedBlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias): + """ + Same as BlockDiagonalCausalWithOffsetPaddedKeysMask, but for paged attention. + block_tables has shape [batch_size, max_num_pages] and K/V have shape + [1, max_num_pages * page_size, num_heads, head_dim] + or [1, max_num_pages * page_size, num_groups, num_heads, head_dim] + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _PaddedSeqLenInfo + block_tables: torch.Tensor + page_size: int + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + # First create a non-paged mask, then cut individual pages and + # copy them to their places in the physical mask, using block tables + + max_row_len = self.block_tables.shape[1] * self.page_size + logical_input_len = self.block_tables.shape[0] * max_row_len + bias_nonpaged = BlockDiagonalCausalWithOffsetPaddedKeysMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=_PaddedSeqLenInfo.from_seqlens_padded( + self.k_seqinfo.seqlen_py, max_row_len + ), + ) + mask_nonpaged = bias_nonpaged.materialize(shape, dtype, device) + + mask_paged = torch.empty( + mask_nonpaged.shape[:-1] + (logical_input_len,), dtype=dtype, device=device + ) + mask_paged.fill_(-math.inf) + for b, (q_start, q_end) in enumerate(self.q_seqinfo.intervals()): + for logical_page_idx in range(self.block_tables.shape[1]): + physical_page_idx = self.block_tables[b][logical_page_idx] + k_logical_start = ( + b * self.block_tables.shape[1] + logical_page_idx + ) * self.page_size + k_logical_end = k_logical_start + self.page_size + k_physical_start, k_physical_end = ( + physical_page_idx * self.page_size, + (physical_page_idx + 1) * self.page_size, + ) + mask_paged[ + ..., q_start:q_end, k_physical_start:k_physical_end + ] = mask_nonpaged[..., q_start:q_end, k_logical_start:k_logical_end] + + return mask_paged + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqlen: Sequence[int], + block_tables: torch.Tensor, + page_size: int, + ) -> "PagedBlockDiagonalCausalWithOffsetPaddedKeysMask": + """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + BlockDiagonalCausalWithOffsetPaddedKeysMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded( + kv_seqlen, padding=block_tables.shape[1] * page_size + ) + return cls( + q_seqinfo=q_seqinfo, + k_seqinfo=k_seqinfo, + block_tables=block_tables, + page_size=page_size, + ) + @dataclass class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index ff7dea3587..a7aaf7d887 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -288,11 +288,12 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: Returns a list of reasons why this is not supported. The kernel can run these inputs only if the returned list is empty """ + query_shape = d.query.shape reasons = cls.shape_not_supported_reasons( - Mq=d.query.shape[1], + Mq=query_shape[1], Mkv=d.key.shape[1], - K=d.query.shape[-1], - Kv=d.value.shape[-1], + K=query_shape[-1], + Kv=query_shape[-1] if d.value.dtype == torch.int32 else d.value.shape[-1], ) device_type = d.query.device.type dtype = d.query.dtype diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 20392ab890..acf09c3b68 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -135,7 +135,7 @@ def _flash_fwd( query, key, value, - None, + None, # out cu_seq_lens_q, cu_seq_lens_k, seqused_k, @@ -286,7 +286,7 @@ def _convert_input_format( max_seqlen_q = inp.query.shape[1] max_seqlen_k = inp.key.shape[1] - if query.ndim == 5: # QGA + if query.ndim == 5: # GQA assert supports_mqa # Fold the group/head_in_group dimensions together diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index 1c4f6d9421..7225137629 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -4,128 +4,178 @@ # LICENSE file in the root directory of this source tree. -from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple +import functools +import sys +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type import torch +import triton +import triton.language as tl -from ..common import _has_triton21, register_operator -from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask +from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs + +from ..common import register_operator +from .attn_bias import ( + BlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, +) from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1 -def _strides(x: torch.Tensor, *stride_names: str): +def _strides(x: Optional[torch.Tensor], *stride_names: str): + if x is None: + return {f"stride_{name}": None for name in stride_names} assert x.ndim == len(stride_names) - return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} - - -if TYPE_CHECKING or _has_triton21(): - import triton - import triton.language as tl - - from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs - - @triton.jit - def _fwd_kernel_splitK( - Q, - K, - V, - sm_scale, - Out_splitK, # [B, H, split_k, Mq, K] - Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] - Seq_len, - stride_qz, - stride_qm, - stride_qg, - stride_qh, - stride_qk, - stride_kz, - stride_kn, - stride_kg, - stride_kh, - stride_kk, - stride_vz, - stride_vn, - stride_vg, - stride_vh, - stride_vk, - stride_osk_zhg, - stride_osk_s, - stride_osk_m, - stride_osk_k, - stride_mzhg, - stride_m2, - stride_ms, - stride_mm, - Z, - N_CTX_Q, - N_CTX_K, - BLOCK_N_PER_SPLIT, - H: tl.constexpr, - G: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - BOUNDS_CHECKS_N: tl.constexpr, - USE_SEQ_LEN: tl.constexpr, - PACKED_PER_VAL: tl.constexpr = 1, - N_GROUPS: tl.constexpr = 1, - ): - """This kernel can accept non-quantized or int4-quantized keys/values. - PACKED_PER_VAL determines the quantization type: - - PACKED_PER_VAL == 1 means no quantization - - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) - For the quantized case K/V should be int32 tensors. - Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. - Quantization coefficients are stored at the beginning of the row along the last dimension of K/V - So K[B, H, M, :] has a form - [ quant_coef0, quant_coef1, ...| - group0_quant_value0, group0_quant_value1,... | - group1_quant_value0, group1_quant_value1,...] - where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. - - Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs - before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists. - See how FwOp.apply does it below. - """ - tl.static_assert( - (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) - or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), - f"Only 4-bit quantization is supported, K/V should have dtype int32 in " - f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", - ) - tl.static_assert( - (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), - "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", - ) - - QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 - PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS - D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS - - start_m = tl.program_id(0) - off_zhg = tl.program_id(1) - off_z = off_zhg // (H * G) - off_h = (off_zhg // G) % H - off_g = off_zhg % G - splitk_idx = tl.program_id(2) - - lo = splitk_idx * BLOCK_N_PER_SPLIT - if USE_SEQ_LEN: - kv_len = tl.load(Seq_len + off_z) - else: - kv_len = N_CTX_K - hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) - - Q_block_ptr = tl.make_block_ptr( - base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, - shape=(N_CTX_Q, D_PER_GROUP), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, D_PER_GROUP), - order=(1, 0), - ) - - k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + return {f"stride_{name}": s for name, s in zip(stride_names, x.stride())} + + +AUTOTUNER_KEY = [ + "Z", + "H", + "G", + "N_CTX_Q", + "N_CTX_K", + "BLOCK_DMODEL", + "PACKED_PER_VAL", + "N_GROUPS", + "BLOCK_N_PER_SPLIT", +] + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + LSE_splitk, # [B, H, split_k, Mq] + block_tables, + Seq_len, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_z, + stride_osk_hg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_lsek_z, + stride_lsek_hg, + stride_lsek_s, + stride_lsek_m, + stride_blocktablesz, + stride_blocktablesl, + kv_cache_blocks_per_row: tl.constexpr, + Z: tl.constexpr, + N_CTX_Q: tl.constexpr, # The number of queries + N_CTX_K: tl.constexpr, + BLOCK_N_PER_SPLIT: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + N_GROUPS: tl.constexpr, + # It's important that BOUNDS_CHECKS_N, BLOCK_M, BLOCK_N come at the end of + # the argument list, since they are provided by the heuristics/autotune decorator. + # Otherwise Triton throws IndexError + BOUNDS_CHECKS_N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_SPLITK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + NUM_QUERIES_CAUSAL: tl.constexpr, # The N_CTX_Q queries are from this many sequence positions + USE_PAGED_ATTENTION: tl.constexpr, + PAGE_SIZE: tl.constexpr, + WRITE_LSE: tl.constexpr, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs + before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists. + See how FwOp.apply does it below. + + Set IS_SPLITK=False to indicate the MHA result should be written directly. + No metadata will be written. + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H * G) + off_hg = off_zhg % (H * G) + off_h = off_hg // G + off_g = off_hg % G + splitk_idx = tl.program_id(2) + + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + else: + kv_len = N_CTX_K + + k_base = K + off_h * stride_kh + off_g * stride_kg + v_base = V + off_h * stride_vh + off_g * stride_vg + + # Boundaries of split-k chunk + chunk_hi = (splitk_idx + 1) * BLOCK_N_PER_SPLIT + chunk_lo = splitk_idx * BLOCK_N_PER_SPLIT + # For paged attention case K/V_block_ptr are defined inside the loop + # whereas for non-paged case they are defined before the loop. + if PAGE_SIZE > 0: + # Page contains several blocks + BLOCKS_IN_PAGE: tl.constexpr = PAGE_SIZE // BLOCK_N + # Align boundaries of split-k chunk to page boundaries + # In the last chunk, shift hi to the right, in the other chunks, shift it to the left + is_last_chunk = splitk_idx == tl.num_programs(2) - 1 + shift = PAGE_SIZE - 1 if is_last_chunk else 0 + lo = (chunk_lo // PAGE_SIZE) * PAGE_SIZE + hi = ((chunk_hi + shift) // PAGE_SIZE) * PAGE_SIZE + hi = tl.minimum(hi, kv_len) + block_table = block_tables + stride_blocktablesz * off_z + # Offset in integer blocks + logical_block_idx = lo // BLOCK_N + else: + lo = chunk_lo + hi = tl.minimum(chunk_hi, kv_len) + k_base += off_z * stride_kz + v_base += off_z * stride_vz # Additional shift by 1 along the last dimension in the quantized case, since # the first element along that dim contains packed quantization coefficients. K_block_ptr = tl.make_block_ptr( @@ -136,7 +186,6 @@ def _fwd_kernel_splitK( block_shape=(PACKED_D_PER_GROUP, BLOCK_N), order=(0, 1), ) - v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg V_block_ptr = tl.make_block_ptr( base=v_base + stride_vk * QUANTIZED * N_GROUPS, shape=(hi, PACKED_D_PER_GROUP), @@ -170,68 +219,151 @@ def _fwd_kernel_splitK( K_scale_shift_block_ptr = None V_scale_shift_block_ptr = None - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + # Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs. + # That turns tensors annotated as the one below into lists of tensors of length N_GROUPS. + # This is a solution for Triton native lack of support for lists of tensors. + acc: "VAR_ARGS_ARRAY" # noqa: F821 + + for i in range(len(acc)): # noqa: F821 + acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(acc)): # noqa: F821 + q[i] = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,) + ) - # Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs. - # That turns tensors annotated as the one below into lists of tensors of length N_GROUPS. - # This is a solution for Triton native lack of support for lists of tensors. - acc: "VAR_ARGS_ARRAY" # noqa: F821 + if IS_CAUSAL: + # Why does the masking conditon below work as a causal mask? + # Assuming num_queries <= BLOCK_M: + # kv_pos = kv_start + range(0, BLOCK_N) + # q_offset = start_m * BLOCK_M + range(0, BLOCK_M) + # q_pos = kv_start + kv_len - num_queries + q_offset % num_queries + # mask = q_pos - kv_pos >= 0 + # So the final masking condition is: + # range(0, BLOCK_M) % num_queries - range(0, BLOCK_N) >= num_queries - kv_len + + q_offset = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + diag_idx = (q_offset[:, None] % NUM_QUERIES_CAUSAL) - tl.arange(0, BLOCK_N)[ + None, : + ] + diag_idx_shifted = tl.constexpr(diag_idx - NUM_QUERIES_CAUSAL + kv_len) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + if PAGE_SIZE > 0: + # Offset in integer blocks from the beginning of the page + block_offset_in_page = logical_block_idx % BLOCKS_IN_PAGE + # Offset in integer pages + logical_page_idx = logical_block_idx // BLOCKS_IN_PAGE + physical_page_idx = tl.load( + block_table + stride_blocktablesl * logical_page_idx + ).to(tl.int32) + offset = physical_page_idx * PAGE_SIZE + block_offset_in_page * BLOCK_N + + current_block_size = min(hi - start_n, BLOCK_N) + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, offset + current_block_size), + strides=(stride_kk, stride_kn), + offsets=(0, offset), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(offset + current_block_size, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(offset, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + if QUANTIZED: + # Pointers to quantization coefficients. Even those they are 1D, + # we have to use block pointers, since usual pointers + # don't support boundary checks + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, offset + current_block_size), + strides=(stride_kk, stride_kn), + offsets=(0, offset), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(offset + current_block_size, 1), + strides=(stride_vn, stride_vk), + offsets=(offset, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + logical_block_idx += 1 + k: "VAR_ARGS_ARRAY" # noqa: F821 + v: "VAR_ARGS_ARRAY" # noqa: F821 for i in range(len(acc)): # noqa: F821 - acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout - q: "VAR_ARGS_ARRAY" # noqa: F821 - for i in range(len(acc)): # noqa: F821 - q[i] = tl.load( # noqa: F821 - tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,) + k[i], v[i] = load_dequantize_k_v_group( # noqa: F821 + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + i, ) - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - k: "VAR_ARGS_ARRAY" # noqa: F821 - v: "VAR_ARGS_ARRAY" # noqa: F821 - for i in range(len(acc)): # noqa: F821 - k[i], v[i] = load_dequantize_k_v_group( # noqa: F821 - K_block_ptr, - V_block_ptr, - K_scale_shift_block_ptr, - V_scale_shift_block_ptr, - BOUNDS_CHECKS_N, - PACKED_PER_VAL, - PACKED_D_PER_GROUP, - Q.dtype.element_ty, - i, - ) - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - for i in range(len(acc)): # noqa: F821 - qk += tl.dot(q[i], k[i]) # noqa: F821 - qk *= qk_scale - - # TODO: This is slow, and only needed at the last iteration. - # Maybe we can unroll the last iteration instead? - if BOUNDS_CHECKS_N: - qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - p = p.to(Q.dtype.element_ty) - - # -- scale and update acc -- - for i in range(len(acc)): # noqa: F821 - acc[i] *= alpha[:, None] # noqa: F821 - acc[i] += tl.dot(p, v[i]) # noqa: F821 + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + for i in range(len(acc)): # noqa: F821 + qk += tl.dot(q[i], k[i]) # noqa: F821 + qk *= qk_scale + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + if IS_CAUSAL: + # -- apply the causal mask -- + p = tl.where(diag_idx_shifted >= start_n, p, 0) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + for i in range(len(acc)): # noqa: F821 + acc[i] *= alpha[:, None] # noqa: F821 + acc[i] += tl.dot(p, v[i]) # noqa: F821 + + if not PAGE_SIZE: # update pointers K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) @@ -243,215 +375,303 @@ def _fwd_kernel_splitK( V_scale_shift_block_ptr, (BLOCK_N, 0) ) - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, - shape=(N_CTX_Q, D_PER_GROUP), - strides=(stride_osk_m, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, D_PER_GROUP), - order=(1, 0), + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + + off_z * stride_osk_z + + off_hg * stride_osk_hg + + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + for i in range(len(acc)): # noqa: F821 + # If for the current batch element there are no tokens in the current split-k chunk (because + # seqlen is too short), l_i will be 0, so we need to make sure attention is filled with zeros and not NaNs. + attn_out = tl.where(l_i[:, None] == 0, 0.0, acc[i] / l_i[:, None]) # noqa: F821 + tl.store( + tl.advance(O_block_ptr, (0, i * D_PER_GROUP)), + attn_out.to(Out_splitK.dtype.element_ty), # noqa: F821 + boundary_check=(0,), ) - for i in range(len(acc)): # noqa: F821 - tl.store( - tl.advance(O_block_ptr, (0, i * D_PER_GROUP)), - acc[i], # noqa: F821 - boundary_check=(0,), - ) - # Write metadata for split-K reduction - Metadata_ptr = ( - Metadata - + off_zhg * stride_mzhg - + splitk_idx * stride_ms - + start_m * BLOCK_M - + tl.arange(0, BLOCK_M) + if WRITE_LSE: + LSE_splitk_ptr = ( + LSE_splitk + + off_z * stride_lsek_z + + off_hg * stride_lsek_hg + + splitk_idx * stride_lsek_s + + (start_m * BLOCK_M + tl.arange(0, BLOCK_M)) * stride_lsek_m + ) + mask = start_m * BLOCK_M + tl.arange(0, BLOCK_M) < N_CTX_Q + lse_dtype = LSE_splitk.dtype.element_ty # Can be float64 to improve numerics + tl.store( + LSE_splitk_ptr, + (tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / 1.44269504, + mask=mask, ) - tl.store(Metadata_ptr, m_i) - tl.store(Metadata_ptr + stride_m2, l_i) - - @triton.jit - def load_dequantize_k_v_group( - K_block_ptr, - V_block_ptr, - K_scale_shift_block_ptr, - V_scale_shift_block_ptr, - BOUNDS_CHECKS_N: tl.constexpr, - PACKED_PER_VAL: tl.constexpr, - PACKED_D_PER_GROUP: tl.constexpr, - dtype: tl.constexpr, - group_id: tl.constexpr, - ): - """Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading. - If quantization is group-wise, use group_id to advance the pointers to the current group. - """ - # Advance to the current quantization group - K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) - V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) - # -- load k, v -- - k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) - v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) - if PACKED_PER_VAL > 1: - # K/V are quantized, load quantization coefficients and dequantize +def gen_config( + block_m: int, + block_n: int, + stages: int, + warps: int, +) -> triton.Config: + """A more compact way to define a triton.Config, so it fits on one line""" - K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) - V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + return triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + }, + num_stages=stages, + num_warps=warps, + ) - k_scale_shift = tl.load( - K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () - ) - v_scale_shift = tl.load( - V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + +def _get_splitk_kernel(num_groups): + """ + Kernel _fwd_kernel_splitK needs to be post-processed by unroll_varargs + to specialize it for a given number of quantization groups N_GROUPS + before we can apply triton.heuristics and triton.autotune, so we + don't do them as decorators. + """ + + _fwd_kernel_splitK_unrolled = unroll_varargs(_fwd_kernel_splitK, N=num_groups) + kernel = triton.heuristics( + { + "BOUNDS_CHECKS_N": lambda args: ( + args["BLOCK_N_PER_SPLIT"] % args["BLOCK_N"] ) + > 0 + or args["USE_SEQ_LEN"] + } + )(_fwd_kernel_splitK_unrolled) + return kernel + + +@functools.lru_cache(None) +def autotune_kernel(kernel: Callable): + BLOCK_M_VALUES = [16, 32] + BLOCK_N_VALUES = [32, 64, 128] + STAGES_VALUES = [1, 2, 3] + WARPS_VALUES = [1, 2, 4] + + TRITON_CONFIGS = [ + gen_config(block_m, block_n, stages, warps) + for block_m in BLOCK_M_VALUES + for block_n in BLOCK_N_VALUES + for stages in STAGES_VALUES + for warps in WARPS_VALUES + ] + + kernel = triton.autotune( + configs=TRITON_CONFIGS, + key=AUTOTUNER_KEY, + )(kernel) + return kernel + + +# This object contains forward kernels wrapped into autotuner for different number +# of quantization groups. +_fwd_kernel_splitK_autotune: Dict[int, triton.runtime.Autotuner] = {} +# The loop below: +# - transforms the jitted kernel with unroll_varargs producing a new kernel of each value of num_groups +# - wraps the kernel into triton.heuristics +# - wraps kernel into Triton autotuner. Autotuning itself happens the first time the kernel is called +if sys.version_info >= (3, 9): + # unroll_varargs requires Python 3.9+ + for num_groups in [1, 2, 4, 8]: + _fwd_kernel_splitK_autotune[num_groups] = autotune_kernel( + _get_splitk_kernel(num_groups) + ) - k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) - v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) - v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) - k_t = dequantize( - tl.trans(k), - tl.trans(k_scale), - tl.trans(k_shift), - PACKED_PER_VAL, - ).to(dtype) - k = tl.trans(k_t) - return k, v - - @triton.jit - def cast_uint32_to_half2(scale_shift): - """Extract two float16 packed into one int32""" - scale = scale_shift & 0xFFFF - shift = scale_shift >> 16 - scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) - shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) - return scale, shift - - @triton.jit - def dequantize( - x_, - scale, - shift, - PACKED_PER_VAL: tl.constexpr = 8, - ): - """PACKED_PER_VAL is the number of values packed into each element x_. - For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8. + def get_autotuner_cache(num_groups: int) -> Dict[Tuple[int], triton.Config]: + """Returns a triton.runtime.autotuner.AutoTuner.cache object, which + represents mappings from kernel autotune keys (tuples describing kernel inputs) + to triton.Config """ + return _fwd_kernel_splitK_autotune[num_groups].cache + + def set_autotuner_cache( + cache: Dict[Tuple[int], triton.Config], num_groups: int + ) -> None: + _fwd_kernel_splitK_autotune[num_groups].cache = cache + + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + """Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading. + If quantization is group-wise, use group_id to advance the pointers to the current group. + """ + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) - # Axis along which offsets are applied matters here - # It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL) - # and expand K/V to that shape before applying offsets - # However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2 - # Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends - # on the implementation details which might change in the future. - # Ideally we would like to use tl.reshape, but it's not implemented yet. - # See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501 - - # x_ : (BLOCK_N, D // PACKED_PER_VAL) - # scale: (BLOCK_N, 1) - # offsets: (PACKED_PER_VAL,) - BLOCK_N: tl.constexpr = x_.shape[0] - BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] - offsets = tl.arange(0, PACKED_PER_VAL) * 4 - quant_offset = ( - x_[:, None, :] >> offsets[None, :, None] - ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) - - quant_offset = tl.view( - quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () ) - # Trick - instead of converting int4 to float16 we view it as float16 - # and then multiply by 32768 * 512 == 2**24 - quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) - quant_offset = (quant_offset * 32768.0).to(tl.float16) - scale_512 = scale * 512 - - dequant = quant_offset * scale_512 + shift - return dequant - - @triton.jit - def _splitK_reduce( - Out_splitK, # [B, H, split_k, Mq, K] - Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] - Out, # [B, H, M, K] - LSE, # [B, H, M] - split_k, - stride_osk_zhg, - stride_osk_s, - stride_osk_m, - stride_osk_k, - stride_mzhg, - stride_m2, - stride_ms, - stride_mm, - stride_oz, - stride_oh, - stride_og, - stride_om, - stride_ok, - stride_lse_zhg, - stride_lse_m, - BLOCK_SIZE: tl.constexpr, - H: tl.constexpr, - G: tl.constexpr, - ): - off_zhg = tl.program_id(0) - off_z = off_zhg // (H * G) - off_h = (off_zhg // G) % H - off_g = off_zhg % G - off_m = tl.program_id(1) - - Out_splitK_ptr = ( - Out_splitK - + stride_osk_zhg * off_zhg - + stride_osk_m * off_m - + tl.arange(0, BLOCK_SIZE) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () ) - Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m - m = tl.load(Metadata_ptr) - l_sum = tl.load(Metadata_ptr + stride_m2) - acc = tl.load(Out_splitK_ptr) - - for split_k_idx in range(1, split_k): - Metadata_ptr = Metadata_ptr + stride_ms - Out_splitK_ptr = Out_splitK_ptr + stride_osk_s - - m_k = tl.load(Metadata_ptr) - l_k = tl.load(Metadata_ptr + stride_m2) - acc_k = tl.load(Out_splitK_ptr) - - m_new = tl.maximum(m, m_k) - if m_k < m: - # Scale incoming values - alpha = tl.math.exp2(m_k - m_new) - acc_k = acc_k * alpha - l_k = l_k * alpha - else: - # Scale our values - alpha = tl.math.exp2(m - m_new) - acc = acc * alpha - l_sum = l_sum * alpha - - m = m_new - l_sum = l_sum + l_k - acc = acc + acc_k - - acc = acc / l_sum - Out_ptr = ( - Out - + stride_oz * off_z - + stride_oh * off_h - + stride_og * off_g - + stride_om * off_m - + tl.arange(0, BLOCK_SIZE) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + """Extract two float16 packed into one int32""" + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + """PACKED_PER_VAL is the number of values packed into each element x_. + For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8. + """ + # x_ : (BLOCK_N, D // PACKED_PER_VAL) + # scale: (BLOCK_N, 1) + # offsets: (PACKED_PER_VAL,) + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, :, None, :] >> offsets + ) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL) + + quant_offset = tl.reshape( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + LSE_splitK, # [B, H, split_k, Mq] + Out, # [B, H, M, K] + LSE, # [B, H, M] + split_k: tl.constexpr, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_lsek_zhg, + stride_lsek_s, + stride_lsek_m, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + WRITE_LSE: tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + + Out_splitK_ptr = ( + Out_splitK + + stride_osk_zhg * off_zhg + + stride_osk_m * off_m + + tl.arange(0, BLOCK_SIZE) + ) + + LSE_splitK_ptr0 = LSE_splitK + stride_lsek_zhg * off_zhg + stride_lsek_m * off_m + LSE_splitK_ptr = LSE_splitK_ptr0 + lse_max = tl.load(LSE_splitK_ptr) + for split_k_idx in tl.static_range(1, split_k): + LSE_splitK_ptr = LSE_splitK_ptr + stride_lsek_s + lse_splitk = tl.load(LSE_splitK_ptr) + lse_max = tl.maximum(lse_max, lse_splitk) + + sumexp_normalized = 0.0 + numerator_normalized = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + LSE_splitK_ptr = LSE_splitK_ptr0 + for split_k_idx in tl.static_range(0, split_k): + out_splitk = tl.load(Out_splitK_ptr) + lse_splitk = tl.load(LSE_splitK_ptr) + # Compute denominator + sumexp_normalized_splitk = tl.math.exp2( + (lse_splitk - lse_max).to(tl.float32) * 1.44269504 ) - tl.store(Out_ptr, acc) + sumexp_normalized += sumexp_normalized_splitk + + # Compute numerator + numerator_normalized += out_splitk * sumexp_normalized_splitk + LSE_splitK_ptr = LSE_splitK_ptr + stride_lsek_s + Out_splitK_ptr = Out_splitK_ptr + stride_osk_s - l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m - tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504) + acc = numerator_normalized / sumexp_normalized -else: - _fwd_kernel_splitK = None - _splitK_reduce = None + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + tl.arange(0, BLOCK_SIZE) + ) + tl.store(Out_ptr, acc) + + if WRITE_LSE: + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m * stride_lse_m + tl.store(l_ptrs, (lse_max + tl.math.log2(sumexp_normalized) / 1.44269504)) @register_operator @@ -475,6 +695,11 @@ class FwOp(AttentionFwOpBase): group_dequant = group_quant[..., 1:] * scale + shift ... + This op uses Paged Attention when bias is PagedBlockDiagonalCausalWithOffsetPaddedKeysMask. + In this case bias has additional fields: + - block_tables of shape [batch_size, max_num_pages] + - K/V of shape [1, max_num_pages * page_size, num_heads, head_dim] + or [1, max_num_pages * page_size, num_groups, num_heads, head_dim] """ OPERATOR = _fwd_kernel_splitK @@ -488,6 +713,7 @@ class FwOp(AttentionFwOpBase): SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { type(None), BlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, } SUPPORTS_DROPOUT = False SUPPORTS_CUSTOM_SCALE = True @@ -495,10 +721,19 @@ class FwOp(AttentionFwOpBase): NAME = "triton_splitKF" SPLIT_K: Optional[int] = None - BLOCK_M = 16 - BLOCK_N = 64 + MAX_BLOCK_M = 32 + + # Perform kernel-level Triton autotune + AUTOTUNE = False NUM_GROUPS = 1 # Default quantization is row-wise + NUM_GROUPS_VALUES = [1, 2, 4, 8] + + # values used when autotune=False + BLOCK_M: int = 16 + BLOCK_N: int = 64 + NUM_STAGES: int = 1 + NUM_WARPS: int = 2 @classmethod def shape_not_supported_reasons( @@ -526,8 +761,14 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: ) q_len = d.query.shape[1] - if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): - seqinfo = d.attn_bias.q_seqinfo + is_block_diagonal = isinstance( + d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask + ) + is_paged = isinstance( + d.attn_bias, PagedBlockDiagonalCausalWithOffsetPaddedKeysMask + ) + if is_block_diagonal or is_paged: + seqinfo = d.attn_bias.q_seqinfo # type: ignore if q_len != seqinfo.seqstart_py[-1]: reasons.append( f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}" @@ -537,15 +778,28 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons.append( "Variable query len is not supported in the presence of causal mask." ) + if q_len > 16: + # 16 is the minimum BLOCK_M which gets used + # XXX I don't really understand why this is needed. + reasons.append("Query length should not be larger than 16") + + if is_paged: + page_size = d.attn_bias.page_size # type: ignore + if d.key.shape[1] % page_size: + reasons.append( + "For paged attention, key.shape[1] should be divisible " + "by the page size, " + f"but got {d.key.shape[1]=}, {page_size=}." + ) + if cls.AUTOTUNE: + reasons.append("Paged attention doesn't support autotuning yet.") + if page_size % cls.BLOCK_N: + reasons.append( + "For paged attention, page size should be divisible " + "by the block size, " + f"but got {page_size=}, {cls.BLOCK_N=}." + ) - if d.key.ndim in [4, 5] and d.key.shape[-2] != 1: - if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1: - reasons.append("multiquery is only supported with query seqlen=1") - - if d.attn_bias is not None and q_len > 1: - reasons.append( - "query with seqlen > 1 is not supported in the presence of causal mask" - ) return reasons @classmethod @@ -567,29 +821,51 @@ def apply( attn_bias = inp.attn_bias seq_len = None q, k, v = inp.get_qkv_in_bmghk() + IS_CAUSAL = False + NUM_QUERIES_CAUSAL = 1 + is_block_diagonal = isinstance( + attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask + ) + is_paged = isinstance( + attn_bias, PagedBlockDiagonalCausalWithOffsetPaddedKeysMask + ) if attn_bias is not None: - assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + assert is_paged or is_block_diagonal # TODO: do we really need to do this cast? seems fishy but # I just copied it from the decoder.py - attn_bias.k_seqinfo.to(inp.query.device) - attn_bias.q_seqinfo.to(inp.query.device) - seq_len = attn_bias.k_seqinfo.seqlen + attn_bias.k_seqinfo.to(inp.query.device) # type: ignore + attn_bias.q_seqinfo.to(inp.query.device) # type: ignore + seq_len = attn_bias.k_seqinfo.seqlen # type: ignore B = len(seq_len) - G, H, Kq = q.shape[-3:] + G, Hq, Kq = q.shape[-3:] Kkv = v.shape[-1] # assume kv has been padded - q = q.reshape(B, -1, G, H, Kq) - k = k.reshape(B, -1, G, H, Kkv) - v = v.reshape(B, -1, G, H, Kkv) + q = q.reshape(B, -1, G, Hq, Kq) + if is_paged: + k = k.view(1, -1, G, Hq, Kkv) + v = v.view(1, -1, G, Hq, Kkv) + else: + k = k.reshape(B, -1, G, Hq, Kkv) + v = v.reshape(B, -1, G, Hq, Kkv) + Mq = q.shape[1] + IS_CAUSAL = Mq > 1 + NUM_QUERIES_CAUSAL = Mq + else: + B, Mq, G, Hq, Kq = q.shape - # Transpose in the case of MQA/GQA + # In the case of MQA/GQA, we make q have sequence length (H * Mq) and only one "head". mqa_swap_seqlen_head = False - if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + if ( + not needs_gradient + and k.shape[3] > 1 + and k.stride(3) == 0 + and v.stride(3) == 0 + ): mqa_swap_seqlen_head = True - assert q.shape[1] == 1 - q = q.transpose(1, 3) + # This is a copy iff Mq, G and H are all > 1. + q = q.permute(0, 3, 1, 2, 4).reshape(B, -1, G, 1, Kq) k = k[:, :, :, :1] v = v[:, :, :, :1] @@ -605,136 +881,230 @@ def apply( B, M, G, H, Kq = q.shape assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" - BLOCK_M = cls.BLOCK_M - BLOCK_N = cls.BLOCK_N + page_size = inp.attn_bias.page_size if is_paged else 0 # type: ignore + block_tables = None + kv_cache_blocks_per_row = 0 + if is_paged: + block_tables = inp.attn_bias.block_tables # type: ignore + kv_cache_blocks_per_row = block_tables.shape[1] + Mk = block_tables.shape[1] * page_size + elif attn_bias is not None: + Mk = min(Mk, attn_bias.k_seqinfo.max_seqlen) # type: ignore + if cls.SPLIT_K is not None: split_k = cls.SPLIT_K else: # Use heuristics split_k = cls.get_split_k(B, H, Mk) - M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M - o_splitk = torch.empty( - [B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device - ) - metadata = torch.empty( - [B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device - ) - lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) - grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + if is_paged: + # Avoid having more than one split per page + split_k = min(split_k, block_tables.shape[1]) # type: ignore + # M_ceil = M rounded up to a multiple of MAX_BLOCK_M + M_ceil = (M + cls.MAX_BLOCK_M - 1) // cls.MAX_BLOCK_M * cls.MAX_BLOCK_M + IS_SPLITK = split_k > 1 # or cls.autotune? + if IS_SPLITK: + o_splitk = torch.empty( + [B, G * H, split_k, M_ceil, Kq], + dtype=torch.float32, + device=q.device, + ) + else: + o_splitk = torch.empty( + [B, split_k, M, G * H, Kq], + dtype=q.dtype, + device=q.device, + ).permute(0, 3, 1, 2, 4) + lse, lse_splitk = None, None + if IS_SPLITK and needs_gradient: + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + if IS_SPLITK or needs_gradient: + lse_splitk = torch.empty( + [B, G * H, split_k, M], + dtype=torch.float64 if IS_SPLITK else torch.float32, + device=q.device, + ) + + def grid(META): + return triton.cdiv(M, META["BLOCK_M"]), B * G * H, split_k - num_warps = 2 split_size = (Mk + split_k - 1) // split_k use_seq_len = seq_len is not None - _fwd_kernel_splitK_unrolled = unroll_varargs( - _fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1 - ) - _fwd_kernel_splitK_unrolled[grid]( + num_groups = cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1 + if cls.AUTOTUNE: + kernel = _fwd_kernel_splitK_autotune[num_groups] + extra_args = {} + else: + kernel = _get_splitk_kernel(num_groups) + extra_args = { + "BLOCK_M": cls.BLOCK_M, + "BLOCK_N": cls.BLOCK_N, + "num_warps": cls.NUM_WARPS, + "num_stages": cls.NUM_STAGES, + } + kernel[grid]( Q=q, K=k, V=v, sm_scale=inp.scale_float, Out_splitK=o_splitk, - Metadata=metadata, + LSE_splitk=lse_splitk, + block_tables=block_tables, Seq_len=seq_len, **_strides(q, "qz", "qm", "qg", "qh", "qk"), **_strides(k, "kz", "kn", "kg", "kh", "kk"), **_strides(v, "vz", "vn", "vg", "vh", "vk"), - **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), - **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(o_splitk, "osk_z", "osk_hg", "osk_s", "osk_m", "osk_k"), + **_strides(lse_splitk, "lsek_z", "lsek_hg", "lsek_s", "lsek_m"), + **_strides(block_tables, "blocktablesz", "blocktablesl"), + kv_cache_blocks_per_row=kv_cache_blocks_per_row, Z=B, H=H, G=G, N_CTX_Q=M, N_CTX_K=Mk, BLOCK_N_PER_SPLIT=split_size, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, - BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, USE_SEQ_LEN=use_seq_len, - num_warps=num_warps, - num_stages=1, PACKED_PER_VAL=PACKED_PER_VAL, - N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + N_GROUPS=num_groups, + IS_CAUSAL=IS_CAUSAL, + NUM_QUERIES_CAUSAL=NUM_QUERIES_CAUSAL, + IS_SPLITK=IS_SPLITK, + USE_PAGED_ATTENTION=is_paged, + PAGE_SIZE=page_size, + WRITE_LSE=IS_SPLITK or needs_gradient, + **extra_args, ) + if not IS_SPLITK: + out = o_splitk[:, :, 0].view(B, G, -1, Mq, Kq) + # This is a copy iff mqa_swap_seqlen_head and Mq, G and Hq are all > 1. + out = out.permute(0, 3, 1, 2, 4).contiguous() + if needs_gradient: + assert lse_splitk is not None + lse = lse_splitk[:, :, 0].view(B, G, -1, Mq) + else: + lse = None + + if inp.query.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + if lse is not None: + lse = lse[:, 0] + out = out[:, :, 0] + + if lse is None: + return out, None + return out, Context(out=out, lse=lse) if mqa_swap_seqlen_head: - out = torch.empty( - (B, H, G, M, Kq), device=q.device, dtype=q.dtype - ).transpose(1, 3) + out = torch.empty((B, G, M, 1, Kq), device=q.device, dtype=q.dtype).permute( + 0, 2, 1, 3, 4 + ) else: out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype) - # Merge together - grid = (B * G * H, M, 1) - _splitK_reduce[grid]( - o_splitk, - metadata, - out, - lse, - split_k=split_k, - **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), - **_strides(metadata, "mzhg", "m2", "ms", "mm"), - **_strides(out, "oz", "om", "og", "oh", "ok"), - **_strides(lse, "lse_zhg", "lse_m"), - BLOCK_SIZE=out.shape[-1], - G=G, - H=H, - # TODO: Tune num_warps - ) - lse = lse.reshape([B, G, H, M]) + # Merge attention and LSE outputs from different split-k chunks + assert lse_splitk is not None + merge_attentions(out, lse, o_splitk, lse_splitk) + if lse is not None: + lse = lse.reshape([B, G, H, M]) if mqa_swap_seqlen_head: - # H/M dimensions have been swapped - out = out.transpose(1, 3) - lse = lse.transpose(2, 3) + out = out.reshape(B, -1, Mq, G, Kq).permute(0, 2, 3, 1, 4) + # This is a copy iff Mq, G and Hq are all > 1. + out = out.contiguous() if inp.query.ndim == 4: # BMGHK -> BMHK assert G == 1 out = out[:, :, 0] - lse = lse[:, 0] + if lse is not None: + lse = lse[:, 0] if Mk == 0: out.zero_() - return out, Context(out=out, lse=lse) - - -class FwOp_S1(FwOp): - SPLIT_K = 1 - NAME = "triton_splitK1" - - -class FwOp_S2(FwOp): - SPLIT_K = 2 - NAME = "triton_splitK2" - - -class FwOp_S4(FwOp): - SPLIT_K = 4 - NAME = "triton_splitK4" - - -class FwOp_S8(FwOp): - SPLIT_K = 8 - NAME = "triton_splitK8" - - -class FwOp_S16(FwOp): - SPLIT_K = 16 - NAME = "triton_splitK16" - - -class FwOp_S32(FwOp): - SPLIT_K = 32 - NAME = "triton_splitK32" + if lse is None: + return out, None + return out, Context(out=out, lse=lse) -class FwOp_S64(FwOp): - SPLIT_K = 64 - NAME = "triton_splitK64" + @classmethod + @functools.lru_cache + def get_operator( + cls, + splitk: int, + *, + block_m: Optional[int] = None, + block_n: Optional[int] = None, + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, + ) -> Type[AttentionFwOpBase]: + kwargs = { + "NAME": f"triton_splitK{splitk}", + "SPLIT_K": splitk, + } + if block_m is not None: + kwargs["BLOCK_M"] = block_m + if block_n is not None: + kwargs["BLOCK_N"] = block_n + if num_warps is not None: + kwargs["NUM_WARPS"] = num_warps + if num_stages is not None: + kwargs["NUM_STAGES"] = num_stages + return type( + f"FwOp_S{splitk}", + (cls,), + kwargs, + ) -class FwOp_S128(FwOp): - SPLIT_K = 128 - NAME = "triton_splitK128" +def merge_attentions( + attn_out: torch.Tensor, + lse_out: Optional[torch.Tensor], + attn_split: torch.Tensor, + lse_split: torch.Tensor, +): + B, M, G, H, Kq = attn_out.shape + if lse_out is not None: + B_H_G, M1 = lse_out.shape + B1, H_G, split_k, M_ceil, Kq1 = attn_split.shape + B2, H_G1, split_k1, M2 = lse_split.shape + + assert ( + B == B1 == B2 and G * H == H_G == H_G1 and M <= M_ceil and M == M2 and Kq == Kq1 + ), f"Incompatible shapes: {attn_out.shape=}, {attn_split.shape=}, {lse_split.shape=}" + if lse_out is not None: + assert ( + B * G * H == B_H_G and M == M1 + ), f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}" + + grid = (B * G * H, M, 1) + _splitK_reduce[grid]( + attn_split, + lse_split, + attn_out, + lse_out, + split_k=split_k, + **_strides(attn_split.flatten(end_dim=1), "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(lse_split.flatten(end_dim=1), "lsek_zhg", "lsek_s", "lsek_m"), + **_strides(attn_out, "oz", "om", "og", "oh", "ok"), + **_strides(lse_out, "lse_zhg", "lse_m"), + BLOCK_SIZE=attn_out.shape[-1], + G=G, + H=H, + WRITE_LSE=lse_out is not None, + num_warps=2 if B * G * H >= 32 else 4, + ) + + +FwOp_Map = { + k: FwOp.get_operator(k) for k in [1, 2, 4, 8, 16, 32, 48, 64, 72, 80, 96, 112, 128] +} +FwOp_S1 = FwOp_Map[1] +FwOp_S2 = FwOp_Map[2] +FwOp_S4 = FwOp_Map[4] +FwOp_S8 = FwOp_Map[8] +FwOp_S16 = FwOp_Map[16] +FwOp_S32 = FwOp_Map[32] +FwOp_S64 = FwOp_Map[64] +FwOp_S128 = FwOp_Map[128]