Skip to content

Commit

Permalink
fbcode exports: gappy biases, fewer reshapings (fairinternal/xformers…
Browse files Browse the repository at this point in the history
…#1043)

Non causal and gappy versions of padded biases
triton_splitk and merge_attention layout fixes to avoid copies

__original_commit__ = fairinternal/xformers@50b82df
  • Loading branch information
bottler authored and xFormers Bot committed Mar 12, 2024
1 parent 0ccd367 commit 2c719c6
Show file tree
Hide file tree
Showing 6 changed files with 668 additions and 174 deletions.
149 changes: 125 additions & 24 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,11 @@ def _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(
}:
Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) + 2
elif bias_type in {
fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask,
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
}:
Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq)
shape = (B, Mq, Mkv, H, K, Kv)
Expand Down Expand Up @@ -375,28 +378,43 @@ def create_tensors(
if mask_is_bottom_right and q_len > kv_len:
# Bottom-right attention and local-attention masks require q_len <= kv_len
kv_len = q_len

if attn_bias_type is not None and issubclass(
attn_bias_type,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
):
page_size_choices = [256, 512]
if issubclass(op, fmha.triton_splitk.FwOp):
# TODO: enable small pages for flash attention when that's implemented
page_size_choices.extend([64, 128])
page_size = random.choice(page_size_choices)
kv_len_paged = (kv_len + page_size - 1) // page_size * page_size
else:
kv_len_paged = kv_len
page_size = None

scale = 3
if fmt == "BMK":
query = torch.randn((B * h, q_len, k), device=device, dtype=dtype)
key = torch.randn((B * h, kv_len, k), device=device, dtype=dtype)
value = torch.randn((B * h, kv_len, kv), device=device, dtype=dtype)
key = torch.randn((B * h, kv_len_paged, k), device=device, dtype=dtype)
value = torch.randn((B * h, kv_len_paged, kv), device=device, dtype=dtype)
elif fmt == "BMHK":
query = torch.randn((B, q_len, h, k), device=device, dtype=dtype)
key = torch.randn((B, kv_len, h, k), device=device, dtype=dtype)
value = torch.randn((B, kv_len, h, kv), device=device, dtype=dtype)
key = torch.randn((B, kv_len_paged, h, k), device=device, dtype=dtype)
value = torch.randn((B, kv_len_paged, h, kv), device=device, dtype=dtype)
else:
assert fmt == "BMGHK"
query = torch.randn((B, q_len, g, h, k), device=device, dtype=dtype)
key = torch.randn((B, kv_len, g, 1, k), device=device, dtype=dtype)
value = torch.randn((B, kv_len, g, 1, kv), device=device, dtype=dtype)
key = torch.randn((B, kv_len_paged, g, 1, k), device=device, dtype=dtype)
value = torch.randn((B, kv_len_paged, g, 1, kv), device=device, dtype=dtype)

for x in [query, key, value]:
x.mul_(scale)

if fmt == "BMGHK":
# Expand - after the in-place mul
key = key.expand((B, kv_len, g, h, k))
value = value.expand((B, kv_len, g, h, k))
key = key.expand((B, kv_len_paged, g, h, k))
value = value.expand((B, kv_len_paged, g, h, k))

if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(attn_bias_type):
attn_bias_type = None
Expand All @@ -414,13 +432,15 @@ def create_tensors(
requires_grad=attn_bias_requires_grad,
fmt=fmt,
op=op,
page_size=page_size,
)
if isinstance(
attn_bias,
(
fmha.attn_bias.BlockDiagonalMask,
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.BlockDiagonalGappyKeysMask,
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
),
):
query, key, value = [
Expand Down Expand Up @@ -467,7 +487,12 @@ def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs)
k,
kv,
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv

if packed and issubclass(
bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask
):
pytest.skip(
"packed doesn't make sense with paged attention, since q has different shape than k/v"
)
if packed and not (k == kv and q_len == kv_len):
pytest.skip(
f"packed incompatible with `k ({k}) != kv ({kv})` or `q_len ({q_len}) != kv_len ({kv_len})`"
Expand Down Expand Up @@ -1501,6 +1526,7 @@ def test_attn_bias_blockdiag_crossattn_causal_with_prefix() -> None:
@cuda_only
def test_attn_bias_padded() -> None:
bsize, n_heads, d, padding = 8, 3, 8, 32
torch.manual_seed(0)

# Q / KV have different seqlen
k = torch.randn((bsize, padding, n_heads, d), device="cuda", dtype=torch.float16)
Expand Down Expand Up @@ -2482,7 +2508,10 @@ def paged_attention_run_inner(
D_H_KV = D_H // 8 + num_quant_groups if num_quant_groups else D_H
kv_seqlens = torch.randint(low=1, high=MAX_T + 1, size=(B,)).tolist()

attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
paged_type = fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask
block_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask

attn_bias = block_type.from_seqlens(
q_seqlen=[1] * B,
kv_padding=MAX_T,
kv_seqlen=kv_seqlens,
Expand Down Expand Up @@ -2548,7 +2577,7 @@ def paged_attention_run_inner(
axv_padded = axv_padded.expand(-1, -1, N_H_L, -1)

attn_bias_paged = attn_bias.make_paged(
block_tables=block_tables, page_size=page_size
block_tables=block_tables, page_size=page_size, paged_type=paged_type
)

y_usual = fmha.memory_efficient_attention_forward(
Expand Down Expand Up @@ -2614,7 +2643,7 @@ def paged_attention_run_inner(
page_size,
)
attn_bias_paged = attn_bias.make_paged(
block_tables=block_tables, page_size=page_size
block_tables=block_tables, page_size=page_size, paged_type=paged_type
)
axk = packed_cache_k.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV)
axv = packed_cache_v.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV)
Expand Down Expand Up @@ -2755,6 +2784,69 @@ def test_merge_attentions_nobias(
(torch.bfloat16, fmha.triton_splitk.FwOp_S1),
# Cutlass's LSE is not consistent
# (torch.float32, fmha.cutlass.FwOp),
],
ids=lambda o: f"{o.NAME}" if hasattr(o, "NAME") else str(o),
)
@pytest.mark.parametrize("num_queries", [1])
@pytest.mark.parametrize("bmghk", [True, False], ids=lambda x: "bmghk" if x else "")
def test_partial_paged(
dtype: torch.dtype, op: Type[AttentionFwOpBase], num_queries: int, bmghk: bool
):
B = 128
N_H_L = 8
D_H = 128
page_size = 256
G = 2 if bmghk else 1
block_tables = torch.zeros((B, 1), dtype=torch.int32, device="cuda")
torch.manual_seed(1)
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None

B_T = num_queries * B

q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda")
k = torch.randn((1, page_size, G, 1, D_H), dtype=dtype, device="cuda")
v = torch.randn_like(k)
k = k.expand(1, page_size, G, N_H_L, D_H)
v = v.expand(1, page_size, G, N_H_L, D_H)
if not bmghk:
q = q[:, :, 0]
k = k[:, :, 0]
v = v[:, :, 0]

attn_bias = (
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[num_queries] * B,
kv_seqlen=[1] + ([100] * (B - 1)),
page_size=page_size,
block_tables=block_tables,
)
)

attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial(
q,
k,
v,
attn_bias,
op=op,
output_dtype=output_dtype,
)
if bmghk:
assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H)
assert lse_chunk.shape == (1, G, N_H_L, B_T)
else:
assert attn_chunk.shape == (1, B_T, N_H_L, D_H)
assert lse_chunk.shape == (1, N_H_L, B_T)


@disable_on_rocm
@sm80_or_better_only
@pytest.mark.parametrize(
"dtype,op",
[
(torch.bfloat16, fmha.triton_splitk.FwOp_S1),
(torch.bfloat16, fmha.triton_splitk.FwOp_S32),
# Cutlass's LSE is not consistent
# (torch.float32, fmha.cutlass.FwOp),
(torch.bfloat16, fmha.flash.FwOp),
],
ids=lambda o: f"{o.NAME}" if hasattr(o, "NAME") else str(o),
Expand Down Expand Up @@ -2813,12 +2905,13 @@ def test_merge_attentions_decoding(
axk = axk[:, :, 0]
axv = axv[:, :, 0]

attn_bias = (
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=q_lens,
kv_padding=chunk_end - chunk_start,
kv_seqlen=[max(min(x, chunk_end) - chunk_start, 0) for x in k_lens],
)
bias_type = fmha.attn_bias.BlockDiagonalPaddedKeysMask
if i + 1 == num_chunks:
bias_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask
attn_bias = bias_type.from_seqlens(
q_seqlen=q_lens,
kv_padding=chunk_end - chunk_start,
kv_seqlen=[max(min(x, chunk_end) - chunk_start, 0) for x in k_lens],
)

attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial(
Expand Down Expand Up @@ -2863,14 +2956,22 @@ def test_merge_attentions_decoding(
output_dtype=output_dtype,
)

atol = op.ERROR_ATOL[dtype]
rtol = op.ERROR_RTOL[dtype] * 2
assert_allclose(
lse_out.to(lse_full.dtype), lse_full, rtol=rtol * 2, atol=atol, msg="lse"
lse_out.to(lse_full.dtype), lse_full, rtol=1e-3, atol=1e-3, msg="lse"
)
assert_allclose(
attn_out.to(attn_full.dtype), attn_full, rtol=rtol, atol=atol, msg="out"
attn_out.to(attn_full.dtype), attn_full, rtol=1e-3, atol=1e-3, msg="out"
)

attn_full2 = fmha.memory_efficient_attention_forward(
q,
axk,
axv,
attn_bias,
op=op,
output_dtype=output_dtype,
)
assert_allclose(attn_full2, attn_full, rtol=1e-3, atol=1e-3, msg="out2")


@sm80_or_better_only
Expand Down
42 changes: 32 additions & 10 deletions xformers/attn_bias_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def create_attn_bias(
requires_grad: bool,
fmt: str,
op: Type[AttentionOpBase],
page_size: Optional[int] = None,
):
if bias_type is None or isinstance(None, bias_type):
return None
Expand Down Expand Up @@ -151,28 +152,49 @@ def create_attn_bias(
block_diag = block_diag.make_causal_from_bottomright()
return block_diag
if bias_type in [
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
]:
assert fmt in ["BMHK", "BMGHK"]
q, k = _rand_seqlens_padded_k(r, batch_size, q_len, kv_len)
g_block_diag = (
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=q,
kv_padding=kv_len,
kv_seqlen=k,
)
block_diag_type = (
bias_type._UNPAGED_TYPE
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask)
else bias_type
)
g_block_diag = block_diag_type.from_seqlens(
q_seqlen=q,
kv_padding=kv_len,
kv_seqlen=k,
)
if bias_type == fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask:
page_size = r.choice([64, 128, 256])
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask):
assert page_size is not None
pages_per_row = (kv_len + page_size - 1) // page_size
block_tables = torch.randperm(
batch_size * pages_per_row, device=device
batch_size * pages_per_row, device=device, dtype=torch.int32
).reshape(batch_size, pages_per_row)
return g_block_diag.make_paged(
block_tables=block_tables, page_size=page_size
block_tables=block_tables, page_size=page_size, paged_type=bias_type
)
return g_block_diag
if bias_type in [
fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask,
fmha.attn_bias.BlockDiagonalGappyKeysMask,
]:
assert fmt in ["BMHK", "BMGHK"]
max_q_minus_k = (
None if bias_type is fmha.attn_bias.BlockDiagonalGappyKeysMask else 0
)
q, k = _rand_seqlens(r, batch_size, q_len, kv_len, max_q_minus_k)
total_kv_len = kv_len * batch_size
starts = [r.randint(0, total_kv_len - ki) for ki in k] + [total_kv_len]
return fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens(
q_seqlen=q,
kv_seqstarts=starts,
kv_seqlen=k,
)
if bias_type == fmha.attn_bias.LocalAttentionFromBottomRightMask:
return bias_type(
window_left=r.randint(0, 5),
Expand Down
20 changes: 10 additions & 10 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
)
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalGappyKeysMask,
BlockDiagonalMask,
BlockDiagonalPaddedKeysMask,
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
PagedBlockDiagonalPaddedKeysMask,
)
from .common import (
AttentionBwOpBase,
Expand Down Expand Up @@ -479,7 +481,9 @@ def memory_efficient_attention_partial(
attn_bias,
(
type(None),
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
PagedBlockDiagonalPaddedKeysMask,
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
),
Expand Down Expand Up @@ -544,8 +548,8 @@ def merge_attentions(
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {split_k}/{split_k1}, {M}/{M}"
)

attn_split = attn_split.permute(1, 3, 4, 0, 2, 5).reshape(B, G * H, split_k, M, Kq)
lse_split = lse_split.permute(1, 2, 3, 0, 4).reshape(B, G * H, split_k, M)
attn_split = attn_split.permute(1, 3, 4, 0, 2, 5)
lse_split = lse_split.permute(1, 2, 3, 0, 4)

attn_out = torch.empty(
B,
Expand All @@ -558,16 +562,12 @@ def merge_attentions(
)
if write_lse:
lse_out = torch.empty(
B * H * G, M, device=attn_split.device, dtype=lse_split.dtype
B, G, H, M, device=attn_split.device, dtype=lse_split.dtype
)
else:
lse_out = None

triton_splitk.merge_attentions(
attn_out.permute(0, 1, 3, 2, 4), lse_out, attn_split, lse_split
)
if lse_out is not None:
lse_out = lse_out.view(B, G, H, M)
triton_splitk.merge_attentions(attn_out, lse_out, attn_split, lse_split)

if is_bmhk:
attn_out = attn_out[:, :, 0]
Expand Down
Loading

0 comments on commit 2c719c6

Please sign in to comment.