Skip to content

Commit

Permalink
partial attention bwd (fairinternal/xformers#1130)
Browse files Browse the repository at this point in the history
* partial attention bwd

* fix

* address comments

---------

Co-authored-by: bottler <[email protected]>

__original_commit__ = fairinternal/xformers@8344854
  • Loading branch information
bottler authored and xFormers Bot committed Jun 12, 2024
1 parent 8ce361e commit 0de8212
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 23 deletions.
64 changes: 63 additions & 1 deletion tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
import random
from functools import partial
from typing import List, Optional, Sequence, Tuple, Type, TypeVar
from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union

import pytest
import torch
Expand Down Expand Up @@ -2837,6 +2837,68 @@ def test_merge_attentions_nobias(
assert lse is None


@disable_on_rocm
@sm80_or_better_only
@pytest.mark.parametrize(
"op",
[
pytest.param(fmha.flash.FwOp, id="flashfwd"),
pytest.param((fmha.flash.FwOp, fmha.cutlass.BwOp), id="flashcutlass"),
# pytest.param((fmha.triton_splitk.FwOp, fmha.cutlass.BwOp), id="splitk"), # XXX
pytest.param(fmha.MemoryEfficientAttentionFlashAttentionOp, id="flash"),
None,
],
)
def test_merge_attentions_nobias_bwd(
op: Union[Type[AttentionFwOpBase], fmha.AttentionOp]
):
B, M, Mq, H, K = 13, 5, 5, 4, 128
dtype = torch.bfloat16
nparts = 3
torch.manual_seed(1)
q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda")
kv = [
[3 * (torch.rand(B, M, H, K, dtype=dtype, device="cuda")) for _ in range(2)]
for _ in range(nparts)
]
q = q.requires_grad_(True)
kv = [[j.requires_grad_(True) for j in i] for i in kv]
out_parts = [fmha.memory_efficient_attention_partial(q, k, v, op=op) for k, v in kv]
attn_split, lse_split = [list(x) for x in zip(*out_parts)]
out_merged = fmha.merge_attentions(attn_split, lse_split, write_lse=True)[0]
grad_out = torch.rand_like(q)
out_merged.backward(grad_out)
grad_q_out = q.grad
assert q.grad is not None
grad_kv_out = [[j.grad for j in i] for i in kv]
q = q.detach().requires_grad_(True)
kv = [[j.detach().requires_grad_(True) for j in i] for i in kv]

k2, v2 = [torch.cat([i[j] for i in kv], dim=1) for j in range(2)]

if op is None or isinstance(op, tuple):
full_op = op
else:
full_op = (op, None)
out_full = fmha.memory_efficient_attention(q, k2, v2, op=full_op) # type: ignore
out_full.backward(grad_out)
assert_allclose(
out_merged, out_full.to(out_merged.dtype), rtol=1e-2, atol=2e-2, msg="out"
)
atol = fmha.AttentionBwOpBase.ERROR_ATOL[dtype] * 1.5
rtol = fmha.AttentionBwOpBase.ERROR_RTOL[dtype]
assert_allclose(grad_q_out, q.grad, rtol=rtol, atol=atol, msg="qgrad")
for i in range(nparts):
for j in range(2):
assert_allclose(
grad_kv_out[i][j],
kv[i][j].grad,
rtol=rtol,
atol=atol,
msg=f"kvgrad {i} {j}",
)


@disable_on_rocm
@sm80_or_better_only
@pytest.mark.parametrize(
Expand Down
79 changes: 57 additions & 22 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Optional, Tuple, Type, Union, cast
from typing import Any, List, Optional, Sequence, Tuple, Type, Union, cast

import torch

Expand Down Expand Up @@ -127,11 +127,11 @@ def forward(ctx, op_fw, op_bw, *args: Any) -> Any:
ctx.scale = inp.scale
ctx.attn_bias_ctx = attn_bias_ctx
ctx.n_args = len(args)
return out
return out, op_ctx.lse

@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad):
def backward(ctx, grad, grad_lse):
# Re-create context
query, key, value, out, lse = ctx.saved_tensors
attn_bias_tensor = ctx.attn_bias_tensor
Expand Down Expand Up @@ -402,7 +402,7 @@ def _memory_efficient_attention(
op_bw = _serialize_op(op[1] if op is not None else None)
return _fMHA.apply(
op_fw, op_bw, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
).reshape(output_shape)
)[0].reshape(output_shape)


def _memory_efficient_attention_forward(
Expand Down Expand Up @@ -534,7 +534,7 @@ def memory_efficient_attention_partial(
p: float = 0.0,
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
op: Optional[Union[AttentionOp, Type[AttentionFwOpBase]]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -543,9 +543,14 @@ def memory_efficient_attention_partial(
The outputs of calls to this with the same query and separate keys and values
can be merged with merge_attentions to obtain the attention of the queries
against the disjoint union of the keys and values.
Warning: The backward pass of this function is quite restricted. In particular
we assume that in the forward pass the outputs were only used in merge_attention
calculations, and that LSEs weren't used anywhere except in merge attentions.
"""
if p != 0.0:
raise NotImplementedError("dropout is not supported.")
fwop: Optional[Type[AttentionFwOpBase]] = op[0] if isinstance(op, tuple) else op
if not (
isinstance(
attn_bias,
Expand All @@ -559,31 +564,61 @@ def memory_efficient_attention_partial(
_attn_bias.LowerTriangularMask,
),
)
or op is None
or op.UNPADDED_LSE
or fwop is None
or fwop.UNPADDED_LSE
):
raise ValueError(
f"{type(attn_bias)} is not supported in memory_efficient_attention_partial."
)
out, ctx = _memory_efficient_attention_forward_requires_grad(
Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
is_partial=True,
),
op=op,
inp = Inputs(
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
is_partial=True,
)

is_grad = torch.is_grad_enabled() and any(
x.requires_grad for x in [query, key, value]
)

if not is_grad:
out, ctx = _memory_efficient_attention_forward_requires_grad(
inp,
op=fwop,
)
return out, ctx.lse

if query.ndim == 5:
raise ValueError("gradients not supported for 5D tensors")
if isinstance(op, tuple):
op_fw = _serialize_op(op[0])
op_bw = _serialize_op(op[1])
elif op is None:
op_fw = op_bw = None
else:
op_fw = _serialize_op(op)
op_bw = None
return _fMHA.apply(
op_fw,
op_bw,
inp.query,
inp.key,
inp.value,
inp.attn_bias,
inp.p,
inp.scale,
inp.output_dtype,
inp.is_partial,
)
return out, ctx.lse


def merge_attentions(
attn_split: Union[torch.Tensor, List[torch.Tensor]],
lse_split: Union[torch.Tensor, List[torch.Tensor]],
attn_split: Union[torch.Tensor, Sequence[torch.Tensor]],
lse_split: Union[torch.Tensor, Sequence[torch.Tensor]],
write_lse: bool = True,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
Expand Down
1 change: 1 addition & 0 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ class AttentionBwOpBase(AttentionOpBase):
torch.bfloat16: 0.1,
}
SUPPORTS_ATTN_BIAS_GRAD = False
SUPPORTS_PARTIAL = True
SUPPORTS_UNPADDED_LSE = False

@classmethod
Expand Down

0 comments on commit 0de8212

Please sign in to comment.