Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…nto dev_upstream
  • Loading branch information
tenpercent committed Mar 1, 2024
2 parents c5ea221 + fe0526b commit b585563
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 81 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ If you use xFormers in your publication, please cite it by using the following B

``` bibtex
@Misc{xFormers2022,
author = {Benjamin Lefaudeux and Francisco Massa and Diana Liskovich and Wenhan Xiong and Vittorio Caggiano and Sean Naren and Min Xu and Jieru Hu and Marta Tintore and Susan Zhang and Patrick Labatut and Daniel Haziza},
author = {Benjamin Lefaudeux and Francisco Massa and Diana Liskovich and Wenhan Xiong and Vittorio Caggiano and Sean Naren and Min Xu and Jieru Hu and Marta Tintore and Susan Zhang and Patrick Labatut and Daniel Haziza and Luca Wehrstedt and Jeremy Reizenstein and Grigory Sizov},
title = {xFormers: A modular and hackable Transformer modelling library},
howpublished = {\url{https://github.com/facebookresearch/xformers}},
year = {2022}
Expand Down
172 changes: 135 additions & 37 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2689,18 +2689,93 @@ def paged_attention_run_inner(


@sm80_or_better_only
def test_merging_attentions_decoding():
@pytest.mark.parametrize(
"op",
[
fmha.triton_splitk.FwOp,
fmha.flash.FwOp,
None,
],
ids=lambda op: "None" if op is None else op.NAME,
)
@pytest.mark.parametrize("G,H", [(1, 11), (7, 1), (1, 1), (7, 11), (None, 11)])
@pytest.mark.parametrize(
"write_lse", (False, True), ids=lambda x: "write_lse" if x else ""
)
def test_merge_attentions_nobias(
write_lse: bool, op: Type[AttentionFwOpBase], G: Optional[int], H: int
):
"""
Merging the same attention twice shouldn't change anything.
This also tests the shape of the lse output of each permitted op.
"""
B, M, Mq, K = 13, 5, 3, 128
if op is None or torch.bfloat16 in op.SUPPORTED_DTYPES:
dtype = torch.bfloat16
else:
dtype = next(iter(op.SUPPORTED_DTYPES))
if G is None:
q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda")
k = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K)
v = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K)
else:
q = 3 * torch.rand(B, Mq, G, H, K, dtype=dtype, device="cuda")
k = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand(
B, M, G, H, K
)
v = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand(
B, M, G, H, K
)
out1, lse1 = fmha.memory_efficient_attention_partial(q, k, v, op=op)
assert out1.shape == q.shape
M_ceil = lse1.shape[-1]
assert M_ceil >= Mq
assert lse1.shape == (B, H, M_ceil) if G is None else (B, G, H, M_ceil)
lse1 = lse1[..., :Mq]

out, lse = fmha.merge_attentions(
torch.stack([out1, out1]), torch.stack([lse1, lse1]), write_lse=write_lse
)
assert out.shape == out1.shape
assert_allclose(out1, out, rtol=1e-3, atol=1e-3, msg="out")
if write_lse:
assert lse is not None
assert lse.shape[:-1] == lse1.shape[:-1]
assert_allclose(
lse1[..., :Mq] + math.log(2), lse[..., :Mq], rtol=1e-3, atol=1e-3, msg="lse"
)
else:
assert lse is None


@sm80_or_better_only
@sm80_or_better_only
@pytest.mark.parametrize(
"dtype,op",
[
(torch.bfloat16, fmha.triton_splitk.FwOp_S1),
# Cutlass's LSE is not consistent
# (torch.float32, fmha.cutlass.FwOp),
(torch.bfloat16, fmha.flash.FwOp),
],
ids=lambda o: f"{o.NAME}" if hasattr(o, "NAME") else str(o),
)
@pytest.mark.parametrize("num_queries", [1, 2])
@pytest.mark.parametrize("bmghk", [True, False], ids=lambda x: "bmghk" if x else "")
def test_merge_attentions_decoding(
dtype: torch.dtype, op: Type[AttentionFwOpBase], num_queries: int, bmghk: bool
):
"""
Compute decoding attention on chunks of K/V and merge them together.
Compare with computing attention on the whole K/V.
"""

MAX_T = 8192
B = 128
N_KVH_L = 1
N_H_L = 8
D_H = 128
dtype = torch.bfloat16
G = 2 if bmghk else 1
torch.manual_seed(1)
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None

num_chunks = 10

Expand All @@ -2710,26 +2785,34 @@ def test_merging_attentions_decoding():
chunk_starts[0] = 0
chunk_starts.append(MAX_T)

# We construct sequances so that even the last chunk has a non-empty part of every sequence.
# We construct sequences so that even the last chunk has a non-empty part of every sequence
# as long as the number of queries.
# Otherwise the corresponding LSE will be -inf and that'll propagate to the whole sum.
# It is possible to teach the kernel to ignore infinite LSEs, but in practical use cases
# of merging attention, e.g. a batch of sequences with a common prefix, this condition should be satisfied.
k_lens = torch.randint(low=chunk_starts[-2] + 1, high=MAX_T, size=(B,)).tolist()
q_lens = [1 for _ in k_lens]
B_T = sum(q_lens)

q = torch.randn((1, B_T, N_H_L, D_H), dtype=dtype, device="cuda")
k = torch.randn((B, MAX_T, N_KVH_L, D_H), dtype=dtype, device="cuda")
k_lens = torch.randint(
low=chunk_starts[-2] + num_queries, high=MAX_T, size=(B,)
).tolist()
q_lens = [num_queries] * B
B_T = num_queries * B

q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda")
k = torch.randn((B, MAX_T, G, 1, D_H), dtype=dtype, device="cuda")
v = torch.randn_like(k)
if not bmghk:
q = q[:, :, 0]

# Compute per-chunk attention
chunks_output = []
for i in range(num_chunks):
chunk_start, chunk_end = chunk_starts[i], chunk_starts[i + 1]
k_chunk = k[:, chunk_start:chunk_end, ...]
v_chunk = v[:, chunk_start:chunk_end, ...]
axk = k_chunk.reshape(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H)
axv = v_chunk.reshape(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H)
axk = k_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
axv = v_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
if not bmghk:
axk = axk[:, :, 0]
axv = axv[:, :, 0]

attn_bias = (
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
Expand All @@ -2739,45 +2822,61 @@ def test_merging_attentions_decoding():
)
)

attn_chunk, lse_chunk = fmha.memory_efficient_attention_forward_requires_grad(
attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial(
q,
axk,
axv,
attn_bias,
op=op,
output_dtype=output_dtype,
)
attn_chunk = attn_chunk.reshape(B, -1, N_H_L, D_H)
if bmghk:
assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H)
assert lse_chunk.shape == (1, G, N_H_L, B_T)
else:
assert attn_chunk.shape == (1, B_T, N_H_L, D_H)
assert lse_chunk.shape == (1, N_H_L, B_T)
chunks_output.append((attn_chunk, lse_chunk))

# Merge attention from all chunks
attn_split = torch.stack([attn_chunk for attn_chunk, _ in chunks_output])
lse_split = torch.stack([lse_chunk for _, lse_chunk in chunks_output])
attn_out, lse_out = fmha.merge_attentions(
attn_split.permute(0, 1, 3, 2, 4), lse_split
)
attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split, output_dtype=dtype)
assert lse_out is not None

# Compute attention on the full K/V
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=q_lens,
kv_padding=MAX_T,
kv_seqlen=k_lens,
)
axk = k.view(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H)
axv = v.view(1, -1, N_KVH_L, D_H).expand(1, -1, N_H_L, D_H)
axk = k.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
axv = v.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
if not bmghk:
axk = axk[:, :, 0]
axv = axv[:, :, 0]
attn_full, lse_full = fmha.memory_efficient_attention_forward_requires_grad(
q,
axk,
axv,
attn_bias,
op=op,
output_dtype=output_dtype,
)

attn_out = attn_out.reshape(1, B_T, N_H_L, D_H)
torch.testing.assert_close(lse_out, lse_full, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(attn_out, attn_full, rtol=1e-3, atol=1e-3)
atol = op.ERROR_ATOL[dtype]
rtol = op.ERROR_RTOL[dtype] * 2
assert_allclose(
lse_out.to(lse_full.dtype), lse_full, rtol=rtol * 2, atol=atol, msg="lse"
)
assert_allclose(
attn_out.to(attn_full.dtype), attn_full, rtol=rtol, atol=atol, msg="out"
)


@sm80_or_better_only
@pytest.mark.parametrize("bmghk", (False, True))
def test_merging_attentions_against_ref(bmghk: bool):
def test_merge_attentions_against_ref(bmghk: bool):
split_k = 16
B = 12
M = 137
Expand All @@ -2786,12 +2885,12 @@ def test_merging_attentions_against_ref(bmghk: bool):
D_H = 128
dtype = torch.float32

attn_split = torch.randn([split_k, B, N_H_L, G, M, D_H], dtype=dtype, device="cuda")
lse_split = torch.randn([split_k, B, N_H_L, G, M], dtype=dtype, device="cuda")
attn_split = torch.randn([split_k, B, M, G, N_H_L, D_H], dtype=dtype, device="cuda")
lse_split = torch.randn([split_k, B, G, N_H_L, M], dtype=dtype, device="cuda")

if not bmghk:
attn_split = attn_split[:, :, :, 0, :, :]
lse_split = lse_split[:, :, :, 0, :]
attn_split = attn_split[:, :, :, 0]
lse_split = lse_split[:, :, 0]

attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split)

Expand All @@ -2803,29 +2902,28 @@ def test_merging_attentions_against_ref(bmghk: bool):

def _merge_attentions_ref(attn_split, lse_split):
"""
attn_split: [split_k, B, H, G, M_ceil, Kq]
lse_split: [split_k, B, H, G, M]
attn_split: [split_k, B, M, (G,) H, Kq]
lse_split: [split_k, B, (G,) H, M]
"""
is_bmghk = len(attn_split.shape) == 6
if not is_bmghk:
attn_split = attn_split.unsqueeze(3)
lse_split = lse_split.unsqueeze(3)
lse_split = lse_split.unsqueeze(2)

lse_split = lse_split.unsqueeze(5) # [split_k, B, M, G, H, 1]
lse_split = lse_split[..., None].moveaxis(4, 2) # [split_k, B, M, G, H, 1]

lse_max, _ = torch.max(lse_split, dim=0, keepdim=True) # [1, B, M, G, H, 1]
lse_max, _ = torch.max(lse_split, dim=0) # [B, M, G, H, 1]
sumexp_normalized = torch.exp(lse_split - lse_max) # [split_k, B, M, G, H, 1]
denominator = sumexp_normalized.sum(dim=0) # [B, M, G, H, 1]
numerator = (sumexp_normalized * attn_split).sum(dim=0) # [B, M, G, H, K]

attn_out = numerator / denominator # [B, M_ceil, G, H, Kq]
lse_out = (lse_max.squeeze(0) + torch.log(denominator)).squeeze(
4
) # [B, M_ceil, G, H]
lse_out = lse_max + torch.log(denominator)
lse_out = lse_out.squeeze(4).permute(0, 2, 3, 1) # [B, G, H, M]

if not is_bmghk:
attn_out = attn_out.squeeze(2)
lse_out = lse_out.squeeze(2)
lse_out = lse_out.squeeze(1)

return attn_out, lse_out

Expand Down
Loading

0 comments on commit b585563

Please sign in to comment.