diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index b6bd84a0bd..c4ff7562ac 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -2612,7 +2612,7 @@ def paged_attention_run_inner( @pytest.mark.parametrize( "dtype,op", [ - (torch.bfloat16, fmha.triton_splitk.FwOp), + (torch.bfloat16, fmha.triton_splitk.FwOp_S1), # Cutlass's LSE is not consistent # (torch.float32, fmha.cutlass.FwOp), (torch.bfloat16, fmha.flash.FwOp), @@ -2634,6 +2634,7 @@ def test_merge_attentions_decoding( D_H = 128 G = 2 if bmghk else 1 torch.manual_seed(1) + output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None num_chunks = 10 @@ -2686,6 +2687,7 @@ def test_merge_attentions_decoding( axv, attn_bias, op=op, + output_dtype=output_dtype, ) if bmghk: assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H) @@ -2699,6 +2701,7 @@ def test_merge_attentions_decoding( attn_split = torch.stack([attn_chunk for attn_chunk, _ in chunks_output]) lse_split = torch.stack([lse_chunk for _, lse_chunk in chunks_output]) attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split) + assert lse_out is not None # Compute attention on the full K/V attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( @@ -2717,12 +2720,17 @@ def test_merge_attentions_decoding( axv, attn_bias, op=op, + output_dtype=output_dtype, ) - atol = op.ERROR_ATOL[dtype] * (10 if op is fmha.triton_splitk.FwOp else 1) + atol = op.ERROR_ATOL[dtype] rtol = op.ERROR_RTOL[dtype] * 2 - assert_allclose(lse_out, lse_full, rtol=rtol * 2, atol=atol, msg="lse") - assert_allclose(attn_out, attn_full, rtol=rtol, atol=atol, msg="out") + assert_allclose( + lse_out.to(lse_full.dtype), lse_full, rtol=rtol * 2, atol=atol, msg="lse" + ) + assert_allclose( + attn_out.to(attn_full.dtype), attn_full, rtol=rtol, atol=atol, msg="out" + ) @sm80_or_better_only diff --git a/xformers/ops/fmha/__init__.py b/xformers/ops/fmha/__init__.py index 9265fc25be..f7aeafda90 100644 --- a/xformers/ops/fmha/__init__.py +++ b/xformers/ops/fmha/__init__.py @@ -122,6 +122,7 @@ def memory_efficient_attention( scale: Optional[float] = None, *, op: Optional[AttentionOp] = None, + output_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """Implements the memory-efficient attention mechanism following `"Self-Attention Does Not Need O(n^2) Memory" `_. @@ -226,7 +227,13 @@ def memory_efficient_attention( """ return _memory_efficient_attention( Inputs( - query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale + query=query, + key=key, + value=value, + p=p, + attn_bias=attn_bias, + scale=scale, + output_dtype=output_dtype, ), op=op, ) @@ -241,13 +248,20 @@ def memory_efficient_attention_forward( scale: Optional[float] = None, *, op: Optional[Type[AttentionFwOpBase]] = None, + output_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """ Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`. """ return _memory_efficient_attention_forward( Inputs( - query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale + query=query, + key=key, + value=value, + p=p, + attn_bias=attn_bias, + scale=scale, + output_dtype=output_dtype, ), op=op, ) @@ -262,6 +276,7 @@ def memory_efficient_attention_forward_requires_grad( scale: Optional[float] = None, *, op: Optional[Type[AttentionFwOpBase]] = None, + output_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later. @@ -275,7 +290,13 @@ def memory_efficient_attention_forward_requires_grad( ) out, ctx = _memory_efficient_attention_forward_requires_grad( Inputs( - query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale + query=query, + key=key, + value=value, + p=p, + attn_bias=attn_bias, + scale=scale, + output_dtype=output_dtype, ), op=op, ) @@ -463,7 +484,7 @@ def merge_attentions( ) if write_lse: lse_out = torch.empty( - B * H * G, M, device=attn_split.device, dtype=torch.float32 + B * H * G, M, device=attn_split.device, dtype=lse_split.dtype ) else: lse_out = None diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index a7aaf7d887..7eddf11eb0 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -52,6 +52,7 @@ class Inputs: attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None p: float = 0.0 scale: Optional[float] = None + output_dtype: Optional[torch.dtype] = None @property def device(self) -> torch.device: @@ -203,6 +204,11 @@ def validate_inputs(self) -> None: "yourself before calling `memory_efficient_attention` if you need to" ) + def get_output_dtype(self) -> torch.dtype: + if self.output_dtype is None: + return self.query.dtype + return self.output_dtype + @dataclass class Context: @@ -257,6 +263,7 @@ class AttentionOpBase(BaseOperator): SUPPORTS_DROPOUT: bool SUPPORTS_CUSTOM_SCALE: bool = False SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False + SUPPORTS_OUTPUT_DTYPE: bool = False IS_DETERMINISTIC: bool = True SUPPORTS_BMGHK: bool = False NAME: str @@ -312,6 +319,9 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})") if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES: reasons.append(f"attn_bias type is {type(d.attn_bias)}") + if not cls.SUPPORTS_OUTPUT_DTYPE: + if d.output_dtype is not None and d.output_dtype is not dtype: + reasons.append("Custom output dtype not supported") if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT: reasons.append("dropout > 0.0") if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE: diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index b30bcd265f..7c6586917f 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -123,6 +123,9 @@ def _fwd_kernel_splitK( Set IS_SPLITK=False to indicate the MHA result should be written directly. No metadata will be written. """ + internal_dtype = ( + tl.float64 if Out_splitK.dtype.element_ty is tl.float64 else tl.float32 + ) tl.static_assert( (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), @@ -239,7 +242,9 @@ def _fwd_kernel_splitK( acc: "VAR_ARGS_ARRAY" # noqa: F821 for i in range(len(acc)): # noqa: F821 - acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + acc[i] = tl.zeros( # noqa: F821 + [BLOCK_M, D_PER_GROUP], dtype=internal_dtype + ) # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop @@ -718,6 +723,7 @@ class FwOp(AttentionFwOpBase): SUPPORTS_DROPOUT = False SUPPORTS_CUSTOM_SCALE = True SUPPORTS_BMGHK = True + SUPPORTS_OUTPUT_DTYPE = True NAME = "triton_splitKF" SPLIT_K: Optional[int] = None @@ -818,6 +824,7 @@ def get_split_k(cls, B: int, H: int, Mk: int) -> int: def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: + output_dtype = inp.get_output_dtype() attn_bias = inp.attn_bias seq_len = None q, k, v = inp.get_qkv_in_bmghk() @@ -905,24 +912,30 @@ def apply( M_ceil = (M + cls.MAX_BLOCK_M - 1) // cls.MAX_BLOCK_M * cls.MAX_BLOCK_M IS_SPLITK = split_k > 1 # or cls.autotune? if IS_SPLITK: + o_splitk_dtype = ( + torch.float64 if output_dtype == torch.float64 else torch.float32 + ) o_splitk = torch.empty( [B, G * H, split_k, M_ceil, Kq], - dtype=torch.float32, + dtype=o_splitk_dtype, device=q.device, ) else: o_splitk = torch.empty( [B, split_k, M, G * H, Kq], - dtype=q.dtype, + dtype=output_dtype, device=q.device, ).permute(0, 3, 1, 2, 4) lse, lse_splitk = None, None + # LSE may need higher precision than output + output_f64_lse = output_dtype in (torch.float32, torch.float64) if IS_SPLITK and needs_gradient: - lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + lse_dtype = torch.float64 if output_f64_lse else torch.float32 + lse = torch.empty((B * G * H, M), device=q.device, dtype=lse_dtype) if IS_SPLITK or needs_gradient: lse_splitk = torch.empty( [B, G * H, split_k, M], - dtype=torch.float64 if IS_SPLITK else torch.float32, + dtype=torch.float64 if IS_SPLITK or output_f64_lse else torch.float32, device=q.device, ) @@ -1002,11 +1015,11 @@ def grid(META): return out, Context(out=out, lse=lse) if mqa_swap_seqlen_head: - out = torch.empty((B, G, M, 1, Kq), device=q.device, dtype=q.dtype).permute( - 0, 2, 1, 3, 4 - ) + out = torch.empty( + (B, G, M, 1, Kq), device=q.device, dtype=output_dtype + ).permute(0, 2, 1, 3, 4) else: - out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=output_dtype) # Merge attention and LSE outputs from different split-k chunks assert lse_splitk is not None