From 86e9188702504340680d9e7bd46aebeb974b8c24 Mon Sep 17 00:00:00 2001 From: Longzhi Wang <583087864@qq.com> Date: Mon, 8 Apr 2024 11:14:51 +0800 Subject: [PATCH] [Cherry-pick] Support flash attention 2 with causal masking when KV's seq length is longer than Q's seq length. (#36) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> --- csrc/flash_attn/src/flash_bwd_kernel.h | 8 +++++--- csrc/flash_attn/src/flash_fwd_kernel.h | 11 ++++++----- csrc/flash_attn/src/softmax.h | 6 +++--- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index a0399677296e5d..b9d41af015ca37 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -687,7 +687,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded; int m_block = m_block_max - 1; - int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM; + int m_block_min = !Is_causal ? 0 : (n_block * kBlockN - int(binfo.actual_seqlen_k - binfo.actual_seqlen_q)) / kBlockM; + m_block_min = m_block_min < 0 ? 0 : m_block_min; // We might need to exit early and write 0 to dK and dV. // Otherwise we get wrong result for the case where we don't enter the for loop. @@ -873,7 +874,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } else if (m_block * kBlockM < (n_block + 1) * kBlockN || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, AtomLayoutMS * 16); } @@ -1424,7 +1426,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in // the corresponding values of K would be 0, so the result would still be correct. if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) { flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, AtomLayoutMS * 16); } diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index ac58d0af1ff283..0bdfeab79815c1 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -155,7 +155,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + n_block_max = std::min(n_block_max, cute::ceil_div( + (m_block + 1) * kBlockM + int(binfo.actual_seqlen_k - binfo.actual_seqlen_q), kBlockN)); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } @@ -429,10 +430,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); gSparseMask.data() = gSparseMask.data() + (-kBlockN); } else { - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - kNWarps * 16); + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_q, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + kNWarps * 16); // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); } diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 148f2fce5b572a..00fc151d36dc13 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -142,8 +142,8 @@ inline __device__ void apply_mask(Tensor &tensor, const uint32_t template inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, - const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, - const uint32_t warp_row_stride) { + const uint32_t max_seqlen_q, const uint32_t max_seqlen_k, + const uint32_t row_idx_offset_, const uint32_t warp_row_stride) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const uint32_t lane_id = threadIdx.x % 32; @@ -156,7 +156,7 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const u #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const uint32_t row_idx = row_idx_base + i * 8; - const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const uint32_t col_idx_base = col_idx_offset + nj * 8;