Skip to content

Commit

Permalink
merge_attention api
Browse files Browse the repository at this point in the history
ghstack-source-id: b20fc46f920a7f80dc1e87a7ab4d88a591aae03b
Pull Request resolved: fairinternal/xformers#1039

__original_commit__ = fairinternal/xformers@bcae5bf
  • Loading branch information
bottler authored and xFormers Bot committed Feb 29, 2024
1 parent 7f8c290 commit 41ebc03
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 64 deletions.
102 changes: 66 additions & 36 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2608,18 +2608,32 @@ def paged_attention_run_inner(


@sm80_or_better_only
def test_merging_attentions_decoding():
@sm80_or_better_only
@pytest.mark.parametrize(
"dtype,op",
[
(torch.bfloat16, fmha.triton_splitk.FwOp),
# Cutlass's LSE is not consistent
# (torch.float32, fmha.cutlass.FwOp),
(torch.bfloat16, fmha.flash.FwOp),
],
ids=lambda o: f"{o.NAME}" if hasattr(o, "NAME") else str(o),
)
@pytest.mark.parametrize("num_queries", [1, 2])
@pytest.mark.parametrize("bmghk", [True, False], ids=lambda x: "bmghk" if x else "")
def test_merge_attentions_decoding(
dtype: torch.dtype, op: Type[AttentionFwOpBase], num_queries: int, bmghk: bool
):
"""
Compute decoding attention on chunks of K/V and merge them together.
Compare with computing attention on the whole K/V.
"""

MAX_T = 8192
B = 128
N_KVH_L = 1
N_H_L = 8
D_H = 128
dtype = torch.bfloat16
G = 2 if bmghk else 1
torch.manual_seed(1)

num_chunks = 10

Expand All @@ -2629,26 +2643,34 @@ def test_merging_attentions_decoding():
chunk_starts[0] = 0
chunk_starts.append(MAX_T)

# We construct sequances so that even the last chunk has a non-empty part of every sequence.
# We construct sequences so that even the last chunk has a non-empty part of every sequence
# as long as the number of queries.
# Otherwise the corresponding LSE will be -inf and that'll propagate to the whole sum.
# It is possible to teach the kernel to ignore infinite LSEs, but in practical use cases
# of merging attention, e.g. a batch of sequences with a common prefix, this condition should be satisfied.
k_lens = torch.randint(low=chunk_starts[-2] + 1, high=MAX_T, size=(B,)).tolist()
q_lens = [1 for _ in k_lens]
B_T = sum(q_lens)

q = torch.randn((1, B_T, N_H_L, D_H), dtype=dtype, device="cuda")
k = torch.randn((B, MAX_T, N_KVH_L, D_H), dtype=dtype, device="cuda")
k_lens = torch.randint(
low=chunk_starts[-2] + num_queries, high=MAX_T, size=(B,)
).tolist()
q_lens = [num_queries] * B
B_T = num_queries * B

q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda")
k = torch.randn((B, MAX_T, G, 1, D_H), dtype=dtype, device="cuda")
v = torch.randn_like(k)
if not bmghk:
q = q[:, :, 0]

# Compute per-chunk attention
chunks_output = []
for i in range(num_chunks):
chunk_start, chunk_end = chunk_starts[i], chunk_starts[i + 1]
k_chunk = k[:, chunk_start:chunk_end, ...]
v_chunk = v[:, chunk_start:chunk_end, ...]
axk = k_chunk.reshape(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H)
axv = v_chunk.reshape(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H)
axk = k_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
axv = v_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
if not bmghk:
axk = axk[:, :, 0]
axv = axv[:, :, 0]

attn_bias = (
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
Expand All @@ -2663,40 +2685,49 @@ def test_merging_attentions_decoding():
axk,
axv,
attn_bias,
op=op,
)
attn_chunk = attn_chunk.reshape(B, -1, N_H_L, D_H)
if bmghk:
assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H)
assert lse_chunk.shape == (1, G, N_H_L, B_T)
else:
assert attn_chunk.shape == (1, B_T, N_H_L, D_H)
assert lse_chunk.shape == (1, N_H_L, B_T)
chunks_output.append((attn_chunk, lse_chunk))

# Merge attention from all chunks
attn_split = torch.stack([attn_chunk for attn_chunk, _ in chunks_output])
lse_split = torch.stack([lse_chunk for _, lse_chunk in chunks_output])
attn_out, lse_out = fmha.merge_attentions(
attn_split.permute(0, 1, 3, 2, 4), lse_split
)
attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split)

# Compute attention on the full K/V
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=q_lens,
kv_padding=MAX_T,
kv_seqlen=k_lens,
)
axk = k.view(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H)
axv = v.view(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H)
axk = k.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
axv = v.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
if not bmghk:
axk = axk[:, :, 0]
axv = axv[:, :, 0]
attn_full, lse_full = fmha.memory_efficient_attention_forward_requires_grad(
q,
axk,
axv,
attn_bias,
op=op,
)

attn_out = attn_out.reshape(1, B_T, N_H_L, D_H)
torch.testing.assert_close(lse_out, lse_full, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(attn_out, attn_full, rtol=1e-3, atol=1e-3)
atol = op.ERROR_ATOL[dtype] * (10 if op is fmha.triton_splitk.FwOp else 1)
rtol = op.ERROR_RTOL[dtype] * 2
assert_allclose(lse_out, lse_full, rtol=rtol * 2, atol=atol, msg="lse")
assert_allclose(attn_out, attn_full, rtol=rtol, atol=atol, msg="out")


@sm80_or_better_only
@pytest.mark.parametrize("bmghk", (False, True))
def test_merging_attentions_against_ref(bmghk: bool):
def test_merge_attentions_against_ref(bmghk: bool):
split_k = 16
B = 12
M = 137
Expand All @@ -2705,12 +2736,12 @@ def test_merging_attentions_against_ref(bmghk: bool):
D_H = 128
dtype = torch.float32

attn_split = torch.randn([split_k, B, N_H_L, G, M, D_H], dtype=dtype, device="cuda")
lse_split = torch.randn([split_k, B, N_H_L, G, M], dtype=dtype, device="cuda")
attn_split = torch.randn([split_k, B, M, G, N_H_L, D_H], dtype=dtype, device="cuda")
lse_split = torch.randn([split_k, B, G, N_H_L, M], dtype=dtype, device="cuda")

if not bmghk:
attn_split = attn_split[:, :, :, 0, :, :]
lse_split = lse_split[:, :, :, 0, :]
attn_split = attn_split[:, :, :, 0]
lse_split = lse_split[:, :, 0]

attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split)

Expand All @@ -2722,29 +2753,28 @@ def test_merging_attentions_against_ref(bmghk: bool):

def _merge_attentions_ref(attn_split, lse_split):
"""
attn_split: [split_k, B, H, G, M_ceil, Kq]
lse_split: [split_k, B, H, G, M]
attn_split: [split_k, B, M, (G,) H, Kq]
lse_split: [split_k, B, (G,) H, M]
"""
is_bmghk = len(attn_split.shape) == 6
if not is_bmghk:
attn_split = attn_split.unsqueeze(3)
lse_split = lse_split.unsqueeze(3)
lse_split = lse_split.unsqueeze(2)

lse_split = lse_split.unsqueeze(5) # [split_k, B, M, G, H, 1]
lse_split = lse_split[..., None].moveaxis(4, 2) # [split_k, B, M, G, H, 1]

lse_max, _ = torch.max(lse_split, dim=0, keepdim=True) # [1, B, M, G, H, 1]
lse_max, _ = torch.max(lse_split, dim=0) # [B, M, G, H, 1]
sumexp_normalized = torch.exp(lse_split - lse_max) # [split_k, B, M, G, H, 1]
denominator = sumexp_normalized.sum(dim=0) # [B, M, G, H, 1]
numerator = (sumexp_normalized * attn_split).sum(dim=0) # [B, M, G, H, K]

attn_out = numerator / denominator # [B, M_ceil, G, H, Kq]
lse_out = (lse_max.squeeze(0) + torch.log(denominator)).squeeze(
4
) # [B, M_ceil, G, H]
lse_out = lse_max + torch.log(denominator)
lse_out = lse_out.squeeze(4).permute(0, 2, 3, 1) # [B, G, H, M]

if not is_bmghk:
attn_out = attn_out.squeeze(2)
lse_out = lse_out.squeeze(2)
lse_out = lse_out.squeeze(1)

return attn_out, lse_out

Expand Down
40 changes: 21 additions & 19 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,15 @@ def merge_attentions(
Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...)
LSE_full = log(exp(LSE1) + exp(LSE2) + ...)
Attention inputs are in BH(G)MK format, stacked along dim 0. Attention output also is in BH(G)MK.
Args:
attn_split: [split_k, B, H, G, M, Kq] or [split_k, B, H, M, Kq]
lse_split: [split_k, B, H, G, M] or [split_k, B, H, M]
Res:
attn_out: [B, H, G, M, K] or [B, H, M, K]
lse_out: [B, H, G, M] or [B, H, M]
attn_split: [split_k, B, M, G, H, Kq] or [split_k, B, M, H, Kq]
lse_split: [split_k, B, G, H, M] or [split_k, B, H, M]
Returns:
attn_out: [B, M, G, H, Kq] or [B, M, H, Kq]
lse_out: [B, G, H, M] or [B, H, M] if write_lse
or None otherwise
"""

assert (
Expand All @@ -443,21 +446,20 @@ def merge_attentions(
is_bmhk = attn_split.ndim == 5
if is_bmhk:
attn_split = attn_split.unsqueeze(3)
lse_split = lse_split.unsqueeze(3)

split_k, B, H, G, M_ceil, Kq = attn_split.shape
split_k1, B1, H1, G1, M = lse_split.shape
assert (
B == B1 and G == G1 and H == H1 and split_k == split_k1 and M_ceil >= M
), f"{attn_split.shape=} {lse_split.shape=}"
lse_split = lse_split.unsqueeze(2)

attn_split = attn_split.permute(1, 2, 3, 0, 4, 5).view(
B, H * G, split_k, M_ceil, Kq
split_k, B, M, G, H, Kq = attn_split.shape
split_k1, B1, G1, H1, M1 = lse_split.shape
assert B == B1 and G == G1 and H == H1 and split_k == split_k1 and M == M, (
f"{attn_split.shape=} {lse_split.shape=} "
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {split_k}/{split_k1}, {M}/{M}"
)
lse_split = lse_split.permute(1, 2, 3, 0, 4).view(B, H * G, split_k, M)

attn_split = attn_split.permute(1, 3, 4, 0, 2, 5).reshape(B, G * H, split_k, M, Kq)
lse_split = lse_split.permute(1, 2, 3, 0, 4).reshape(B, G * H, split_k, M)

attn_out = torch.empty(
B, H, G, M, Kq, device=attn_split.device, dtype=attn_split.dtype
B, M, G, H, Kq, device=attn_split.device, dtype=attn_split.dtype
)
if write_lse:
lse_out = torch.empty(
Expand All @@ -467,15 +469,15 @@ def merge_attentions(
lse_out = None

triton_splitk.merge_attentions(
attn_out.permute(0, 3, 2, 1, 4), lse_out, attn_split, lse_split
attn_out.permute(0, 1, 3, 2, 4), lse_out, attn_split, lse_split
)
if lse_out is not None:
lse_out = lse_out.view(B, H, G, M)
lse_out = lse_out.view(B, G, H, M)

if is_bmhk:
attn_out = attn_out[:, :, 0]
if lse_out is not None:
lse_out = lse_out[:, :, 0]
lse_out = lse_out[:, 0]

return attn_out, lse_out

Expand Down
27 changes: 26 additions & 1 deletion xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,25 @@ def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) ->
)


def _post_process_lse(
lse: torch.Tensor, inp: Inputs, original_query_shape
) -> torch.Tensor:
if not isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
return lse
q_seqinfo = inp.attn_bias.q_seqinfo
B = len(q_seqinfo.seqstart_py) - 1
if q_seqinfo.max_seqlen * B != original_query_shape[1]:
# Heterogeneous batch. We can't fix it.
return lse

# reshape from (B, G*H, max_seqlen) to (1, G*H, B*max_seqlen)
# Unfortunately this flatten is not just a view.
lse_hkm = lse.permute(1, 0, 2).flatten(start_dim=1)[None]
if len(original_query_shape) == 5:
return lse_hkm.unflatten(1, original_query_shape[2:4])
return lse_hkm


@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
Expand Down Expand Up @@ -449,6 +468,8 @@ def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
return_softmax = False
original_query_shape = inp.query.shape

out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
Expand Down Expand Up @@ -489,7 +510,11 @@ def apply(
device=inp.query.device,
dtype=torch.float32,
)
ctx = Context(out=out, lse=softmax_lse)
if not needs_gradient:
return out, None
ctx = Context(
out=out, lse=_post_process_lse(softmax_lse, inp, original_query_shape)
)
if inp.p != 0.0:
ctx.op_bw = BwOp
ctx.rng_state = rng_state
Expand Down
29 changes: 21 additions & 8 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ def apply(
attn_bias.k_seqinfo.to(inp.query.device) # type: ignore
attn_bias.q_seqinfo.to(inp.query.device) # type: ignore
seq_len = attn_bias.k_seqinfo.seqlen # type: ignore
assert q.shape[0] == 1
B = len(seq_len)
G, Hq, Kq = q.shape[-3:]
Kkv = v.shape[-1]
Expand Down Expand Up @@ -984,6 +985,8 @@ def grid(META):
if needs_gradient:
assert lse_splitk is not None
lse = lse_splitk[:, :, 0].view(B, G, -1, Mq)
if attn_bias is not None:
lse = lse.permute(1, 2, 0, 3).reshape(1, G, H, B * Mq)
else:
lse = None

Expand All @@ -1007,9 +1010,12 @@ def grid(META):

# Merge attention and LSE outputs from different split-k chunks
assert lse_splitk is not None
merge_attentions(out, lse, o_splitk, lse_splitk)
merge_attentions(out, lse, o_splitk[:, :, :, :M], lse_splitk)
if lse is not None:
lse = lse.reshape([B, G, H, M])
if attn_bias is not None:
lse = lse.permute(1, 2, 0, 3).reshape(1, G, H, B * M)

if mqa_swap_seqlen_head:
out = out.reshape(B, -1, Mq, G, Kq).permute(0, 2, 3, 1, 4)
# This is a copy iff Mq, G and Hq are all > 1.
Expand Down Expand Up @@ -1067,26 +1073,33 @@ def merge_attentions(
B, M, G, H, Kq = attn_out.shape
if lse_out is not None:
B_H_G, M1 = lse_out.shape
B1, H_G, split_k, M_ceil, Kq1 = attn_split.shape
B2, H_G1, split_k1, M2 = lse_split.shape
B1, H_G, split_k, M2, Kq1 = attn_split.shape
B2, H_G1, split_k1, M3 = lse_split.shape

assert (
B == B1 == B2 and G * H == H_G == H_G1 and M <= M_ceil and M == M2 and Kq == Kq1
B == B1 == B2 and G * H == H_G == H_G1 and M == M2 == M3 and Kq == Kq1
), f"Incompatible shapes: {attn_out.shape=}, {attn_split.shape=}, {lse_split.shape=}"
assert (
split_k == split_k1
), f"Incompatible shapes: {attn_split.shape=}, {lse_split.shape=}"
if lse_out is not None:
assert (
B * G * H == B_H_G and M == M1
), f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}"

# TODO: avoid this copy in more cases
attn_split_ = attn_split.flatten(end_dim=1)
lse_split_ = lse_split.flatten(end_dim=1)

grid = (B * G * H, M, 1)
_splitK_reduce[grid](
attn_split,
lse_split,
attn_split_,
lse_split_,
attn_out,
lse_out,
split_k=split_k,
**_strides(attn_split.flatten(end_dim=1), "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(lse_split.flatten(end_dim=1), "lsek_zhg", "lsek_s", "lsek_m"),
**_strides(attn_split_, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(lse_split_, "lsek_zhg", "lsek_s", "lsek_m"),
**_strides(attn_out, "oz", "om", "og", "oh", "ok"),
**_strides(lse_out, "lse_zhg", "lse_m"),
BLOCK_SIZE=attn_out.shape[-1],
Expand Down

0 comments on commit 41ebc03

Please sign in to comment.