diff --git a/csrc/flash_attn/CMakeLists.txt b/csrc/flash_attn/CMakeLists.txt index 49f581440061d2..7267d9232231ec 100644 --- a/csrc/flash_attn/CMakeLists.txt +++ b/csrc/flash_attn/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(flashattn SHARED ${SOURCES_CU} ${SOURCES_CPP} flash_attn.cpp + flash_attn_with_bias_mask.cpp ) target_compile_options(flashattn PRIVATE $<$: diff --git a/csrc/flash_attn/flash_attn.cpp b/csrc/flash_attn/flash_attn.cpp index bc4dd26e1dd567..42f2644b41f494 100644 --- a/csrc/flash_attn/flash_attn.cpp +++ b/csrc/flash_attn/flash_attn.cpp @@ -26,6 +26,7 @@ * ******************************************************************************/ +#include "flash_attn.h" #include "fmha.h" #include "utils.h" #include "cuda_utils.h" @@ -62,7 +63,7 @@ extern "C" { static thread_local std::unique_ptr flash_attn_err_msg; -static void flash_attn_set_error(const char *msg) { +void flash_attn_set_error(const char *msg) { if (msg == nullptr || *msg == '\0') { msg = "unknown error"; } diff --git a/csrc/flash_attn/flash_attn.h b/csrc/flash_attn/flash_attn.h index 16ffd5cce47a3b..c1febc145093d5 100644 --- a/csrc/flash_attn/flash_attn.h +++ b/csrc/flash_attn/flash_attn.h @@ -70,6 +70,79 @@ bool flash_attn_bwd( uint64_t offset ); +bool flash_attn_fwd_with_bias_and_mask( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const int32_t *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const int32_t *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool is_bf16, + const int num_splits, // SMs per attention matrix, can be 1 + void *softmax_lse_ptr, // softmax log_sum_exp + void *softmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset, + const void *attn_mask, + const void *attn_bias, + const int64_t* mask_dims, + const int64_t* bias_dims +); + +bool flash_attn_bwd_with_bias_and_mask( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *dout, // total_q x num_heads, x head_size + const int32_t *cu_seqlens_q, // int32, batch_size+1 + const int32_t *cu_seqlens_k, // int32, batch_size+1 + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool is_bf16, + const int num_splits, + const void *softmax_lse_ptr, + void *dsoftmax_ptr, + void *dbias_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset, + const void* attn_mask, + const void* attn_bias, + const int64_t* mask_dims, + const int64_t* bias_dims +); + +void flash_attn_set_error(const char *msg); + const char *flash_attn_error(); #ifdef __cplusplus diff --git a/csrc/flash_attn/flash_attn_with_bias_mask.cpp b/csrc/flash_attn/flash_attn_with_bias_mask.cpp new file mode 100644 index 00000000000000..c4980ad769e9ce --- /dev/null +++ b/csrc/flash_attn/flash_attn_with_bias_mask.cpp @@ -0,0 +1,539 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include "flash_attn.h" +#include "fmha.h" +#include "utils.h" +#include "cuda_utils.h" +#include +#include + +#include "cuda.h" +#include "cuda_runtime.h" +#include "dlfcn.h" +#include "math.h" +#include +#include +#include + +#include +#include +#include + +#define FLASH_ATTN_ASSERT_CHECK(__cond) \ + do { \ + const bool __cond_var = (__cond); \ + if (!__cond_var) { \ + ::std::string __err_msg = ::std::string("`") + \ + #__cond + "` check failed at " + \ + __FILE__ + ":" + \ + ::std::to_string(__LINE__); \ + throw std::runtime_error(__err_msg); \ + } \ + } while (0) + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __cplusplus +} +#endif + +#define FLASHATTNLIB_BEGIN_FUNC try { +#define FLASHATTNLIB_END_FUNC } catch (::std::exception &__e) { flash_attn_set_error(__e.what()); return false; } catch (...) { flash_attn_set_error(nullptr); return false; } + +void set_params_fprop_with_bias_mask(FMHA_fprop_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t h, + const size_t d, + // device pointers + void *q, + void *k, + void *v, + void *out, + int32_t *cu_seqlens_q_d, + int32_t *cu_seqlens_k_d, + void *o_tmp_d, + void *s_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + bool is_causal, + bool is_bf16, + int num_splits, + void *attn_mask = nullptr, + void *attn_bias = nullptr, + int bias_mod_size = 0, + int mask_head_mod_size = 0, + int mask_seq_mod_size = 0) { + Data_type data_type = is_bf16 ? DATA_TYPE_BF16 : DATA_TYPE_FP16; + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.is_bf16 = is_bf16; + + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.q_row_stride_in_elts = h * d; + params.k_row_stride_in_elts = h * d; + params.v_row_stride_in_elts = h * d; + params.q_head_stride_in_elts = d; + params.k_head_stride_in_elts = d; + params.v_head_stride_in_elts = d; + params.o_ptr = out; + params.o_row_stride_in_elts = h * d; + params.o_head_stride_in_elts = d; + params.o_tmp_ptr = o_tmp_d; + params.o_tmp_row_stride_in_elts = h * d; + params.o_tmp_head_stride_in_elts = d; + + params.cu_seqlens_q = cu_seqlens_q_d; + params.cu_seqlens_k = cu_seqlens_k_d; + + // S = softmax(P) + params.s_ptr = s_d; + params.s_stride_in_bytes = get_size_in_bytes(b * h * seqlen_k, data_type); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.d = d; + + // attn mask & bias + params.attn_mask_ptr = attn_mask; + params.attn_bias_ptr = attn_bias; + params.bias_mod_size = bias_mod_size; + params.mask_head_mod_size = mask_head_mod_size; + params.mask_seq_mod_size = mask_seq_mod_size; + + // Set the different scale values. + // const float scale_bmm1 = 1.f / sqrtf(d); + const float scale_bmm1 = softmax_scale; + + params.scale_bmm1f = scale_bmm1; + set_alpha(params.scale_bmm1, scale_bmm1, data_type); + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f; + FLASH_ATTN_ASSERT_CHECK(p_dropout < 1.f); + set_alpha(params.scale_dropout, params.rp_dropout, data_type); + + params.is_causal = is_causal; + params.num_splits = num_splits; +} + +void set_params_dgrad_with_bias_mask(FMHA_dgrad_params ¶ms, + const size_t b, // sizes + const size_t seqlen_q, + const size_t seqlen_k, + const size_t h, + const size_t d, + void *q, // device pointers + void *k, + void *v, + void *out, + void *dq, + void *dk, + void *dv, + int32_t *cu_seqlens_q_d, + int32_t *cu_seqlens_k_d, + void *dq_tmp_d, + void *do_packed_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + bool is_causal, + bool is_bf16, + int num_splits, + void *attn_mask = nullptr, + void *attn_bias = nullptr, + void *attn_ds = nullptr, + int bias_mod_size = 0, + int mask_head_mod_size = 0, + int mask_seq_mod_size = 0) { + set_params_fprop_with_bias_mask(params, + b, + seqlen_q, + seqlen_k, + h, + d, + q, + k, + v, + out, + cu_seqlens_q_d, + cu_seqlens_k_d, + dq_tmp_d, // Reusing the o_tmp_ptr variable to store dq_tmp + nullptr, + softmax_lse_d, + p_dropout, + softmax_scale, + is_causal, + is_bf16, + num_splits, + attn_mask, + attn_bias, + bias_mod_size, + mask_head_mod_size, + mask_seq_mod_size); + + // Set the pointers and strides. + params.dq_ptr = dq; + params.dk_ptr = dk; + params.dv_ptr = dv; + params.dq_row_stride_in_elts = h * d; + params.dk_row_stride_in_elts = h * d; + params.dv_row_stride_in_elts = h * d; + params.dq_head_stride_in_elts = d; + params.dk_head_stride_in_elts = d; + params.dv_head_stride_in_elts = d; + params.do_ptr = do_packed_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + params.attn_ds_ptr = attn_ds; +} + +void run_fwd_with_bias_mask(Launch_params &launch_params, + const bool configure) { + if (launch_params.params.d == 16) { + run_fmha_fwd_with_mask_bias_hdim16(launch_params, configure); + } else if (launch_params.params.d == 32) { + run_fmha_fwd_with_mask_bias_hdim32(launch_params, configure); + } else if (launch_params.params.d == 64) { + run_fmha_fwd_with_mask_bias_hdim64(launch_params, configure); + } else if (launch_params.params.d == 128) { + run_fmha_fwd_with_mask_bias_hdim128(launch_params, configure); + } +} + +void run_bwd_with_bias_mask(FMHA_dgrad_params ¶ms, + cudaStream_t stream) { + if (params.d == 16) { + run_fmha_bwd_with_mask_bias_hdim16(params, stream); + } else if (params.d == 32) { + run_fmha_bwd_with_mask_bias_hdim32(params, stream); + } else if (params.d == 64) { + run_fmha_bwd_with_mask_bias_hdim64(params, stream); + } else if (params.d == 128) { + run_fmha_bwd_with_mask_bias_hdim128(params, stream); + } +} + +#ifdef __cplusplus +extern "C" { +#endif + + +// For just alphafold2 +bool flash_attn_fwd_with_bias_and_mask( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const int32_t *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const int32_t *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool is_bf16, + const int num_splits, // SMs per attention matrix, can be 1 + void *softmax_lse_ptr, // softmax log_sum_exp + void *softmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset, + const void *attn_mask = nullptr, + const void *attn_bias = nullptr, + const int64_t* mask_dims = nullptr, + const int64_t* bias_dims = nullptr) { + // printf("forward seed %jd offset %jd\b", seed, offset); + FLASHATTNLIB_BEGIN_FUNC + + auto dprops = GetDeviceProperties(-1); + bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + + FLASH_ATTN_ASSERT_CHECK(is_sm8x || is_sm75); + FLASH_ATTN_ASSERT_CHECK(batch_size > 0); + FLASH_ATTN_ASSERT_CHECK((head_size % 8 == 0) && (head_size <= 128)); + + int blocksize_c = head_size > 64 ? 128 : 256; + // Need to round max_seqlen_k to multiples of blocksize_c + int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; + if( max_seqlen_k_ <= 128 ) { + max_seqlen_k = 128; + } else if( max_seqlen_k_ <= 256 ) { + max_seqlen_k = 256; + } + int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; + bool loop = max_seqlen_k > blocksize_c; + + void* o_tmp_ptr = workspace_ptr; + // nullptr out to calculate workspace size + if (out == nullptr) { + if (loop) { + *workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float); + } else { + *workspace_size = 0; + } + return true; + } + int bias_mod_size = attn_bias ? bias_dims[0] : 0; + if (attn_bias) { + FLASH_ATTN_ASSERT_CHECK(bias_dims[1] == num_heads); + } + int mask_head_mod_size = attn_mask ? mask_dims[1] : 0; + int mask_seq_mod_size = attn_mask ? mask_dims[2] : 0; + if (attn_mask) { + FLASH_ATTN_ASSERT_CHECK(mask_dims[1] == 1 || mask_dims[1] == num_heads); + FLASH_ATTN_ASSERT_CHECK(mask_dims[2] == 1 || mask_dims[2] == max_seqlen_q_); + } + + bool return_softmax = (softmax_ptr != nullptr); + bool is_dropout = p_dropout > 0.f; + Launch_params launch_params(dprops, stream, is_dropout, return_softmax); + + if (zero_tensors) { + SetZero(out, 2, {total_q, num_heads, head_size}, stream); + SetConstValue(softmax_lse_ptr, -std::numeric_limits::infinity(), uint64_t(batch_size) * num_heads * max_seqlen_q, stream); + if (return_softmax) SetZero(softmax_ptr, 2, {batch_size, num_heads, max_seqlen_q, max_seqlen_k}, stream); // float16 + } + + set_params_fprop_with_bias_mask(launch_params.params, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + head_size, + const_cast(q), + const_cast(k), + const_cast(v), + const_cast(out), + const_cast(cu_seqlens_q), + const_cast(cu_seqlens_k), + loop ? o_tmp_ptr : nullptr, + return_softmax ? softmax_ptr : nullptr, + softmax_lse_ptr, + p_dropout, + softmax_scale, + is_causal, + is_bf16, + num_splits, + const_cast(attn_mask), + const_cast(attn_bias), + bias_mod_size, + mask_head_mod_size, + mask_seq_mod_size); + run_fwd_with_bias_mask(launch_params, /*configure=*/ true); + + if( is_dropout ) { + launch_params.params.philox_args = PhiloxCudaState(seed, offset); + } + run_fwd_with_bias_mask(launch_params, /*configure=*/false); + return true; + FLASHATTNLIB_END_FUNC +} + + +bool flash_attn_bwd_with_bias_and_mask( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *dout, // total_q x num_heads, x head_size + const int32_t *cu_seqlens_q, // int32, batch_size+1 + const int32_t *cu_seqlens_k, // int32, batch_size+1 + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool is_bf16, + const int num_splits, + const void *softmax_lse_ptr, + void *dsoftmax_ptr, + void *dbias_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset, + const void* attn_mask = nullptr, + const void* attn_bias = nullptr, + const int64_t* mask_dims = nullptr, + const int64_t* bias_dims = nullptr) { + // printf("backward seed %jd offset %jd\b", seed, offset); + FLASHATTNLIB_BEGIN_FUNC + auto dprops = GetDeviceProperties(-1); + bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + FLASH_ATTN_ASSERT_CHECK(is_sm8x || is_sm75); + + bool is_dropout = p_dropout > 0.0; + + FLASH_ATTN_ASSERT_CHECK(batch_size > 0); + FLASH_ATTN_ASSERT_CHECK((head_size % 8 == 0) && (head_size <= 128)); + if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well + FLASH_ATTN_ASSERT_CHECK(is_sm80); + } + + int blocksize_c = (head_size > 64 || (is_sm75 && head_size > 32)) ? 128 : 256; + int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; + if( max_seqlen_k_ <= 128 ) { + max_seqlen_k = 128; + } else if( max_seqlen_k_ <= 256 ) { + max_seqlen_k = 256; + } + int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; + bool loop = max_seqlen_k > blocksize_c; + + void *dq_tmp_ptr = workspace_ptr; + // nullptr out to calculate workspace size + if (out == nullptr) { + // There are two cases no need to allocate workspace: + // 1) num_splits == 1 + // 2) num_splits == 0 for auto calculation, result to num_splits == 1 + // we do allocation for case 2 for simplicity + if (num_splits == 1) { + *workspace_size = 0; + } else { + *workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float); + } + return true; + } + + int bias_mod_size = 0; + if (attn_bias) { + // check attn_bias shape + bias_mod_size = bias_dims[0]; + SetZero(dbias_ptr, 2, {batch_size, num_heads, max_seqlen_q_, max_seqlen_k_}, stream); + FLASH_ATTN_ASSERT_CHECK(bias_dims[1] == num_heads); + } + + int mask_head_mod_size = 0; + int mask_seq_mod_size = 0; + if (attn_mask) { + // last two dimension + mask_head_mod_size = mask_dims[1]; + mask_seq_mod_size = mask_dims[2]; + FLASH_ATTN_ASSERT_CHECK(mask_dims[1] == 1 || mask_dims[1] == num_heads); + FLASH_ATTN_ASSERT_CHECK(mask_dims[2] == 1 || mask_dims[2] == max_seqlen_q_); + } + + if(zero_tensors) { + SetZero(dq, 2, {total_q, num_heads, head_size}, stream); + SetZero(dk, 2, {total_q, num_heads, head_size}, stream); + SetZero(dv, 2, {total_q, num_heads, head_size}, stream); + SetZero(dsoftmax_ptr, 4, {batch_size, num_heads, max_seqlen_q}, stream); + } + + FMHA_dgrad_params params; + set_params_dgrad_with_bias_mask(params, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + head_size, + const_cast(q), + const_cast(k), + const_cast(v), + const_cast(out), + dq, + dk, + dv, + const_cast(cu_seqlens_q), + const_cast(cu_seqlens_k), + loop ? dq_tmp_ptr : nullptr, + const_cast(dout), + const_cast(softmax_lse_ptr), + dsoftmax_ptr, + p_dropout, + softmax_scale, + is_causal, + is_bf16, + num_splits, + attn_mask ? const_cast(attn_mask) : nullptr, + attn_bias ? const_cast(attn_bias) : nullptr, + attn_bias ? dbias_ptr : nullptr, + bias_mod_size, + mask_head_mod_size, + mask_seq_mod_size); + + if(is_dropout) { + params.philox_args = PhiloxCudaState(seed, offset); + } + run_bwd_with_bias_mask(params, stream); + return true; + FLASHATTNLIB_END_FUNC +} + +#ifdef __cplusplus +} +#endif + diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 43b6f4c298c035..6e53bf0a246268 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -234,13 +234,12 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(cu_seqlens_q.is_contiguous()); TORCH_CHECK(cu_seqlens_k.is_contiguous()); - const auto sizes = q.sizes(); - - const int batch_size = cu_seqlens_q.numel() - 1; - const int total_q = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - const int total_k = k.size(TOTAL_DIM); + const auto sizes = q.sizes(); // q : torch.Size([32768, 8, 32]) + const int batch_size = cu_seqlens_q.numel() - 1; // q_cu_seqlens : torch.Size([129]) + const int total_q = sizes[/*0*/TOTAL_DIM]; // 32768 + const int num_heads = sizes[/*1*/H_DIM]; // 8 + const int head_size = sizes[/*2*/D_DIM]; // 32 + const int total_k = k.size(TOTAL_DIM); // 32768 TORCH_CHECK(batch_size > 0); TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128)); diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 758e28413c1d5d..775b3908e5f9c7 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -64,6 +64,17 @@ struct Qkv_params { //////////////////////////////////////////////////////////////////////////////////////////////////// struct FMHA_fprop_params : public Qkv_params { + // The attn mask matrix + void * __restrict__ attn_mask_ptr; + int mask_head_mod_size; + int mask_seq_mod_size; + + // The attn bias matrix + void * __restrict__ attn_bias_ptr; + int bias_mod_size; + + // The ds matrix + void * __restrict__ attn_ds_ptr; // The O matrix (output). void * __restrict__ o_ptr; @@ -195,6 +206,16 @@ void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); +void run_fmha_fwd_with_mask_bias_hdim16(Launch_params &launch_params, const bool configure); +void run_fmha_fwd_with_mask_bias_hdim32(Launch_params &launch_params, const bool configure); +void run_fmha_fwd_with_mask_bias_hdim64(Launch_params &launch_params, const bool configure); +void run_fmha_fwd_with_mask_bias_hdim128(Launch_params &launch_params, const bool configure); + +void run_fmha_bwd_with_mask_bias_hdim16(FMHA_dgrad_params ¶ms, cudaStream_t stream); +void run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream); +void run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream); +void run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream); + void run_fmha_block_fp16_sm80(Launch_params &launch_params, const bool configure); void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index a142f0bf2c62ad..2fff2b219f0aa6 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -165,6 +165,12 @@ struct Fragment_b : public Fragment { //////////////////////////////////////////////////////////////////////////////////////////////////// +template< typename Layout, typename elem_type > +struct Fragment_c : public Fragment { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + struct Fragment_accumulator : public Fragment { // The base class. @@ -184,6 +190,15 @@ struct Fragment_accumulator : public Fragment { } } + template< typename Other_fragment_ > + inline __device__ void addf(const Other_fragment_ &other) { + // elt or reg? + #pragma unroll + for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { + this->elt(ii) = this->elt(ii) + toFloat(other.elt(ii)); + } + } + // Do the HMMA. template< typename Layout_a, typename Layout_b > inline __device__ void mma(const Fragment_a &a, diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index e0bd24c3c09b00..9d68c164b25097 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -65,7 +65,7 @@ struct Gmem_tile_qkv { static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG); // Ctor. - template< typename BInfo > + template inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts, const uint32_t head_stride_in_elts, const int headdim, const BInfo &binfo, const int tidx, bool use_seqlen_q) @@ -95,6 +95,38 @@ struct Gmem_tile_qkv { ptr += row_offset + col * BYTES_PER_LDG; } + // Ctor. + template + inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts, + const uint32_t head_stride_in_elts, + const BInfo &binfo, const int tidx, bool use_seqlen_q) + : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k) + , ptr(reinterpret_cast(ptr_)) + , tidx_(tidx) + , col_predicate(true) { + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Store the row as we need it to disable the loads. + // TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it + // row_ = row; + + // The row offset in the batched GEMM. For each seq element, we store QKV in that order. + // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; + uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes); + // Add the block index. + + // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; + row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); + + // Assemble the final pointer. + ptr += row_offset + col * BYTES_PER_LDG; + } + // Store data to shared memory. template< typename Smem_tile > inline __device__ void commit(Smem_tile &smem_tile) { @@ -227,6 +259,38 @@ struct Gmem_tile_o { } } + // Ctor. + template + // inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx) + inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts, + const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx) + : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen_q(binfo.actual_seqlen_q) + , ptr_(reinterpret_cast(ptr)) + , tidx_(tidx) + , col_predicate(true) { + + // Compute the position in the sequence (within the CTA for the moment). + int row = tidx / THREADS_PER_ROW; + // Compute the position of the thread in the row. + int col = tidx % THREADS_PER_ROW; + + // Store the row as we need it to disable loads. + // row_ = row; + + // The row offset in the batched GEMM. + // int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; + uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes); + row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); + // Assemble the final pointer. + ptr_ += row_offset + col * BYTES_PER_STG; + + // Is that thread active on the last STG? + if( HAS_INCOMPLETE_STG ) { + is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; + } + } + // Store data to global memory. template inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { @@ -435,6 +499,526 @@ struct Gmem_tile_mma_s : public Base { //////////////////////////////////////////////////////////////////////////////////////////////////// +template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> +struct Gmem_tile_mma_mask { + + using Mma_tile = fmha::Hmma_tile; + // The type of the vectors stored by each STG. + using StoreType = uint32_t; + + // static constexpr int LDG_ELEMENTS = 2 + // using Type = typename fmha::Uint_from_size_in_bytes< LDG_ELEMENTS * BYTES_PER_ELEMENT >::Type; + + // The number of MMAs in the M dimension. + static constexpr int M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int N = Mma_tile::MMAS_N; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS = Cta_tile::M; + static constexpr int COLS = Cta_tile::N; + + // The size of each LDG. + // load two elements of data + static constexpr int BYTES_PER_LDG = 2 * BYTES_PER_ELEMENT; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; + + // The number of LDGS needed to store a chunk of the P matrix in total. + // Tell me if has more efficient way + static constexpr int LDGS_PER_THREAD_PER_WARP = 4; + static constexpr int THREADS_PER_QUAD = 4; + static constexpr int COL_PER_MMA_PER_CTA = Cta_tile::THREADS_PER_WARP / THREADS_PER_QUAD; + + // Ctor. + template< typename Params, typename Block_info > + inline __device__ Gmem_tile_mma_mask(const Params ¶ms, + // const uint32_t row_stride_in_elts, const uint32_t head_stride_in_elts, + const Block_info& binfo, const int tidx, const int loop_step_idx) + : ptr_(static_cast(params.attn_mask_ptr)) + // : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen_q(binfo.actual_seqlen_q) + , actual_seqlen_k(binfo.actual_seqlen_k) + , tidx_(tidx) + , loop_step_idx(loop_step_idx) + , mask_seq_mod_size(params.mask_seq_mod_size) + { + row_stride_in_bytes = binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + + const int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + const int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + // this col is mean the 8x4 tile's cole + + row = warp_m * Mma_tile::M_PER_MMA + quad; + static_assert(Mma_tile::M_PER_MMA == 16); + + col = warp_n * Mma_tile::N_PER_MMA + tid; + static_assert(Mma_tile::N_PER_MMA == 16); + + // The distance between two blocks (in bytes). + // TODO: mask is [bs * seq, head, seq_q, seq_k] + // The block index. + // uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + uint32_t bidx = binfo.bidb * params.mask_head_mod_size + (binfo.bidh % params.mask_head_mod_size); + + // the index of bs and head dim + // uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + // row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); + + // to support the mask last two dimension + uint32_t row_offset = bidx * params.mask_seq_mod_size * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + row_offset += (uint32_t)( (row % params.mask_seq_mod_size) * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); + + ptr_ += row_offset; + } + + // Load from global memory to Fragment. + template + inline __device__ void load(Fragment (&frag)[M][N]) { + // using Fragment = typename fmha::Fragment; + + const void *ptrs[LDGS_PER_THREAD_PER_WARP]; + uint32_t preds[LDGS_PER_THREAD_PER_WARP]; + + if (!(actual_seqlen_k & 1)) { + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + int offset = ii * 2 + jj; + const int current_row = mi * ROWS + ii * 8; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + ptrs[offset] = ptr_ + (uint32_t)(current_row % mask_seq_mod_size) * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + preds[offset] = (current_row + (row % mask_seq_mod_size) < min(ROWS, actual_seqlen_q)) + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); + } + } + // load data + Ldg_functor fct(frag[mi][ni].regs_, ptrs); + #pragma unroll + for(int kk = 0; kk < LDGS_PER_THREAD_PER_WARP; ++kk ) { + fct.load(kk, preds[kk]); + } + } + } + }else{ + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + int offset = ii * 2 + jj; + const int current_row = mi * ROWS + ii * 8; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + ptrs[offset] = ptr_ + (uint32_t)(current_row % mask_seq_mod_size) * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + preds[offset] = 0; + if ((current_row + (row % mask_seq_mod_size) < min(ROWS, actual_seqlen_q))) { + if(current_col <= actual_seqlen_k) { + if((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k){ + preds[offset] = 1; + }else if((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT - 1) == actual_seqlen_k) { + preds[offset] = 2; + } + } + } + } + } + // load data + #pragma unroll + for(int kk = 0; kk < LDGS_PER_THREAD_PER_WARP; ++kk ) { + if (preds[kk] == 1) { + uint16_t dst_16_h = *reinterpret_cast(ptrs[kk]); + uint16_t dst_16_l = *(reinterpret_cast(ptrs[kk]) + 1); + frag[mi][ni].regs_[kk] = ((uint32_t)dst_16_l << 16) + dst_16_h; + } + if (preds[kk] == 2) { + uint16_t dst_16 = *reinterpret_cast(ptrs[kk]); + frag[mi][ni].regs_[kk] = ((uint32_t)0 << 16) + dst_16; + } + } + } + } + } + } + + inline __device__ void move(const int steps = 1) { + // to support the mask last two dimension + ptr_ += (uint32_t)(ROWS % mask_seq_mod_size) * row_stride_in_bytes * steps; + this->actual_seqlen_q -= ROWS * steps; + } + + int row; + int col; + const int loop_step_idx; + uint32_t row_stride_in_bytes; + // The pointer. + char *ptr_; + int actual_seqlen_q; + int actual_seqlen_k; + int mask_seq_mod_size; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> +struct Gmem_tile_mma_bias { + + using Mma_tile = fmha::Hmma_tile; + // The type of the vectors stored by each STG. + using StoreType = uint32_t; + + // static constexpr int LDG_ELEMENTS = 2 + // using Type = typename fmha::Uint_from_size_in_bytes< LDG_ELEMENTS * BYTES_PER_ELEMENT >::Type; + + // The number of MMAs in the M dimension. + static constexpr int M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int N = Mma_tile::MMAS_N; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS = Cta_tile::M; + static constexpr int COLS = Cta_tile::N; + + // The size of each LDG. + // load two elements of data + static constexpr int BYTES_PER_LDG = 2 * BYTES_PER_ELEMENT; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; + + // The number of LDGS needed to store a chunk of the P matrix in total. + // Tell me if has more efficient way + static constexpr int LDGS_PER_THREAD_PER_WARP = 4; + static constexpr int THREADS_PER_QUAD = 4; + static constexpr int COL_PER_MMA_PER_CTA = Cta_tile::THREADS_PER_WARP / THREADS_PER_QUAD; + + // Ctor. + template< typename Params, typename Block_info > + inline __device__ Gmem_tile_mma_bias(const Params ¶ms, + // const uint32_t row_stride_in_elts, const uint32_t head_stride_in_elts, + const Block_info& binfo, const int tidx, const int loop_step_idx) + : ptr_(static_cast(params.attn_bias_ptr)) + // : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen_q(binfo.actual_seqlen_q) + , actual_seqlen_k(binfo.actual_seqlen_k) + , tidx_(tidx) + , loop_step_idx(loop_step_idx) + { + row_stride_in_bytes = binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + + const int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + const int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + // this col is mean the 8x4 tile's cole + + row = warp_m * Mma_tile::M_PER_MMA + quad; + static_assert(Mma_tile::M_PER_MMA == 16); + + col = warp_n * Mma_tile::N_PER_MMA + tid; + static_assert(Mma_tile::N_PER_MMA == 16); + + // The distance between two blocks (in bytes). + // TODO: mask is [bs, head, seq_q, seq_k] + // The block index. + // uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + uint32_t bidx = ( binfo.bidb % params.bias_mod_size ) * params.h + binfo.bidh; + + // the index of bs and head dim + uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + // row_offset = (uint32_t)(row * row_stride_in_bytes); + row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); + + // do we need to move col first if seklen_k > cols + ptr_ += row_offset; + } + + // Load from global memory to Fragment. + template + inline __device__ void load(Fragment (&frag)[M][N]) { + const void *ptrs[LDGS_PER_THREAD_PER_WARP]; + uint32_t preds[LDGS_PER_THREAD_PER_WARP]; + + if (!(actual_seqlen_k & 1)) { + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + int offset = ii * 2 + jj; + const int current_row = mi * ROWS + ii * 8; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + + preds[offset] = (current_row + row < min(ROWS, actual_seqlen_q)) + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); + } + } + + Ldg_functor fct(frag[mi][ni].regs_, ptrs); + #pragma unroll + for(int kk = 0; kk < LDGS_PER_THREAD_PER_WARP; ++kk ) { + fct.load(kk, preds[kk]); + } + } + } + }else{ + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + int offset = ii * 2 + jj; + const int current_row = mi * ROWS + ii * 8; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + preds[offset] = 0; + if ((current_row + row < min(ROWS, actual_seqlen_q))) { + if(current_col <= actual_seqlen_k) { + if((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k){ + preds[offset] = 1; + }else if((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT - 1) == actual_seqlen_k) { + preds[offset] = 2; + } + } + } + } + } + // load data + #pragma unroll + for(int kk = 0; kk < LDGS_PER_THREAD_PER_WARP; ++kk ) { + if (preds[kk] == 1) { + uint16_t dst_16_h = *reinterpret_cast(ptrs[kk]); + uint16_t dst_16_l = *(reinterpret_cast(ptrs[kk]) + 1); + frag[mi][ni].regs_[kk] = ((uint32_t)dst_16_l << 16) + dst_16_h; + } + if (preds[kk] == 2) { + uint16_t dst_16 = *reinterpret_cast(ptrs[kk]); + frag[mi][ni].regs_[kk] = ((uint32_t)0 << 16) + dst_16; + } + } + } + } + } + } + + inline __device__ void move(const int steps = 1) { + ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; + this->actual_seqlen_q -= ROWS * steps; + } + + int row; + int col; + const int loop_step_idx; + uint32_t row_stride_in_bytes; + // The pointer. + char *ptr_; + int actual_seqlen_q; + int actual_seqlen_k; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> +struct Gmem_tile_mma_ds { + + using Mma_tile = fmha::Hmma_tile; + // The type of the vectors stored by each STG. + using StoreType = uint32_t; + + // The number of MMAs in the M dimension. + static constexpr int M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int N = Mma_tile::MMAS_N; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS = Cta_tile::M; + static constexpr int COLS = Cta_tile::N; + + // The size of each LDG. + // load two elements of data + static constexpr int BYTES_PER_LDG = 2 * BYTES_PER_ELEMENT; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; + + // The number of LDGS needed to store a chunk of the P matrix in total. + // Tell me if has more efficient way + static constexpr int LDGS_PER_THREAD_PER_WARP = 4; + static constexpr int THREADS_PER_QUAD = 4; + static constexpr int COL_PER_MMA_PER_CTA = Cta_tile::THREADS_PER_WARP / THREADS_PER_QUAD; + + // Ctor. + template< typename Params, typename Block_info > + inline __device__ Gmem_tile_mma_ds(const Params ¶ms, + // const uint32_t row_stride_in_elts, const uint32_t head_stride_in_elts, + const Block_info& binfo, const int tidx, const int loop_step_idx) + : ptr_(static_cast(params.attn_ds_ptr)) + // : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen_q(binfo.actual_seqlen_q) + , actual_seqlen_k(binfo.actual_seqlen_k) + , tidx_(tidx) + , loop_step_idx(loop_step_idx) + { + row_stride_in_bytes = binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + + const int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + const int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + // this col is mean the 8x4 tile's cole + + row = warp_m * Mma_tile::M_PER_MMA + quad; + static_assert(Mma_tile::M_PER_MMA == 16, + "only support sm80 m16n8k16 tensor core"); + + col = warp_n * Mma_tile::N_PER_MMA + tid; + static_assert(Mma_tile::N_PER_MMA == 16, + "only support sm80 m16n8k16 tensor core"); + + // The distance between two blocks (in bytes). + // TODO: mask is [bs, head, seq_q, seq_k] + // The block index. + uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + + // the index of bs and head dim + uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + // row_offset = (uint32_t)(row * row_stride_in_bytes); + + row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); + // do we need to move col first if seklen_k > cols + ptr_ += row_offset; + } + + // Store to global memory. + template + inline __device__ void store(const float (&softmax)[2 * M][4 * N], int l=0) { + uint32_t preds; + uint32_t dst; + + if (!(actual_seqlen_k & 1)) { + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + float tmp00 = softmax[2 * mi + ii][4 * ni + jj * 2]; + float tmp01 = softmax[2 * mi + ii][4 * ni + jj * 2 + 1]; + dst = fmha::float2_pack(tmp00, tmp01); + + const int current_row = mi * ROWS + ii * 8; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + + char *ptrs = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + + preds = (current_row + row < min(ROWS, actual_seqlen_q)) + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); + if (preds) { + fmha::stg(ptrs, dst); + } + } + } + } + } + }else{ + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + float tmp00 = softmax[2 * mi + ii][4 * ni + jj * 2]; + float tmp01 = softmax[2 * mi + ii][4 * ni + jj * 2 + 1]; + uint16_t data1 = fmha::float_pack(tmp00); + uint16_t data2 = fmha::float_pack(tmp01); + + const int current_row = mi * ROWS + ii * 8; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + + char *ptrs = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + preds = 0; + if ((current_row + row < min(ROWS, actual_seqlen_q))) { + if(current_col <= actual_seqlen_k) { + if((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k){ + preds = 1; + }else if((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT - 1) == actual_seqlen_k) { + preds = 2; + } + } + } + + if (preds == 1) { + fmha::stg(reinterpret_cast(ptrs), data1); + fmha::stg(reinterpret_cast(ptrs) + 1, data2); + }else if (preds == 2) { + fmha::stg(reinterpret_cast(ptrs), data1); + } + } + } + } + } + } + } + + inline __device__ void move(const int steps = 1) { + ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; + this->actual_seqlen_q -= ROWS * steps; + } + + int row; + int col; + const int loop_step_idx; + uint32_t row_stride_in_bytes; + // The pointer. + char *ptr_; + int actual_seqlen_q; + int actual_seqlen_k; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// template< // The dimensions of the tile computed by the CTA. typename Cta_tile diff --git a/csrc/flash_attn/src/fmha/kernel_traits.h b/csrc/flash_attn/src/fmha/kernel_traits.h index 63f07aee8e090a..867a9e74f3e271 100644 --- a/csrc/flash_attn/src/fmha/kernel_traits.h +++ b/csrc/flash_attn/src/fmha/kernel_traits.h @@ -76,6 +76,15 @@ struct FMHA_kernel_traits { using Gmem_tile_do = fmha::Gmem_tile_qkv; + // Gmem_tile_mma_mask + using Gmem_tile_mask = fmha::Gmem_tile_mma_mask; + + // Gmem_tile_mma_bias + using Gmem_tile_bias = fmha::Gmem_tile_mma_bias; + + // Gmem_tile_mma_ds + using Gmem_tile_ds = fmha::Gmem_tile_mma_ds; + // // The global memory tile to store the accumulated dK and dV // // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV // // where there are 16 bits per lements and 16 bytes per load. In reality we won't diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index bd874375e5d372..6dd84327afd1d1 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -457,6 +457,46 @@ struct Softmax : public Softmax_base { , smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) { } + template + inline __device__ void apply_attn_mask(const Fragment (&attn_mask)[MMAS_M][MMAS_N], const Mask &mask, int l = 0, int loop_step_idx = 0) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + #pragma unroll + for( int jj = 0; jj < 4; ++jj ) { + if( mask.is_valid(mi, ni, ii, jj) ) { + float value = toFloat(attn_mask[mi][ni].elt(ii * 4 + jj)); + this->elt_[2 * mi + ii][4 * ni + jj] += value; + } + } + } + } + } + } + + template + inline __device__ void apply_attn_bias(const Fragment (&bias)[MMAS_M][MMAS_N], const Mask &mask, int l = 0) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + #pragma unroll + for( int jj = 0; jj < 4; ++jj ) { + if( mask.is_valid(mi, ni, ii, jj) ) { + float value = toFloat(bias[mi][ni].elt(ii * 4 + jj)); + this->elt_[2 * mi + ii][4 * ni + jj] += value; + } + } + } + } + } + } + // Pack the data to a fragment for the next GEMM. template inline __device__ void pack(Fragment_a (&dst)[K][M]) const { diff --git a/csrc/flash_attn/src/fmha/utils.h b/csrc/flash_attn/src/fmha/utils.h index ecb8aef7fa31df..110dda25f086a4 100644 --- a/csrc/flash_attn/src/fmha/utils.h +++ b/csrc/flash_attn/src/fmha/utils.h @@ -410,6 +410,25 @@ inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ uint16_t float_pack(float a); + +template <> +inline __device__ uint16_t float_pack<__half>(float a) { + __half result = __float2half_rn(a); + return reinterpret_cast(result); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint16_t float_pack<__nv_bfloat16>(float a) { + __nv_bfloat16 result = __float2bfloat16_rn(a); + return reinterpret_cast(result); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + static inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a,a); } @@ -1211,5 +1230,20 @@ __device__ inline void quad_allreduce(__half2 (&dst)[M], float2 (&src)[M], Opera } //////////////////////////////////////////////////////////////////////////////////////////////////// +template __device__ +inline float toFloat(T a) { + return (float)a; +} +template<> __device__ +inline float toFloat(half a) { + return __half2float(a); +} +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> __device__ +inline float toFloat(__nv_bfloat16 a) { + return __bfloat162float(a); +} +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha diff --git a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu index bfafa20ea4eda1..c6c45177e44164 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu @@ -1,6 +1,5 @@ /* Copyright (c) 2022, Tri Dao. */ - #include "fmha.h" #include "fmha_block_dgrad_kernel_1xN_loop.h" @@ -61,4 +60,4 @@ void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_ using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); } -} \ No newline at end of file +} diff --git a/csrc/flash_attn/src/fmha_bwd_hdim32.cu b/csrc/flash_attn/src/fmha_bwd_hdim32.cu index a09ebac2b1d1a8..e590c1f9fe8108 100644 --- a/csrc/flash_attn/src/fmha_bwd_hdim32.cu +++ b/csrc/flash_attn/src/fmha_bwd_hdim32.cu @@ -14,4 +14,4 @@ void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b run_fmha_bwd_loop(params, stream, configure); } })); -} \ No newline at end of file +} diff --git a/csrc/flash_attn/src/fmha_bwd_launch_template.h b/csrc/flash_attn/src/fmha_bwd_launch_template.h index 324e30411c126e..07d13b9c1fad33 100644 --- a/csrc/flash_attn/src/fmha_bwd_launch_template.h +++ b/csrc/flash_attn/src/fmha_bwd_launch_template.h @@ -44,6 +44,11 @@ __global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params param fmha::compute_dq_dk_dv_seqparallel(params); } +template +__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { + fmha::compute_dq_dk_dv_1xN(params); +} + template void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); @@ -113,3 +118,124 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo FMHA_CHECK_CUDA(cudaPeekAtLastError()); })); } + +template +__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { + fmha::compute_dq_dk_dv_1xN_with_bias_mask(params); +} + +template +void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { + constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); + constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; + constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; + constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; + + using Smem_tile_s = fmha::Smem_tile_mma_transposed; + constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; + static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); + static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); + + constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2; + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + // printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv); + + bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" + + bool has_attn_mask = !(params.attn_mask_ptr == nullptr); + bool has_attn_bias = !(params.attn_bias_ptr == nullptr); + + if (has_attn_mask) { + if (has_attn_bias) { + BOOL_SWITCH_FUNC(is_dropout, IsDropoutConst, [&] { + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + BOOL_SWITCH_FUNC(is_dropout, IsDropoutConst, [&] { + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + } + }else{ + if (has_attn_bias) { + BOOL_SWITCH_FUNC(is_dropout, IsDropoutConst, [&] { + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + BOOL_SWITCH_FUNC(is_dropout, IsDropoutConst, [&] { + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + } + } +} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim128.cu b/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim128.cu new file mode 100644 index 00000000000000..11236eda72d1c2 --- /dev/null +++ b/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim128.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_bwd_launch_template.h" + +void run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(params.is_bf16, ([&] { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + })); +} diff --git a/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim16.cu b/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim16.cu new file mode 100644 index 00000000000000..c441abf4764b33 --- /dev/null +++ b/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim16.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_bwd_launch_template.h" + +void run_fmha_bwd_with_mask_bias_hdim16(FMHA_dgrad_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(params.is_bf16, ([&] { + if( params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } else if( params.seqlen_k == 256 ) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } else { + // TD [2022-05-15] 512 gives wrong results rn + // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>; + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } + })); +} diff --git a/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim32.cu b/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim32.cu new file mode 100644 index 00000000000000..7cb676c433a516 --- /dev/null +++ b/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim32.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_bwd_launch_template.h" + +void run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(params.is_bf16, ([&] { + if( params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } else if( params.seqlen_k >= 256 ) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } + })); +} diff --git a/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim64.cu b/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim64.cu new file mode 100644 index 00000000000000..03d5e8949fe8cb --- /dev/null +++ b/csrc/flash_attn/src/fmha_bwd_with_mask_bias_hdim64.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_bwd_launch_template.h" + +void run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream) { + auto dprops = GetDeviceProperties(-1); + FP16_SWITCH(params.is_bf16, ([&] { + if( params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } else if( params.seqlen_k >= 256 ) { + if (dprops->major == 8 && dprops->minor == 0) { + // Don't share smem for K & V, and don't keep V in registers + // This speeds things up by 2-3% by avoiding register spills, but it + // uses more shared memory, which is fine on A100 but not other GPUs. + // For other GPUs, we keep V in registers. + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } else if (dprops->major == 8 && dprops->minor > 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } else if (dprops->major == 7 && dprops->minor == 5) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } + } + })); +} diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 5e890b259fd246..aaf5900a5b032a 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -836,4 +836,724 @@ inline __device__ void compute_dq_dk_dv_seqparallel(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void compute_dq_dk_dv_1xN_one_iter_with_bias_mask( + const Params ¶ms, Prng &ph, + const int loop_step_idx) { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using elem_type = typename Kernel_traits::elem_type; +#else + constexpr bool is_fp16_type = std::is_same::value; + assert(is_fp16_type); + using elem_type = __half; +#endif + + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_dq = typename Kernel_traits::Cta_tile_o; + // The description of the CTA tile for the 3rd batched GEMM. + using Cta_tile_dkv = + fmha::Cta_tile_extd; + + static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128); + static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128); + static_assert(Cta_tile_dkv::K == 16); + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_dq = fmha::Hmma_tile; + // The MMA tile for the 3rd GEMM. + using Mma_tile_dkv = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + // The shared memory tile to reload Q transposed. + using Smem_tile_qt = fmha::Smem_tile_b; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; + // The shared memory tile to swizzle K^T. Treat K^T as V + using Smem_tile_kt = typename Kernel_traits::Smem_tile_v; + + // Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_k; + + // The global memory tile to load dO. + using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; + // The shared memory tile to load dO. + // Treating dO as Q. + using Smem_tile_do = typename Kernel_traits::Smem_tile_q; + // The shared memory tile to reload dO transposed. + using Smem_tile_dot = fmha::Smem_tile_b; + + // The global memory tile to load O.Loading O here is similar to loading dO. + using Gmem_tile_o = Gmem_tile_do; + + // The global memory tile to store dQ. + using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o; + using Gmem_tile_dq_tmp = fmha::Gmem_tile_o; + // The shared memory tile to swizzle dQ. + using Smem_tile_dq = typename Kernel_traits::Smem_tile_o; + + // The global memory tile to store dV. + using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle dV. + using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; + + // The global memory tile to store dK. + using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle dK. + using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; + static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); + static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + + using Smem_tile_st = typename Kernel_traits::Smem_tile_st; + + using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; + + // using Gemm1 = Gemm_Q_K; + using Gemm1 = Gemm_Q_K; + + using Softmax = fmha::Softmax; + + // Shared memory. + extern __shared__ char smem_[]; + // Shared memory layout if we keep V in registers: + // dO | Q | K / V | dQ | S | dP | dP_sum + // dV | dK + // Shared memory layout if we keep V shared memory: + // dO | Q | K | V | dQ | S | dP | dP_sum + // dV | dK + + + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + // if( binfo.stop_early() ) return; + if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; + + Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); + // Allocate the global memory tile loader for Q. + Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); + // Allocate the global memory tile loader for dQ. + Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts, binfo, tidx); + Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); + // Allocate the global memory tile loader for S. + Gmem_tile_s gmem_s(params, binfo, tidx); + + // Allocate the global memory tile loader for mask. + using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for bias. + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + using Gmem_tile_ds = typename Kernel_traits::Gmem_tile_ds; + + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); + Gmem_tile_ds gmem_ds(params, binfo, tidx, loop_step_idx); + + fmha::Mask mask(binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false); + // Allocate the global memory tile loader for V. + Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false); + // The base pointer of smem_v; + char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; + + // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! + Smem_tile_v smem_v(smem_v_, tidx); + // Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!! + Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); + + // Allocate the global memory tile loader for dO. + Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true); + // Allocate the shared memory tile loader for dO. + Smem_tile_do smem_do(&smem_[0], tidx); + Smem_tile_dot smem_dot(&smem_[0], tidx); + // Allocate the shared memory tile loader for Q^T. + // TODO: assert that this points to the same memory as gemm_q_k.smem_q + Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); + + Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx); + Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); + + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true); + + // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! + Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); + + Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); + Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx); + + static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); + const int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; + const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin; + + // Wind gmem tiles to the correct position. + gmem_q.move(begin); + gmem_do.move(begin); + gmem_o.move(begin); + gmem_dq.move(begin); + gmem_dq_tmp.move(begin); + // TODO: need to move gmem_s if we want the intermediate result for debugging + gmem_softmax_lse.move(begin); + gmem_softmax_d.move(begin); + + if constexpr (has_attn_mask) { + gmem_mask.move(begin); + } + + if constexpr (has_attn_bias) { + gmem_bias.move(begin); + gmem_ds.move(begin); + } + + if (!Is_first) { + gmem_k.move(loop_step_idx); + gmem_v.move(loop_step_idx); + } + + // Trigger the loads for K. + gmem_k.load(); + // Trigger the loads for Q. + gmem_q.load(); + // Trigger the loads for V. + gmem_v.load(); + // Trigger the loads for dO. + gmem_do.load(); + // Trigger the loads for O. + if (Is_first) { gmem_o.load(); } + + float p_lse[Mma_tile_p::MMAS_M * 2]; + gmem_softmax_lse.load(reinterpret_cast(p_lse)); + gmem_softmax_lse.move(); + + if (!Is_first) { __syncthreads(); } + // Commit the data for Q, dO, and V to shared memory. + gmem_q.commit(gemm_q_k.smem_q); + gmem_do.commit(smem_do); + + if (Is_first) { + dot_do_o( + gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx + ); + } + + // // Instead of scaling dP by rp_dropout, we scale V instead + // if (Is_dropout) { + // const uint32_t scale_dropout = params.scale_dropout; + // #pragma unroll + // for(int it=0; it < Gmem_tile_v::LDGS; it++){ + // gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); + // } + // } + + gmem_v.commit(smem_v); + + // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); + // #pragma unroll + // for(int it=0; it < Gmem_tile_k::LDGS; it++){ + // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); + // } + + // Commit the data for K to shared memory. + if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { + gmem_k.commit(gemm_q_k.smem_k); + } + + __syncthreads(); + + // Load the fragments for Q. + gemm_q_k.load_q(); + + // Load the fragments for V. We keep the data in registers during the entire kernel. + typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N]; + if (Kernel_traits::V_IN_REGS) { + #pragma unroll + for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { + smem_v.load(frag_v[ki], ki); + } + } + + float dp_sum[Mma_tile_p::MMAS_M * 2]; + gmem_softmax_d.load(reinterpret_cast(dp_sum)); + gmem_softmax_d.move(); + + // Commit the data for V to shared memory if it has not been done already. + if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { + // Make sure we are done loading the fragments for K. + __syncthreads(); + + // Commit the data to shared memory for V. + gmem_k.commit(gemm_q_k.smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + } + + // Load the fragments for K. + gemm_q_k.load_k(); + // Load the fragments for K^T. + // typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; + // smem_kt.load(frag_kt[0], 0); + // typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N]; + // #pragma unroll + // for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) { + // smem_kt.load(frag_kt[ki], ki); + // } + + // Create the object to do the softmax. + // We won't be using the shared memory for this softmax at all + Softmax softmax(params, smem_, tidx); + + // Declare the accumulators for the 3rd gemm. + fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N]; + fmha::Clear_accumulator::apply(acc_dv); + fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N]; + fmha::Clear_accumulator::apply(acc_dk); + + // Load over the entire sequence length. + for( int l = 0; l < steps; l++ ) { + const int loop = (begin + l) * Cta_tile_p::M; + if( loop >= binfo.actual_seqlen_q ) + break; + + // Load the fragments for V. + // typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N]; + if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); } + + // Load the fragments for dO. + typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M]; + smem_do.load(frag_do[0], 0); + + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::Clear_accumulator::apply(acc_p); + + // Do this part of P^T = (Q * K^T)^T. + gemm_q_k(acc_p); + + // Load the mask for that iteration. + mask.load(begin + l); + + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack_noscale(acc_p); + if constexpr (has_attn_mask) { + using Frag_mask = fmha::Fragment_c; + Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + gmem_mask.template load(frag_mask); + gmem_mask.move(); + + // Apply the attn mask. + softmax.apply_attn_mask(frag_mask, mask); + } + + if constexpr (has_attn_bias) { + using Frag_Bias = fmha::Fragment_c; + Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + gmem_bias.template load(frag_bias); + gmem_bias.move(); + + // Apply the attn mask. + softmax.apply_attn_bias(frag_bias, mask); + } + + // Apply the mask. + softmax.apply_mask(mask); + // Scale by log-sum-exp of the softmax + // softmax.apply_exp(p_lse); + // exp (x - (max+log(sum))) = exp(x - max) / sum + softmax.template scale_apply_exp(p_lse, params.scale_bmm1f); + + if (Is_dropout) { + // softmax.apply_dropout(ph, params.p_dropout_in_uint); + // softmax.template apply_dropout(ph, params.p_dropout_in_uint); + softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t); + } + + using Frag_p = fmha::Fragment_a; + Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; + static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); + static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); + softmax.template pack(frag_p); + + // Store s * dmask to smem for transpose + smem_s.store(frag_p); + + // Trigger the load for the next Q values. + if( l < steps - 1) { + gemm_q_k.smem_q.move_to_next_write_buffer(); + gmem_q.move(); + gmem_q.load(); + } + + // if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { + // // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction + // __syncthreads(); + // } + + fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + #pragma unroll + for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { + #pragma unroll + for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) { + #pragma unroll + for (int ii = 0; ii < 8; ++ii) { + acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)]; + } + } + } + + // Do this part of dP^T = (dO * V^T)^T. + #pragma unroll + for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { + // Trigger the load from shared memory for the next series of dO values. + smem_do.load(frag_do[ki & 1], ki); + if (!Kernel_traits::V_IN_REGS) { + smem_v.load(frag_v[ki & 1], ki); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + } else { + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); + } + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { + // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); + // printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y); + // tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1])); + // printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y); + // } + } + + // Do the final stage of math. + { + int ki = Mma_tile_p::MMAS_K; + if (!Kernel_traits::V_IN_REGS) { + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + } else { + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); + } + } + + auto pointwise_mult = [](float p, float dp, float d) { + return p * ((!Is_dropout) || p >= 0.f ? dp : d); + }; + #pragma unroll + for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) { + #pragma unroll + for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) { + softmax.elt_[2 * mi + 0][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 0], acc_dp[mi][ni].elt(0), dp_sum[2 * mi + 0]); + softmax.elt_[2 * mi + 0][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 1], acc_dp[mi][ni].elt(1), dp_sum[2 * mi + 0]); + softmax.elt_[2 * mi + 0][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 2], acc_dp[mi][ni].elt(4), dp_sum[2 * mi + 0]); + softmax.elt_[2 * mi + 0][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 3], acc_dp[mi][ni].elt(5), dp_sum[2 * mi + 0]); + softmax.elt_[2 * mi + 1][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 0], acc_dp[mi][ni].elt(2), dp_sum[2 * mi + 1]); + softmax.elt_[2 * mi + 1][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 1], acc_dp[mi][ni].elt(3), dp_sum[2 * mi + 1]); + softmax.elt_[2 * mi + 1][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 2], acc_dp[mi][ni].elt(6), dp_sum[2 * mi + 1]); + softmax.elt_[2 * mi + 1][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 3], acc_dp[mi][ni].elt(7), dp_sum[2 * mi + 1]); + } + } + + // Load the fragments for K^T. + typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; + smem_kt.load(frag_kt[0], 0); + + // Trigger the load for the next dO values. + if( l < steps - 1) { + smem_do.move_to_next_write_buffer(); + gmem_do.move(); + gmem_do.load(); + if (Is_first) { + gmem_o.move(); + gmem_o.load(); + } + } + + softmax.template pack(frag_p); + + if constexpr (has_attn_bias) { + gmem_ds.template store(softmax.elt_); + gmem_ds.move(); + } + + // Store dp to smem for transpose + smem_dp.store(frag_p); + + // gmem_s.store(frag_p, mask); + // gmem_s.move(); + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N]; + fmha::Clear_accumulator::apply(acc_dq); + + // Do this part of O = P^T * V^T. + #pragma unroll + for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) { + // Trigger the load from shared memory for the next series of Q values. + smem_kt.load(frag_kt[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + } + // Do the final stage of math. + { + int ki = Mma_tile_dq::MMAS_K; + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + } + + static_assert(Gmem_tile_dq::LOOPS == 1); + + // Swizzle the elements and do the final reduction. + // Need to syncthreads here, otherwise the smem_dq reads from the previous iteration + // might happen after the smem_dq writes in this iteration. + __syncthreads(); + smem_dq.store(acc_dq, 0); + + typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N]; + static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4); + static_assert(Mma_tile_dkv::MMAS_K == 1); + smem_dot.load(frag_dot[0], 0); + + // Threads in a warp is communicating via shared memory (smem_s and smem_dp) + __syncwarp(); + typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; + smem_s.load(frag_s); + + if (Is_dropout) { + #pragma unroll + for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { + #pragma unroll + for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { + frag_s[ki][mi].template hrelu_(); + } + } + } + + #pragma unroll + for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) { + // Trigger the load from shared memory for the next series of Q values. + smem_dot.load(frag_dot[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + } + + // Do the final stage of math. + { + int ki = Mma_tile_dkv::MMAS_K; + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + } + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // float2 tmp0 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][0])); + // printf("frag_dot[0][0]=%.6f, %.6f\n", tmp0.x, tmp0.y); + // float2 tmp1 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][1])); + // printf("frag_dot[0][1]=%.6f, %.6f\n", tmp1.x, tmp1.y); + // } + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("l = %d, acc_dv[0][0]=%.6f, %.6f\n", l, acc_dv[0][0].elt(2), acc_dv[0][0].elt(3)); + // printf("l = %d, acc_dv[0][1]=%.6f, %.6f\n", l, acc_dv[0][1].elt(2), acc_dv[0][1].elt(3)); + // } + // __syncthreads(); + // Commit the values for Q and dO into shared memory. + if(l < steps - 1) { + gmem_q.commit(gemm_q_k.smem_q); + } + + uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP]; + if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); } + + // __syncthreads(); + // Commit the values for Q and dO into shared memory. + if(l < steps - 1) { + gmem_do.commit(smem_do); + if (Is_first) { + dot_do_o( + gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx + ); + } + gmem_softmax_lse.load(reinterpret_cast(p_lse)); + gmem_softmax_lse.move(); + } + + typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; + smem_dp.load(frag_dpt); + + gemm_q_k.reload_k(); + + typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N]; + static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); + static_assert(Mma_tile_dkv::MMAS_K == 1); + smem_qt.load(frag_qt[0], 0); + + #pragma unroll + for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) { + // Trigger the load from shared memory for the next series of Q values. + smem_qt.load(frag_qt[ki & 1], ki); + // Do the math for the values already in registers. + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + + // Do the final stage of math. + { + int ki = Mma_tile_dkv::MMAS_K; + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + } + + // Make sure dQ is in shared memory. + __syncthreads(); + + if (l < steps - 1) { + gmem_softmax_d.load(reinterpret_cast(dp_sum)); + gmem_softmax_d.move(); + } + + // Load from shared memory. + smem_dq.template load(dq_out); + + const bool is_final_write = + Is_last + || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) + || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); + if (is_final_write) { + // if (Is_dropout) { + // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); + // } + for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { + // dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); + dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout); + } + // Output the values. + gmem_dq.template store(dq_out, 0); + // Move to the next part of the output. + gmem_dq.move(); + } else { + // Output the values. + gmem_dq_tmp.store(dq_out, 0); + } + + // Move to the next part of the output. + if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); } + + // // Make sure the data is in shared memory. + // __syncthreads(); + + // Commit the values for Q and dO into shared memory. + if(l < steps - 1) { + gemm_q_k.smem_q.move_to_next_read_buffer(); + gemm_q_k.reload_q(); + smem_qt.move_to_next_read_buffer(); + // smem_qt.load(frag_qt[0], 0); + smem_do.move_to_next_read_buffer(); + smem_dot.move_to_next_read_buffer(); + // smem_dot.load(frag_dot[0], 0); + } + + } // Outer loop over the sequence length. + + if (Is_dropout) { + for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { + for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { + acc_dv[mi][ni].mul_(params.rp_dropout); + } + } + } + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("l final, acc_dv[0][0]=%.6f, %.6f\n", acc_dv[0][0].elt(2), acc_dv[0][0].elt(3)); + // printf("l final, acc_dv[0][1]=%.6f, %.6f\n", acc_dv[0][1].elt(2), acc_dv[0][1].elt(3)); + // } + for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { + for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { + // acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f); + // acc_dk[mi][ni].mul_(params.scale_bmm1f); + acc_dk[mi][ni].mul_(params.scale_bmm1_rp_dropout); + } + } + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1)); + // } + + __syncthreads(); + // TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than + // the total amount of shared mem? + // Epilogue swizzle for dV + // data flow: fragment -> smem_dv -> global + Smem_tile_dv smem_dv(&smem_[0], tidx); + smem_dv.template store(acc_dv); + + // Epilogue swizzle for dK + Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); + smem_dk.template store(acc_dk); + + __syncthreads(); + uint4 dv_out[Smem_tile_dv::NUM_LDS]; + smem_dv.load(dv_out); + Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, binfo, tidx, false); + if (!Is_first) { + gmem_dv.move(loop_step_idx); + } + gmem_dv.store(dv_out); + + uint4 dk_out[Smem_tile_dk::NUM_LDS]; + smem_dk.load(dk_out); + Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false); + if (!Is_first) { + gmem_dk.move(loop_step_idx); + } + gmem_dk.store(dk_out); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv_1xN_with_bias_mask(const Params ¶ms) { + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The thread index. + const int tidx = threadIdx.x; + + const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx; + auto seeds = philox::unpack(params.philox_args); + Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + + if (loop_steps == 1) { + compute_dq_dk_dv_1xN_one_iter_with_bias_mask(params, ph, 0); + } else if (loop_steps == 2) { + compute_dq_dk_dv_1xN_one_iter_with_bias_mask(params, ph, 0); + compute_dq_dk_dv_1xN_one_iter_with_bias_mask(params, ph, 1); + } else { + if (params.seqlen_k == blocksize_c) { + compute_dq_dk_dv_1xN_one_iter_with_bias_mask(params, ph, 0); + } else { + const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; + compute_dq_dk_dv_1xN_one_iter_with_bias_mask(params, ph, 0); + for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { + compute_dq_dk_dv_1xN_one_iter_with_bias_mask(params, ph, loop_step_idx); + } + compute_dq_dk_dv_1xN_one_iter_with_bias_mask(params, ph, max_loop_steps - 1); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace fmha diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index ceb115c950549b..842c674d74db74 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -658,6 +658,495 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } // Outer loop over the sequence length. } +template +inline __device__ void device_1xN_with_mask_bias(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using elem_type = typename Kernel_traits::elem_type; +#else + constexpr bool is_fp16_type = std::is_same::value; + assert(is_fp16_type); + using elem_type = __half; +#endif + + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = fmha::Hmma_tile; + + // The global memory tile to load Q. + using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; + + // The global memory tile to load K. + using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; + + // The global memory tile to load V. + using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; + // The shared memory tile to swizzle V. + using Smem_tile_v = typename Kernel_traits::Smem_tile_v; + + // The global memory tile to store O. + using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; + using Gmem_tile_o_tmp = fmha::Gmem_tile_o; + // The shared memory tile to swizzle O. + using Smem_tile_o = typename Kernel_traits::Smem_tile_o; + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + + using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; + + using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; + + using Gemm1 = Gemm_Q_K; + + using Softmax = fmha::Softmax; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + // if( binfo.stop_early() ) return; + if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; + + Gemm1 gemm_q_k(smem_, tidx); + + // Allocate the global memory tile loader for Q. + Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); + // Allocate the global memory tile loader for O. + Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); + Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); + // Allocate the global memory tile loader for S. + Gmem_tile_s gmem_s(params, binfo, tidx); + + // Allocate the global memory tile loader for mask. + using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for bias. + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); + + Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); + + // Wind gmem tiles to the correct position. + static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); + const int begin_og = begin; + begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin; + const int steps_og = steps; + steps -= begin - begin_og; + gmem_q.move(begin); + gmem_o.move(begin); + gmem_o_tmp.move(begin); + if (Return_softmax) { gmem_s.move(begin); } + gmem_softmax_lse.move(begin); + + if constexpr (has_attn_mask) { + gmem_mask.move(begin); + } + + if constexpr (has_attn_bias) { + gmem_bias.move(begin); + } + + fmha::Mask mask(binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for K. + Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false); + // Allocate the global memory tile loader for V. + Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false); + // The base pointer of smem_v; + char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; + // smem_ is continous memory, each part is v, o + + // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! + Smem_tile_v smem_v(smem_v_, tidx); + + // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! + Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); + + if (!Is_first) { + gmem_k.move(loop_step_idx); + gmem_v.move(loop_step_idx); + if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } + } + + // Trigger the loads for K. + gmem_k.load(); + // Trigger the loads for Q. + gmem_q.load(); + // Trigger the loads for V. + gmem_v.load(); + + if (!Is_first) { __syncthreads(); } + + float p_prev_lse[Mma_tile_p::MMAS_M * 2]; + if (!Is_first) { + gmem_softmax_lse.load(reinterpret_cast(p_prev_lse)); + } + + // Commit the data for Q and V to shared memory. + gmem_q.commit(gemm_q_k.smem_q); + gmem_v.commit(smem_v); + + // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); + // #pragma unroll + // for(int it=0;it < Gmem_tile_k::LDGS;it++){ + // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); + // } + + // Commit the data for K to shared memory. + if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { + gmem_k.commit(gemm_q_k.smem_k); + } + + __syncthreads(); + + // Load the fragments for Q. + gemm_q_k.load_q(); + + // Load the fragments for V. We keep the data in registers during the entire kernel. + typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { + smem_v.load(frag_v[ki], ki); + } + + // Commit the data for V to shared memory if it has not been done already. + if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { + // Make sure we are done loading the fragments for K. + __syncthreads(); + + // Commit the data to shared memory for V. + gmem_k.commit(gemm_q_k.smem_k); + + // Make sure the data is in shared memory. + __syncthreads(); + } + + // Load the fragments for K. + gemm_q_k.load_k(); + + // Create the object to do the softmax. + Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx); + + Smem_softmax_sum smem_softmax_lse(reinterpret_cast(&smem_[Gemm1::SMEM_BYTES]), tidx); + + // Load over the entire sequence length. + for( int l = 0; l < steps; l++ ) { + if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break; + + // Declare the accumulators for the 1st gemm. + fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::Clear_accumulator::apply(acc_p); + + // Do this part of P = Q * K^T. + gemm_q_k(acc_p); + + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); + // } + + + uint4 out[Gmem_tile_o::STGS_PER_LOOP]; + if (!Is_first) { gmem_o_tmp.load(out, 0); } + + // Trigger the load for the next Q values. + if( l < steps - 1) { + gemm_q_k.smem_q.move_to_next_write_buffer(); + gmem_q.move(); + gmem_q.load(); + } + + // Load the mask for that iteration. + mask.load(begin + l); + + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack_noscale(acc_p); + + if constexpr (has_attn_mask) { + using Frag_mask = fmha::Fragment_c; + Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::clear(frag_mask); + gmem_mask.template load(frag_mask); + gmem_mask.move(); + + // Apply the attn mask. + softmax.apply_attn_mask(frag_mask, mask); + } + + if constexpr (has_attn_bias) { + using Frag_Bias = fmha::Fragment_c; + Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::clear(frag_bias); + gmem_bias.template load(frag_bias); + gmem_bias.move(); + + // Apply the attn mask. + softmax.apply_attn_bias(frag_bias, mask); + } + + // Apply the mask. + // this impl is more like padding + softmax.apply_mask(mask); + + if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { + // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction + __syncthreads(); + } + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]); + // } + // } + // Compute the max. + float p_max[Mma_tile_p::MMAS_M * 2]; + if (!Is_first) { + smem_softmax_lse.store_pair(p_prev_lse, l % 2); + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; } + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; + } + } + + // Trigger the load for the next LSE values. + if( l < steps - 1) { + if (!Is_first) { + gmem_softmax_lse.load_next(reinterpret_cast(p_prev_lse)); + } + } + + softmax.template reduce_max(p_max); + + // if ((threadIdx.x == 0) && (l == 38)) { + // printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]); + // } + + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]); + // } + // } + + // Compute the exponential value. + // softmax.apply_exp(p_max); + softmax.scale_apply_exp(p_max, params.scale_bmm1f); + + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]); + // } + // } + + // Compute the sum. + float p_sum[Mma_tile_p::MMAS_M * 2]; + // if (!Is_first) { + // int warp = tidx / Cta_tile_p::THREADS_PER_WARP; + // int lane = tidx % Cta_tile_p::THREADS_PER_WARP; + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + // p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0; + // } + // } + // softmax.reduce_sum(p_sum); + softmax.reduce_sum_before_sync_(p_sum); + // softmax.template reduce_sum_before_sync_(p_sum); + + // float p_sum_log[Mma_tile_p::MMAS_M * 2]; + // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) { + // float sum = p_sum[mi]; + // // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum); + // constexpr float kLog2e = M_LOG2E; + // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum); + // } + // // gmem_softmax_lse.store(reinterpret_cast(p_sum)); + // gmem_softmax_lse.store(reinterpret_cast(p_sum_log)); + // gmem_softmax_lse.move(); + + // // Finalize softmax on the accumulators of P^T. + // softmax.scale(p_sum); + + + constexpr bool encode_dropout_in_sign_bit = Return_softmax; + if (Is_dropout) { + // softmax.template apply_dropout(ph0, params.p_dropout_in_uint); + // softmax.template apply_dropout(ph0, ph1, params.p_dropout_in_uint); + softmax.template apply_dropout_16bits(ph0, ph1, params.p_dropout_in_uint16_t); + } + + using Frag_p = fmha::Fragment_a; + Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); + static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); + // frag_p = exp(s{i} - s{max}) + softmax.template pack(frag_p); + + if (Return_softmax) { + gmem_s.store(frag_p, mask); + gmem_s.move(); + } + + // Commit the values for Q into shared memory. + if(l < steps - 1) { + gmem_q.commit(gemm_q_k.smem_q); + } + + if (Is_dropout && encode_dropout_in_sign_bit) { + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { + #pragma unroll + for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { + frag_p[ki][mi].template hrelu_(); + } + } + } + + // Declare the accumulators for the 2nd gemm. + fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; + fmha::Clear_accumulator::apply(acc_o); + + // Do this part of O = P^T * V^T. + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { + fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); + // if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki])); + // float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki])); + // printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0)); + // } + } + + // if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0)); + // } + + // The mapping from tidx to rows changes between the softmax and the + // O-reduction. So we recalculate the max. + float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + int rows[Gmem_tile_o::STGS_PER_LOOP]; + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; + } + // When d = 16, O only has 16 x 16 = 256 elements, and each of the 128 threads wants + // to write 4 elements, so only half of the thread should deal with O. + bool o_rows_are_valid = + (Kernel_traits::THREADS <= Gmem_tile_o::THREADS_PER_ROW * Gmem_tile_o::ROWS) + || (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS); + if (o_rows_are_valid) { + softmax.reduce_max_after_sync_(p_max_o, rows); + } + static_assert(Mma_tile_o::MMAS_M == 1); + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + p_max_o[jj][0] *= params.scale_bmm1f; + } + float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP]; + if ((!Is_first) && o_rows_are_valid) { + smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); + } + // if (!Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]); + // } + // } + + static_assert(Gmem_tile_o::LOOPS == 1); + + // Swizzle the elements and do the final reduction. + smem_o.store(acc_o, 0); + + // Make sure the data is in shared memory. + __syncthreads(); + + static_assert(Mma_tile_o::MMAS_M == 1); + float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + if (o_rows_are_valid) { + softmax.reduce_sum_after_sync_(p_sum_o, rows); + } + if (!Is_first) { + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]); + p_sum_o[jj][0] += p_prev_scale_o[jj]; + } + } + + float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; + #pragma unroll + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + float sum = p_sum_o[jj][0]; + p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum); + // if (sum == 0.f || sum != sum) { + // printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]); + // } + // if (Is_first) { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("p_sum_log=%.6f\n", p_sum_log[jj][0]); + // } + // } + if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && o_rows_are_valid) { + gmem_softmax_lse.store_row( + reinterpret_cast(p_sum_log[jj]), rows[jj]); + } + } + gmem_softmax_lse.move(); + + // Load from shared memory. + if (!Is_first) { + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]); + } + } + smem_o.template load(out); + + const bool is_final_write = + Is_last + || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) + || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); + #pragma unroll + for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + float sum = p_sum_o[jj][0]; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + if (Is_dropout && is_final_write) { + inv_sum *= params.rp_dropout; + } + out[jj] = fmha::fmul4(out[jj], inv_sum); + } + + // if (Is_dropout && Is_last) { + // for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { + // out[jj] = fmha::fmul4(out[jj], params.rp_dropout); + // } + // } + + // Output the values. + if (is_final_write) { + gmem_o.template store(out, 0); + gmem_o.move(); + } else { + gmem_o_tmp.store(out, 0); + } + + // Move to the next part of the output. + if (!(Is_first && Is_last)) { gmem_o_tmp.move(); } + gemm_q_k.reload_k(); + + // Make sure we are reading from the correct buffer. + gemm_q_k.smem_q.move_to_next_read_buffer(); + // Trigger the load from shared memory for the next series of Q values. + if(l < steps - 1) { + gemm_q_k.reload_q(); + } + + } // Outer loop over the sequence length. +} + //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -695,6 +1184,41 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { } } + +template +inline __device__ void device_1xN_loop_with_mask_bias(const Params ¶ms) { + + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The thread index. + const int tidx = threadIdx.x; + + const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; + + auto seeds = philox::unpack(params.philox_args); + Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); + constexpr int M = Kernel_traits::Cta_tile_p::M; + + const int STEPS = (params.seqlen_q + M - 1) / M; + // iterative over q, stride with M, block size + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + + if (params.seqlen_k == blocksize_c) { + fmha::device_1xN_with_mask_bias(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + } else { + const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; + // iterative with k + fmha::device_1xN_with_mask_bias(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { + fmha::device_1xN_with_mask_bias(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); + } + fmha::device_1xN_with_mask_bias(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1); + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha diff --git a/csrc/flash_attn/src/fmha_fwd_hdim32.cu b/csrc/flash_attn/src/fmha_fwd_hdim32.cu index f569ca5f6a9526..ed15e1318022fd 100644 --- a/csrc/flash_attn/src/fmha_fwd_hdim32.cu +++ b/csrc/flash_attn/src/fmha_fwd_hdim32.cu @@ -14,4 +14,4 @@ void run_fmha_fwd_hdim32(Launch_params &launch_params) { run_fmha_fwd_loop(launch_params); } })); -} \ No newline at end of file +} diff --git a/csrc/flash_attn/src/fmha_fwd_launch_template.h b/csrc/flash_attn/src/fmha_fwd_launch_template.h index ce4a8cb86d1d14..abccec4f14bc3c 100644 --- a/csrc/flash_attn/src/fmha_fwd_launch_template.h +++ b/csrc/flash_attn/src/fmha_fwd_launch_template.h @@ -90,3 +90,136 @@ void run_fmha_fwd_loop(Launch_params &launch_params) { FMHA_CHECK_CUDA(cudaPeekAtLastError()); })); } + + +template +__global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { + fmha::device_1xN_loop_with_mask_bias(params); +} + +template +void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, + const bool configure) { + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr int M = Kernel_traits::Cta_tile_p::M; + size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; + launch_params.elts_per_thread = elts_per_head; + return; + } + + constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; + // Don't need smem_size_softmax_lse if we're not looping + const int smem_size = fmha::get_dynamic_smem_size() + + (loop_steps > 1 ? smem_size_softmax_lse : 0); + + bool has_attn_mask = !(launch_params.params.attn_mask_ptr == nullptr); + bool has_attn_bias = !(launch_params.params.attn_bias_ptr == nullptr); + + if (has_attn_mask) { + if (has_attn_bias) { + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH_FUNC. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH_FUNC(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH_FUNC. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH_FUNC(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + } + }else{ + if (has_attn_bias) { + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH_FUNC. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH_FUNC(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH_FUNC. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH_FUNC(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + } + } +} diff --git a/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim128.cu b/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim128.cu new file mode 100644 index 00000000000000..38ffdbe57f7a0a --- /dev/null +++ b/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim128.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_with_mask_bias_hdim128(Launch_params &launch_params, + const bool configure) { + auto dprops = GetDeviceProperties(-1); + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) { + // TD [2022-06-05] Keep K in registers to reduce register spilling + // Gives about 6% speedup compared to using block size 128. + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } + })); +} diff --git a/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim16.cu b/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim16.cu new file mode 100644 index 00000000000000..84d4e8a3ad97b9 --- /dev/null +++ b/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim16.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_with_mask_bias_hdim16(Launch_params &launch_params, + const bool configure) { + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if( launch_params.params.seqlen_k == 256 ) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + // TD [2022-05-15] 512 gives wrong results rn + // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>; + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + })); +} diff --git a/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim32.cu b/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim32.cu new file mode 100644 index 00000000000000..e7b5e38e085eee --- /dev/null +++ b/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim32.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_with_mask_bias_hdim32(Launch_params &launch_params, + const bool configure) { + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if( launch_params.params.seqlen_k == 256 ) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + })); +} diff --git a/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim64.cu b/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim64.cu new file mode 100644 index 00000000000000..b2f0d30a012144 --- /dev/null +++ b/csrc/flash_attn/src/fmha_fwd_with_mask_bias_hdim64.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_with_mask_bias_hdim64(Launch_params &launch_params, + const bool configure) { + auto dprops = GetDeviceProperties(-1); + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if( launch_params.params.seqlen_k >= 256 ) { + if (dprops->major == 8 && dprops->minor >= 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if (dprops->major == 7 && dprops->minor == 5) { + if (launch_params.is_dropout) { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } + } + })); +} diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index 53bcf35d6936e6..ccd4785060a113 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -38,3 +38,25 @@ F(); \ } \ } + +#define BOOL_SWITCH_FUNC(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH_FUNC(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = __half; \ + return __VA_ARGS__(); \ + } \ + }() \ No newline at end of file