diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py
index a5ce9056a2..9edacd0672 100644
--- a/tests/test_mem_eff_attention.py
+++ b/tests/test_mem_eff_attention.py
@@ -2236,4 +2236,66 @@ def test_mqa_decoding(op: Type[fmha.AttentionFwOpBase], dtype, B_Mkv_H_K):
)
+@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
+def test_empty_tensors_empty_query(
+ opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
+):
+ query, key, value, attn_bias = create_tensors(
+ *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
+ fmt="BMHK",
+ )
+ opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]
+
+ query = query[:, :0]
+ query.requires_grad_(True)
+ key.requires_grad_(True)
+ value.requires_grad_(True)
+ out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
+ assert out.shape[1] == 0
+ out.backward(out)
+ # dK/dV should be all zeros
+ assert_allclose(key.grad, torch.zeros_like(key.grad), "key.grad")
+ assert_allclose(value.grad, torch.zeros_like(value.grad), "value.grad")
+
+
+@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
+def test_empty_tensors_empty_kv(
+ opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
+):
+ query, key, value, attn_bias = create_tensors(
+ *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
+ fmt="BMHK",
+ )
+ opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]
+
+ key = key[:, :0]
+ value = value[:, :0]
+ query.requires_grad_(True)
+ key.requires_grad_(True)
+ value.requires_grad_(True)
+ out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
+ assert_allclose(out, torch.zeros_like(out), "out")
+ out.backward(out)
+ # dQ should be all zeros
+ assert_allclose(query.grad, torch.zeros_like(query.grad), "query.grad")
+
+
+@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
+def test_empty_tensors_empty_b(
+ opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
+):
+ query, key, value, attn_bias = create_tensors(
+ *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
+ fmt="BMHK",
+ )
+ opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]
+
+ query, key, value = query[:0], key[:0], value[:0]
+ query.requires_grad_(True)
+ key.requires_grad_(True)
+ value.requires_grad_(True)
+ out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
+ out.backward(out)
+
+
# end of file
diff --git a/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu b/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu
index 6efc0e9465..28ff346a0f 100644
--- a/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu
+++ b/xformers/csrc/attention/cuda/fmha/attention_backward_generic.cu
@@ -353,6 +353,15 @@ mem_efficient_attention_backward_cutlass(
workspace.zero_();
}
}
+
+ // Handle the edge-cases where some tensors are empty
+ if (p.num_queries == 0 || p.num_keys == 0 || p.num_batches == 0 ||
+ p.num_heads == 0) {
+ grad_k.zero_();
+ grad_v.zero_();
+ grad_q.zero_();
+ return;
+ }
Kernel::check_supported(p);
if (smem_bytes > 0xc000) {
diff --git a/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu b/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu
index 365ac999fc..5a08fb591c 100644
--- a/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu
+++ b/xformers/csrc/attention/cuda/fmha/attention_forward_generic.cu
@@ -279,8 +279,13 @@ efficient_attention_forward_cutlass(
" kb)");
AT_CUDA_CHECK(err);
}
+ auto blocks = p.getBlocksGrid();
+ if (blocks.x * blocks.y * blocks.z == 0 || key.size(1) == 0) {
+ res.zero_();
+ return;
+ }
Kernel::check_supported(p);
- kernel_fn<<
>>(p);
+ kernel_fn<<>>(p);
};
// Dispatch to the right kernel
diff --git a/xformers/csrc/attention/cuda/fmha/kernel_backward.h b/xformers/csrc/attention/cuda/fmha/kernel_backward.h
index 4fb2ae8d46..3276e11612 100644
--- a/xformers/csrc/attention/cuda/fmha/kernel_backward.h
+++ b/xformers/csrc/attention/cuda/fmha/kernel_backward.h
@@ -1179,8 +1179,12 @@ struct AttentionBackwardKernel {
CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment);
- XFORMERS_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned");
- XFORMERS_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned");
+ XFORMERS_CHECK(
+ p.num_heads <= 1 || p.lse_strideH % 8 == 0,
+ "LSE is not correctly aligned (strideH)");
+ XFORMERS_CHECK(
+ p.num_batches <= 1 || p.lse_strideB % 8 == 0,
+ "LSE is not correctly aligned (strideB)");
XFORMERS_CHECK(
p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0,
"query is not correctly aligned (strideH)");
diff --git a/xformers/csrc/attention/cuda/fmha/small_k.cu b/xformers/csrc/attention/cuda/fmha/small_k.cu
index 09ced11b99..52632f64a1 100644
--- a/xformers/csrc/attention/cuda/fmha/small_k.cu
+++ b/xformers/csrc/attention/cuda/fmha/small_k.cu
@@ -652,6 +652,11 @@ void launch_attention(
dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B);
dim3 block(WARP_SIZE, TILE_SIZE / kBlockSizeQ);
+ if (grid.x * grid.y * grid.z == 0 || key.numel() == 0) {
+ res.zero_();
+ return;
+ }
+
using scalar_t = float;
auto attn_bias_packed = _packed_tensor_accessor_or_dummy(attn_bias);
@@ -1100,6 +1105,9 @@ void launch_attention_backward(
dim3 grid(
ceil_div(M, int64_t(TILE_SIZEQ)), ceil_div(N, int64_t(TILE_SIZEK)), B);
dim3 block(TILE_SIZEQ / kBlockSizeQ, TILE_SIZEK / kBlockSizeK);
+ if (grid.x * grid.y * grid.z == 0) {
+ return;
+ }
// the bounds checking in device code is very expensive, making the code
// around 25% slower. So let's skip those checks if possible.
@@ -1444,10 +1452,11 @@ at::Tensor _dropout_mask(at::Tensor output, double p) {
// invert from drop probability to keep probability
p = 1.0 - p;
- dropout_kernel
- <<>>(
- output.packed_accessor(), p, rng_engine_inputs);
-
+ if (grid.x * grid.y * grid.z > 0) {
+ dropout_kernel
+ <<>>(
+ output.packed_accessor(), p, rng_engine_inputs);
+ }
return output;
}
diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py
index b2d013b099..4fbf26c6aa 100644
--- a/xformers/ops/fmha/common.py
+++ b/xformers/ops/fmha/common.py
@@ -515,9 +515,9 @@ def from_arguments(
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
if tensor.ndim == 4:
return tensor
- return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
- (0, 2, 1, 3)
- )
+ return tensor.reshape(
+ [tensor.shape[0] // num_heads, num_heads, tensor.shape[1], tensor.shape[2]]
+ ).permute((0, 2, 1, 3))
def check_lastdim_alignment_stride1(
diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py
index ed53c700d0..628946824b 100644
--- a/xformers/ops/fmha/flash.py
+++ b/xformers/ops/fmha/flash.py
@@ -393,21 +393,30 @@ def apply(
cu_seqlens_k,
max_seqlen_k,
) = _convert_input_format(inp)
- out, softmax_lse, rng_state = cls.OPERATOR(
- inp.query,
- inp.key,
- inp.value,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- inp.p,
- inp.scale_float,
- _is_causal(inp.attn_bias),
- _window_size(inp.attn_bias),
- return_softmax,
- )
- out = out.reshape(out_shape)
+ if inp.query.numel() > 0 and inp.key.numel() > 0:
+ out, softmax_lse, rng_state = cls.OPERATOR(
+ inp.query,
+ inp.key,
+ inp.value,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ inp.p,
+ inp.scale_float,
+ _is_causal(inp.attn_bias),
+ _window_size(inp.attn_bias),
+ return_softmax,
+ )
+ out = out.reshape(out_shape)
+ else:
+ out = torch.zeros(out_shape, device=inp.query.device, dtype=inp.query.dtype)
+ rng_state = None
+ softmax_lse = torch.empty(
+ [inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]],
+ device=inp.query.device,
+ dtype=torch.float32,
+ )
ctx = Context(out=out, lse=softmax_lse)
if inp.p != 0.0:
ctx.op_bw = BwOp
@@ -544,26 +553,33 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
)
assert grad.dtype in cls.SUPPORTED_DTYPES
- cls.OPERATOR(
- grad.reshape(kernel_out_shape).contiguous(),
- inp.query,
- inp.key,
- inp.value,
- ctx.out.reshape(kernel_out_shape),
- ctx_lse,
- grads.dq,
- grads.dk,
- grads.dv,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- inp.p,
- inp.scale_float,
- _is_causal(inp.attn_bias),
- _window_size(inp.attn_bias),
- ctx.rng_state,
- )
+
+ if grads.dq.numel() == 0:
+ grads.dk.zero_()
+ grads.dv.zero_()
+ if grads.dv.numel() == 0:
+ grads.dq.zero_()
+ if grads.dq.numel() and grads.dk.numel():
+ cls.OPERATOR(
+ grad.reshape(kernel_out_shape).contiguous(),
+ inp.query,
+ inp.key,
+ inp.value,
+ ctx.out.reshape(kernel_out_shape),
+ ctx_lse,
+ grads.dq,
+ grads.dk,
+ grads.dv,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ inp.p,
+ inp.scale_float,
+ _is_causal(inp.attn_bias),
+ _window_size(inp.attn_bias),
+ ctx.rng_state,
+ )
grads.dq = grads.dq.reshape(dq_shape)
grads.dk = grads.dk.reshape(dk_shape)
grads.dv = grads.dv.reshape(dv_shape)
diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py
index dd7d2956ce..1c4f6d9421 100644
--- a/xformers/ops/fmha/triton_splitk.py
+++ b/xformers/ops/fmha/triton_splitk.py
@@ -551,7 +551,7 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
@classmethod
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
- bh = B * H
+ bh = max(B * H, 1) # NOTE: Handle B*h=0 case
split_k = max(Mk, 1024) // bh
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
while split_k > 0 and Mk / split_k < max_chunk_size:
@@ -694,6 +694,8 @@ def apply(
assert G == 1
out = out[:, :, 0]
lse = lse[:, 0]
+ if Mk == 0:
+ out.zero_()
return out, Context(out=out, lse=lse)