Skip to content

Commit

Permalink
memory_efficient_attention_partial
Browse files Browse the repository at this point in the history
This is a friendly entrypoint for getting MHA with LSE to be used with
merge_attentions.

ghstack-source-id: 2a670a0a7b4550d049b44e40d21ab3c55f67d2e9
Pull Request resolved: fairinternal/xformers#1042

__original_commit__ = fairinternal/xformers@141900b
  • Loading branch information
bottler authored and xFormers Bot committed Mar 1, 2024
1 parent c9d9be2 commit fe0526b
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 6 deletions.
64 changes: 62 additions & 2 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2607,6 +2607,66 @@ def paged_attention_run_inner(
torch.testing.assert_close(y_swapped, y_packed)


@sm80_or_better_only
@pytest.mark.parametrize(
"op",
[
fmha.triton_splitk.FwOp,
fmha.flash.FwOp,
None,
],
ids=lambda op: "None" if op is None else op.NAME,
)
@pytest.mark.parametrize("G,H", [(1, 11), (7, 1), (1, 1), (7, 11), (None, 11)])
@pytest.mark.parametrize(
"write_lse", (False, True), ids=lambda x: "write_lse" if x else ""
)
def test_merge_attentions_nobias(
write_lse: bool, op: Type[AttentionFwOpBase], G: Optional[int], H: int
):
"""
Merging the same attention twice shouldn't change anything.
This also tests the shape of the lse output of each permitted op.
"""
B, M, Mq, K = 13, 5, 3, 128
if op is None or torch.bfloat16 in op.SUPPORTED_DTYPES:
dtype = torch.bfloat16
else:
dtype = next(iter(op.SUPPORTED_DTYPES))
if G is None:
q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda")
k = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K)
v = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K)
else:
q = 3 * torch.rand(B, Mq, G, H, K, dtype=dtype, device="cuda")
k = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand(
B, M, G, H, K
)
v = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand(
B, M, G, H, K
)
out1, lse1 = fmha.memory_efficient_attention_partial(q, k, v, op=op)
assert out1.shape == q.shape
M_ceil = lse1.shape[-1]
assert M_ceil >= Mq
assert lse1.shape == (B, H, M_ceil) if G is None else (B, G, H, M_ceil)
lse1 = lse1[..., :Mq]

out, lse = fmha.merge_attentions(
torch.stack([out1, out1]), torch.stack([lse1, lse1]), write_lse=write_lse
)
assert out.shape == out1.shape
assert_allclose(out1, out, rtol=1e-3, atol=1e-3, msg="out")
if write_lse:
assert lse is not None
assert lse.shape[:-1] == lse1.shape[:-1]
assert_allclose(
lse1[..., :Mq] + math.log(2), lse[..., :Mq], rtol=1e-3, atol=1e-3, msg="lse"
)
else:
assert lse is None


@sm80_or_better_only
@sm80_or_better_only
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2681,7 +2741,7 @@ def test_merge_attentions_decoding(
)
)

attn_chunk, lse_chunk = fmha.memory_efficient_attention_forward_requires_grad(
attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial(
q,
axk,
axv,
Expand All @@ -2700,7 +2760,7 @@ def test_merge_attentions_decoding(
# Merge attention from all chunks
attn_split = torch.stack([attn_chunk for attn_chunk, _ in chunks_output])
lse_split = torch.stack([lse_chunk for _, lse_chunk in chunks_output])
attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split)
attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split, output_dtype=dtype)
assert lse_out is not None

# Compute attention on the full K/V
Expand Down
62 changes: 59 additions & 3 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import torch

from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk
from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalMask,
LowerTriangularMask,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Expand Down Expand Up @@ -439,8 +444,52 @@ def _memory_efficient_attention_backward(
return grads


def memory_efficient_attention_partial(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple (output, lse), where `output` is the attention and `lse`
is a least squared error. The cat'ed outputs of calls to this with the same query
and separate keys and values can be merged with merge_attentions to obtain
the attention of the queries against the disjoint union of the keys and values.
"""
if p != 0.0:
raise NotImplementedError("dropout is not supported.")
if not isinstance(
attn_bias, (type(None), BlockDiagonalCausalWithOffsetPaddedKeysMask)
):
raise ValueError(
"only BlockDiagonalCausalWithOffsetPaddedKeysMask and no bias supported"
)
out, ctx = _memory_efficient_attention_forward_requires_grad(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
is_partial=True,
),
op=op,
)
return out, ctx.lse


def merge_attentions(
attn_split: torch.Tensor, lse_split: torch.Tensor, write_lse: bool = True
attn_split: torch.Tensor,
lse_split: torch.Tensor,
write_lse: bool = True,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Combine attention output computed on different parts of K/V for the same
Expand All @@ -453,6 +502,7 @@ def merge_attentions(
Args:
attn_split: [split_k, B, M, G, H, Kq] or [split_k, B, M, H, Kq]
lse_split: [split_k, B, G, H, M] or [split_k, B, H, M]
out_dype: dtype of attn_out
Returns:
attn_out: [B, M, G, H, Kq] or [B, M, H, Kq]
Expand Down Expand Up @@ -480,7 +530,13 @@ def merge_attentions(
lse_split = lse_split.permute(1, 2, 3, 0, 4).reshape(B, G * H, split_k, M)

attn_out = torch.empty(
B, M, G, H, Kq, device=attn_split.device, dtype=attn_split.dtype
B,
M,
G,
H,
Kq,
device=attn_split.device,
dtype=attn_split.dtype if output_dtype is None else output_dtype,
)
if write_lse:
lse_out = torch.empty(
Expand Down
6 changes: 6 additions & 0 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Inputs:
p: float = 0.0
scale: Optional[float] = None
output_dtype: Optional[torch.dtype] = None
is_partial: bool = False

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -206,6 +207,8 @@ def validate_inputs(self) -> None:

def get_output_dtype(self) -> torch.dtype:
if self.output_dtype is None:
if self.is_partial and self.query.dtype is not torch.float64:
return torch.float32
return self.query.dtype
return self.output_dtype

Expand Down Expand Up @@ -264,6 +267,7 @@ class AttentionOpBase(BaseOperator):
SUPPORTS_CUSTOM_SCALE: bool = False
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
SUPPORTS_OUTPUT_DTYPE: bool = False
SUPPORTS_PARTIAL: bool = False
IS_DETERMINISTIC: bool = True
SUPPORTS_BMGHK: bool = False
NAME: str
Expand Down Expand Up @@ -322,6 +326,8 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
if not cls.SUPPORTS_OUTPUT_DTYPE:
if d.output_dtype is not None and d.output_dtype is not dtype:
reasons.append("Custom output dtype not supported")
if d.is_partial and not cls.SUPPORTS_PARTIAL:
reasons.append("Partial attention not supported")
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
reasons.append("dropout > 0.0")
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
Expand Down
13 changes: 12 additions & 1 deletion xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,12 @@ def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) ->


def _post_process_lse(
lse: torch.Tensor, inp: Inputs, original_query_shape
lse: torch.Tensor, inp: Inputs, original_query_shape: Tuple[int, ...]
) -> torch.Tensor:
if not isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
if inp.is_partial and inp.attn_bias is None and len(original_query_shape) == 5:
# [B, GH, M] => [B, G, H, M]
return lse.unflatten(1, original_query_shape[2:4])
return lse
q_seqinfo = inp.attn_bias.q_seqinfo
B = len(q_seqinfo.seqstart_py) - 1
Expand Down Expand Up @@ -450,6 +453,7 @@ class FwOp(AttentionFwOpBase):
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = False
SUPPORTS_BMGHK = True
SUPPORTS_PARTIAL = True
NAME = f"flshattF@{FLASH_VERSION}"
VERSION = FLASH_VERSION

Expand All @@ -461,6 +465,13 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
_check_strides_for_bmghk(d.query, "query", reasons)
_check_strides_for_bmghk(d.key, "key", reasons)
_check_strides_for_bmghk(d.value, "value", reasons)
if d.is_partial and isinstance(
d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask
):
q_seqinfo = d.attn_bias.q_seqinfo
if q_seqinfo.min_seqlen != q_seqinfo.max_seqlen:
# Flash provides padded LSE which we don't handle.
reasons.append("partial attention with heterogeneous queries")
return reasons

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,9 @@ def _splitK_reduce(
+ stride_om * off_m
+ tl.arange(0, BLOCK_SIZE)
)
if acc.dtype is tl.float64 and Out.dtype.element_ty is not tl.float64:
# must avoid direct cast f64->f16
acc = acc.to(tl.float32)
tl.store(Out_ptr, acc)

if WRITE_LSE:
Expand Down Expand Up @@ -724,6 +727,7 @@ class FwOp(AttentionFwOpBase):
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_BMGHK = True
SUPPORTS_OUTPUT_DTYPE = True
SUPPORTS_PARTIAL = True
NAME = "triton_splitKF"

SPLIT_K: Optional[int] = None
Expand Down

0 comments on commit fe0526b

Please sign in to comment.