Skip to content

Commit

Permalink
output_dtype for triton_splitk
Browse files Browse the repository at this point in the history
Allows asking for higher precision, and making
merge_attentions test pass with lower tol.

ghstack-source-id: e9916afbf7da39a23a7c91579baccb493cdc626b
Pull Request resolved: fairinternal/xformers#1040

__original_commit__ = fairinternal/xformers@72549fc
  • Loading branch information
bottler authored and xFormers Bot committed Feb 29, 2024
1 parent 41ebc03 commit c9d9be2
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 17 deletions.
16 changes: 12 additions & 4 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2612,7 +2612,7 @@ def paged_attention_run_inner(
@pytest.mark.parametrize(
"dtype,op",
[
(torch.bfloat16, fmha.triton_splitk.FwOp),
(torch.bfloat16, fmha.triton_splitk.FwOp_S1),
# Cutlass's LSE is not consistent
# (torch.float32, fmha.cutlass.FwOp),
(torch.bfloat16, fmha.flash.FwOp),
Expand All @@ -2634,6 +2634,7 @@ def test_merge_attentions_decoding(
D_H = 128
G = 2 if bmghk else 1
torch.manual_seed(1)
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None

num_chunks = 10

Expand Down Expand Up @@ -2686,6 +2687,7 @@ def test_merge_attentions_decoding(
axv,
attn_bias,
op=op,
output_dtype=output_dtype,
)
if bmghk:
assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H)
Expand All @@ -2699,6 +2701,7 @@ def test_merge_attentions_decoding(
attn_split = torch.stack([attn_chunk for attn_chunk, _ in chunks_output])
lse_split = torch.stack([lse_chunk for _, lse_chunk in chunks_output])
attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split)
assert lse_out is not None

# Compute attention on the full K/V
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
Expand All @@ -2717,12 +2720,17 @@ def test_merge_attentions_decoding(
axv,
attn_bias,
op=op,
output_dtype=output_dtype,
)

atol = op.ERROR_ATOL[dtype] * (10 if op is fmha.triton_splitk.FwOp else 1)
atol = op.ERROR_ATOL[dtype]
rtol = op.ERROR_RTOL[dtype] * 2
assert_allclose(lse_out, lse_full, rtol=rtol * 2, atol=atol, msg="lse")
assert_allclose(attn_out, attn_full, rtol=rtol, atol=atol, msg="out")
assert_allclose(
lse_out.to(lse_full.dtype), lse_full, rtol=rtol * 2, atol=atol, msg="lse"
)
assert_allclose(
attn_out.to(attn_full.dtype), attn_full, rtol=rtol, atol=atol, msg="out"
)


@sm80_or_better_only
Expand Down
29 changes: 25 additions & 4 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def memory_efficient_attention(
scale: Optional[float] = None,
*,
op: Optional[AttentionOp] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Implements the memory-efficient attention mechanism following
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
Expand Down Expand Up @@ -226,7 +227,13 @@ def memory_efficient_attention(
"""
return _memory_efficient_attention(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
Expand All @@ -241,13 +248,20 @@ def memory_efficient_attention_forward(
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
"""
return _memory_efficient_attention_forward(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
Expand All @@ -262,6 +276,7 @@ def memory_efficient_attention_forward_requires_grad(
scale: Optional[float] = None,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
output_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
Expand All @@ -275,7 +290,13 @@ def memory_efficient_attention_forward_requires_grad(
)
out, ctx = _memory_efficient_attention_forward_requires_grad(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
query=query,
key=key,
value=value,
p=p,
attn_bias=attn_bias,
scale=scale,
output_dtype=output_dtype,
),
op=op,
)
Expand Down Expand Up @@ -463,7 +484,7 @@ def merge_attentions(
)
if write_lse:
lse_out = torch.empty(
B * H * G, M, device=attn_split.device, dtype=torch.float32
B * H * G, M, device=attn_split.device, dtype=lse_split.dtype
)
else:
lse_out = None
Expand Down
10 changes: 10 additions & 0 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Inputs:
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
p: float = 0.0
scale: Optional[float] = None
output_dtype: Optional[torch.dtype] = None

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -203,6 +204,11 @@ def validate_inputs(self) -> None:
"yourself before calling `memory_efficient_attention` if you need to"
)

def get_output_dtype(self) -> torch.dtype:
if self.output_dtype is None:
return self.query.dtype
return self.output_dtype


@dataclass
class Context:
Expand Down Expand Up @@ -257,6 +263,7 @@ class AttentionOpBase(BaseOperator):
SUPPORTS_DROPOUT: bool
SUPPORTS_CUSTOM_SCALE: bool = False
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
SUPPORTS_OUTPUT_DTYPE: bool = False
IS_DETERMINISTIC: bool = True
SUPPORTS_BMGHK: bool = False
NAME: str
Expand Down Expand Up @@ -312,6 +319,9 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
if not cls.SUPPORTS_OUTPUT_DTYPE:
if d.output_dtype is not None and d.output_dtype is not dtype:
reasons.append("Custom output dtype not supported")
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
reasons.append("dropout > 0.0")
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
Expand Down
31 changes: 22 additions & 9 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def _fwd_kernel_splitK(
Set IS_SPLITK=False to indicate the MHA result should be written directly.
No metadata will be written.
"""
internal_dtype = (
tl.float64 if Out_splitK.dtype.element_ty is tl.float64 else tl.float32
)
tl.static_assert(
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
Expand Down Expand Up @@ -239,7 +242,9 @@ def _fwd_kernel_splitK(
acc: "VAR_ARGS_ARRAY" # noqa: F821

for i in range(len(acc)): # noqa: F821
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821
acc[i] = tl.zeros( # noqa: F821
[BLOCK_M, D_PER_GROUP], dtype=internal_dtype
)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
Expand Down Expand Up @@ -718,6 +723,7 @@ class FwOp(AttentionFwOpBase):
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_BMGHK = True
SUPPORTS_OUTPUT_DTYPE = True
NAME = "triton_splitKF"

SPLIT_K: Optional[int] = None
Expand Down Expand Up @@ -818,6 +824,7 @@ def get_split_k(cls, B: int, H: int, Mk: int) -> int:
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
output_dtype = inp.get_output_dtype()
attn_bias = inp.attn_bias
seq_len = None
q, k, v = inp.get_qkv_in_bmghk()
Expand Down Expand Up @@ -905,24 +912,30 @@ def apply(
M_ceil = (M + cls.MAX_BLOCK_M - 1) // cls.MAX_BLOCK_M * cls.MAX_BLOCK_M
IS_SPLITK = split_k > 1 # or cls.autotune?
if IS_SPLITK:
o_splitk_dtype = (
torch.float64 if output_dtype == torch.float64 else torch.float32
)
o_splitk = torch.empty(
[B, G * H, split_k, M_ceil, Kq],
dtype=torch.float32,
dtype=o_splitk_dtype,
device=q.device,
)
else:
o_splitk = torch.empty(
[B, split_k, M, G * H, Kq],
dtype=q.dtype,
dtype=output_dtype,
device=q.device,
).permute(0, 3, 1, 2, 4)
lse, lse_splitk = None, None
# LSE may need higher precision than output
output_f64_lse = output_dtype in (torch.float32, torch.float64)
if IS_SPLITK and needs_gradient:
lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32)
lse_dtype = torch.float64 if output_f64_lse else torch.float32
lse = torch.empty((B * G * H, M), device=q.device, dtype=lse_dtype)
if IS_SPLITK or needs_gradient:
lse_splitk = torch.empty(
[B, G * H, split_k, M],
dtype=torch.float64 if IS_SPLITK else torch.float32,
dtype=torch.float64 if IS_SPLITK or output_f64_lse else torch.float32,
device=q.device,
)

Expand Down Expand Up @@ -1002,11 +1015,11 @@ def grid(META):
return out, Context(out=out, lse=lse)

if mqa_swap_seqlen_head:
out = torch.empty((B, G, M, 1, Kq), device=q.device, dtype=q.dtype).permute(
0, 2, 1, 3, 4
)
out = torch.empty(
(B, G, M, 1, Kq), device=q.device, dtype=output_dtype
).permute(0, 2, 1, 3, 4)
else:
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype)
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=output_dtype)

# Merge attention and LSE outputs from different split-k chunks
assert lse_splitk is not None
Expand Down

0 comments on commit c9d9be2

Please sign in to comment.