Skip to content

Commit

Permalink
[BUG][fMHA] Support empty tensors (fairinternal/xformers#817)
Browse files Browse the repository at this point in the history
Co-authored-by: danthe3rd <danthe3rd>

__original_commit__ = fairinternal/xformers@ca7c6c5
  • Loading branch information
danthe3rd authored and xFormers Bot committed Oct 9, 2023
1 parent fcc3a25 commit d18d0ef
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 46 deletions.
62 changes: 62 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,4 +2236,66 @@ def test_mqa_decoding(op: Type[fmha.AttentionFwOpBase], dtype, B_Mkv_H_K):
)


@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_empty_tensors_empty_query(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
):
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
)
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]

query = query[:, :0]
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
assert out.shape[1] == 0
out.backward(out)
# dK/dV should be all zeros
assert_allclose(key.grad, torch.zeros_like(key.grad), "key.grad")
assert_allclose(value.grad, torch.zeros_like(value.grad), "value.grad")


@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_empty_tensors_empty_kv(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
):
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
)
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]

key = key[:, :0]
value = value[:, :0]
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
assert_allclose(out, torch.zeros_like(out), "out")
out.backward(out)
# dQ should be all zeros
assert_allclose(query.grad, torch.zeros_like(query.grad), "query.grad")


@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_empty_tensors_empty_b(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
):
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
)
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]

query, key, value = query[:0], key[:0], value[:0]
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
out.backward(out)


# end of file
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,15 @@ mem_efficient_attention_backward_cutlass(
workspace.zero_();
}
}

// Handle the edge-cases where some tensors are empty
if (p.num_queries == 0 || p.num_keys == 0 || p.num_batches == 0 ||
p.num_heads == 0) {
grad_k.zero_();
grad_v.zero_();
grad_q.zero_();
return;
}
Kernel::check_supported(p);

if (smem_bytes > 0xc000) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,13 @@ efficient_attention_forward_cutlass(
" kb)");
AT_CUDA_CHECK(err);
}
auto blocks = p.getBlocksGrid();
if (blocks.x * blocks.y * blocks.z == 0 || key.size(1) == 0) {
res.zero_();
return;
}
Kernel::check_supported(p);
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
kernel_fn<<<blocks, p.getThreadsGrid(), smem_bytes, stream>>>(p);
};

// Dispatch to the right kernel
Expand Down
8 changes: 6 additions & 2 deletions xformers/csrc/attention/cuda/fmha/kernel_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -1179,8 +1179,12 @@ struct AttentionBackwardKernel {
CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment);
XFORMERS_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned");
XFORMERS_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned");
XFORMERS_CHECK(
p.num_heads <= 1 || p.lse_strideH % 8 == 0,
"LSE is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.num_batches <= 1 || p.lse_strideB % 8 == 0,
"LSE is not correctly aligned (strideB)");
XFORMERS_CHECK(
p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0,
"query is not correctly aligned (strideH)");
Expand Down
17 changes: 13 additions & 4 deletions xformers/csrc/attention/cuda/fmha/small_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,11 @@ void launch_attention(
dim3 grid(ceil_div(M, int64_t(TILE_SIZE)), B);
dim3 block(WARP_SIZE, TILE_SIZE / kBlockSizeQ);

if (grid.x * grid.y * grid.z == 0 || key.numel() == 0) {
res.zero_();
return;
}

using scalar_t = float;

auto attn_bias_packed = _packed_tensor_accessor_or_dummy<scalar_t>(attn_bias);
Expand Down Expand Up @@ -1100,6 +1105,9 @@ void launch_attention_backward(
dim3 grid(
ceil_div(M, int64_t(TILE_SIZEQ)), ceil_div(N, int64_t(TILE_SIZEK)), B);
dim3 block(TILE_SIZEQ / kBlockSizeQ, TILE_SIZEK / kBlockSizeK);
if (grid.x * grid.y * grid.z == 0) {
return;
}

// the bounds checking in device code is very expensive, making the code
// around 25% slower. So let's skip those checks if possible.
Expand Down Expand Up @@ -1444,10 +1452,11 @@ at::Tensor _dropout_mask(at::Tensor output, double p) {
// invert from drop probability to keep probability
p = 1.0 - p;

dropout_kernel<scalar_t, scalar_t, kBlockSizeK, kBlockSizeQ, WARP_SIZE>
<<<grid, block, 0, stream>>>(
output.packed_accessor<scalar_t, 3>(), p, rng_engine_inputs);

if (grid.x * grid.y * grid.z > 0) {
dropout_kernel<scalar_t, scalar_t, kBlockSizeK, kBlockSizeQ, WARP_SIZE>
<<<grid, block, 0, stream>>>(
output.packed_accessor<scalar_t, 3>(), p, rng_engine_inputs);
}
return output;
}

Expand Down
6 changes: 3 additions & 3 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,9 @@ def from_arguments(
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
if tensor.ndim == 4:
return tensor
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
(0, 2, 1, 3)
)
return tensor.reshape(
[tensor.shape[0] // num_heads, num_heads, tensor.shape[1], tensor.shape[2]]
).permute((0, 2, 1, 3))


def check_lastdim_alignment_stride1(
Expand Down
86 changes: 51 additions & 35 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,21 +393,30 @@ def apply(
cu_seqlens_k,
max_seqlen_k,
) = _convert_input_format(inp)
out, softmax_lse, rng_state = cls.OPERATOR(
inp.query,
inp.key,
inp.value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
return_softmax,
)
out = out.reshape(out_shape)
if inp.query.numel() > 0 and inp.key.numel() > 0:
out, softmax_lse, rng_state = cls.OPERATOR(
inp.query,
inp.key,
inp.value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
return_softmax,
)
out = out.reshape(out_shape)
else:
out = torch.zeros(out_shape, device=inp.query.device, dtype=inp.query.dtype)
rng_state = None
softmax_lse = torch.empty(
[inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]],
device=inp.query.device,
dtype=torch.float32,
)
ctx = Context(out=out, lse=softmax_lse)
if inp.p != 0.0:
ctx.op_bw = BwOp
Expand Down Expand Up @@ -544,26 +553,33 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
)

assert grad.dtype in cls.SUPPORTED_DTYPES
cls.OPERATOR(
grad.reshape(kernel_out_shape).contiguous(),
inp.query,
inp.key,
inp.value,
ctx.out.reshape(kernel_out_shape),
ctx_lse,
grads.dq,
grads.dk,
grads.dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
ctx.rng_state,
)

if grads.dq.numel() == 0:
grads.dk.zero_()
grads.dv.zero_()
if grads.dv.numel() == 0:
grads.dq.zero_()
if grads.dq.numel() and grads.dk.numel():
cls.OPERATOR(
grad.reshape(kernel_out_shape).contiguous(),
inp.query,
inp.key,
inp.value,
ctx.out.reshape(kernel_out_shape),
ctx_lse,
grads.dq,
grads.dk,
grads.dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
ctx.rng_state,
)
grads.dq = grads.dq.reshape(dq_shape)
grads.dk = grads.dk.reshape(dk_shape)
grads.dv = grads.dv.reshape(dv_shape)
Expand Down
4 changes: 3 additions & 1 deletion xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
@classmethod
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
bh = B * H
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
split_k = max(Mk, 1024) // bh
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
while split_k > 0 and Mk / split_k < max_chunk_size:
Expand Down Expand Up @@ -694,6 +694,8 @@ def apply(
assert G == 1
out = out[:, :, 0]
lse = lse[:, 0]
if Mk == 0:
out.zero_()

return out, Context(out=out, lse=lse)

Expand Down

0 comments on commit d18d0ef

Please sign in to comment.