diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index a5ce9056a2..9edacd0672 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -2236,4 +2236,66 @@ def test_mqa_decoding(op: Type[fmha.AttentionFwOpBase], dtype, B_Mkv_H_K): ) +@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs +def test_empty_tensors_empty_query( + opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, +): + query, key, value, attn_bias = create_tensors( + *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, + fmt="BMHK", + ) + opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] + + query = query[:, :0] + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) + assert out.shape[1] == 0 + out.backward(out) + # dK/dV should be all zeros + assert_allclose(key.grad, torch.zeros_like(key.grad), "key.grad") + assert_allclose(value.grad, torch.zeros_like(value.grad), "value.grad") + + +@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs +def test_empty_tensors_empty_kv( + opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, +): + query, key, value, attn_bias = create_tensors( + *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, + fmt="BMHK", + ) + opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] + + key = key[:, :0] + value = value[:, :0] + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) + assert_allclose(out, torch.zeros_like(out), "out") + out.backward(out) + # dQ should be all zeros + assert_allclose(query.grad, torch.zeros_like(query.grad), "query.grad") + + +@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs +def test_empty_tensors_empty_b( + opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, +): + query, key, value, attn_bias = create_tensors( + *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, + fmt="BMHK", + ) + opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] + + query, key, value = query[:0], key[:0], value[:0] + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) + out.backward(out) + + # end of file diff --git a/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu b/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu index 6efc0e9465..28ff346a0f 100644 --- a/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu +++ b/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu @@ -353,6 +353,15 @@ mem_efficient_attention_backward_cutlass( workspace.zero_(); } } + + // Handle the edge-cases where some tensors are empty + if (p.num_queries == 0 || p.num_keys == 0 || p.num_batches == 0 || + p.num_heads == 0) { + grad_k.zero_(); + grad_v.zero_(); + grad_q.zero_(); + return; + } Kernel::check_supported(p); if (smem_bytes > 0xc000) { diff --git a/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu b/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu index 365ac999fc..5a08fb591c 100644 --- a/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu +++ b/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu @@ -279,8 +279,13 @@ efficient_attention_forward_cutlass( " kb)"); AT_CUDA_CHECK(err); } + auto blocks = p.getBlocksGrid(); + if (blocks.x * blocks.y * blocks.z == 0 || key.size(1) == 0) { + res.zero_(); + return; + } Kernel::check_supported(p); - kernel_fn<<>>(p); + kernel_fn<<>>(p); }; // Dispatch to the right kernel diff --git a/xformers/csrc/attention/cuda/fmha/kernel_backward.h b/xformers/csrc/attention/cuda/fmha/kernel_backward.h index 4fb2ae8d46..3276e11612 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_backward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_backward.h @@ -1179,8 +1179,12 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); - XFORMERS_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); - XFORMERS_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.lse_strideH % 8 == 0, + "LSE is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.lse_strideB % 8 == 0, + "LSE is not correctly aligned (strideB)"); XFORMERS_CHECK( p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned (strideH)"); diff --git a/xformers/csrc/attention/cuda/fmha/small_k.cu b/xformers/csrc/attention/cuda/fmha/small_k.cu index 09ced11b99..52632f64a1 100644 --- a/xformers/csrc/attention/cuda/fmha/small_k.cu +++ b/xformers/csrc/attention/cuda/fmha/small_k.cu @@ -652,6 +652,11 @@ void launch_attention( dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B); dim3 block(WARP_SIZE, TILE_SIZE / kBlockSizeQ); + if (grid.x * grid.y * grid.z == 0 || key.numel() == 0) { + res.zero_(); + return; + } + using scalar_t = float; auto attn_bias_packed = _packed_tensor_accessor_or_dummy(attn_bias); @@ -1100,6 +1105,9 @@ void launch_attention_backward( dim3 grid( ceil_div(M, int64_t(TILE_SIZEQ)), ceil_div(N, int64_t(TILE_SIZEK)), B); dim3 block(TILE_SIZEQ / kBlockSizeQ, TILE_SIZEK / kBlockSizeK); + if (grid.x * grid.y * grid.z == 0) { + return; + } // the bounds checking in device code is very expensive, making the code // around 25% slower. So let's skip those checks if possible. @@ -1444,10 +1452,11 @@ at::Tensor _dropout_mask(at::Tensor output, double p) { // invert from drop probability to keep probability p = 1.0 - p; - dropout_kernel - <<>>( - output.packed_accessor(), p, rng_engine_inputs); - + if (grid.x * grid.y * grid.z > 0) { + dropout_kernel + <<>>( + output.packed_accessor(), p, rng_engine_inputs); + } return output; } diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index b2d013b099..4fbf26c6aa 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -515,9 +515,9 @@ def from_arguments( def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: if tensor.ndim == 4: return tensor - return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute( - (0, 2, 1, 3) - ) + return tensor.reshape( + [tensor.shape[0] // num_heads, num_heads, tensor.shape[1], tensor.shape[2]] + ).permute((0, 2, 1, 3)) def check_lastdim_alignment_stride1( diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index ed53c700d0..628946824b 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -393,21 +393,30 @@ def apply( cu_seqlens_k, max_seqlen_k, ) = _convert_input_format(inp) - out, softmax_lse, rng_state = cls.OPERATOR( - inp.query, - inp.key, - inp.value, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - inp.p, - inp.scale_float, - _is_causal(inp.attn_bias), - _window_size(inp.attn_bias), - return_softmax, - ) - out = out.reshape(out_shape) + if inp.query.numel() > 0 and inp.key.numel() > 0: + out, softmax_lse, rng_state = cls.OPERATOR( + inp.query, + inp.key, + inp.value, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + inp.p, + inp.scale_float, + _is_causal(inp.attn_bias), + _window_size(inp.attn_bias), + return_softmax, + ) + out = out.reshape(out_shape) + else: + out = torch.zeros(out_shape, device=inp.query.device, dtype=inp.query.dtype) + rng_state = None + softmax_lse = torch.empty( + [inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]], + device=inp.query.device, + dtype=torch.float32, + ) ctx = Context(out=out, lse=softmax_lse) if inp.p != 0.0: ctx.op_bw = BwOp @@ -544,26 +553,33 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: ) assert grad.dtype in cls.SUPPORTED_DTYPES - cls.OPERATOR( - grad.reshape(kernel_out_shape).contiguous(), - inp.query, - inp.key, - inp.value, - ctx.out.reshape(kernel_out_shape), - ctx_lse, - grads.dq, - grads.dk, - grads.dv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - inp.p, - inp.scale_float, - _is_causal(inp.attn_bias), - _window_size(inp.attn_bias), - ctx.rng_state, - ) + + if grads.dq.numel() == 0: + grads.dk.zero_() + grads.dv.zero_() + if grads.dv.numel() == 0: + grads.dq.zero_() + if grads.dq.numel() and grads.dk.numel(): + cls.OPERATOR( + grad.reshape(kernel_out_shape).contiguous(), + inp.query, + inp.key, + inp.value, + ctx.out.reshape(kernel_out_shape), + ctx_lse, + grads.dq, + grads.dk, + grads.dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + inp.p, + inp.scale_float, + _is_causal(inp.attn_bias), + _window_size(inp.attn_bias), + ctx.rng_state, + ) grads.dq = grads.dq.reshape(dq_shape) grads.dk = grads.dk.reshape(dk_shape) grads.dv = grads.dv.reshape(dv_shape) diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index dd7d2956ce..1c4f6d9421 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -551,7 +551,7 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: @classmethod def get_split_k(cls, B: int, H: int, Mk: int) -> int: """Heuristic for the number of splits""" - bh = B * H + bh = max(B * H, 1) # NOTE: Handle B*h=0 case split_k = max(Mk, 1024) // bh max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128 while split_k > 0 and Mk / split_k < max_chunk_size: @@ -694,6 +694,8 @@ def apply( assert G == 1 out = out[:, :, 0] lse = lse[:, 0] + if Mk == 0: + out.zero_() return out, Context(out=out, lse=lse)