Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Lean Attention #22352

Merged
merged 10 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF)

cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
cmake_dependent_option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)

option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
Expand Down Expand Up @@ -733,21 +734,25 @@ if (onnxruntime_USE_CUDA)

if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_LEAN_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_LEAN_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_LEAN_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
endif()
else()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_LEAN_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()

Expand All @@ -761,6 +766,13 @@ if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1)
endif()

if (onnxruntime_USE_LEAN_ATTENTION)
message( STATUS "Enable flash attention for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_LEAN_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_LEAN_ATTENTION=1)
endif()

if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
message( STATUS "Enable memory efficient attention for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ enum AttentionKernelType {
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_LeanAttention,
AttentionKernel_Default
};

Expand Down Expand Up @@ -169,10 +170,13 @@ enum class AttentionBackend : int {
CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention.
MATH = 16, // unfused kernel cannot be disabled right now.

// The following kernels might be deprecated in the future.
// The following TRT kernels might be deprecated in the future.
TRT_FLASH_ATTENTION = 32,
TRT_CROSS_ATTENTION = 64,
TRT_CAUSAL_ATTENTION = 128,

// Experimental kernels
LEAN_ATTENTION = 256,
};

// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
Expand Down Expand Up @@ -200,6 +204,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Environment variable to enable or disable lean attention. Default is 0 (disabled).
constexpr const char* kEnableLeanAttention = "ORT_ENABLE_LEAN_ATTENTION";

// Minimum sequence length to perfer memory efficient attention when data type is float32
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";

Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention = false;
}
// Allocate buffers
size_t softmax_lse_bytes = 0;
size_t softmax_lse_accum_bytes = 0;
size_t out_accum_bytes = 0;
if (use_flash_attention) {
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, parameters.num_heads);

using namespace std;
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads,
Expand All @@ -129,10 +132,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
softmax_lse_accum_bytes = slse_accum_bytes;
out_accum_bytes = o_accum_bytes;
}
auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
auto out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
#else
constexpr bool use_flash_attention = false;
auto softmax_lse_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif
Expand Down Expand Up @@ -247,6 +252,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
constexpr bool use_cudnn_flash_attention = false;
constexpr bool use_lean_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -257,6 +263,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.total_sequence_length,
fused_runner,
use_flash_attention,
use_lean_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
Expand Down Expand Up @@ -289,6 +296,10 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}

if (softmax_lse_accum_buffer != nullptr) {
data.softmax_lse_accum = reinterpret_cast<CudaT*>(softmax_lse_accum_buffer.get());
}
Expand Down
94 changes: 92 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/lean_attention/lean_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

using namespace onnxruntime::cuda;
Expand Down Expand Up @@ -108,6 +109,7 @@
size_t total_sequence_length,
void* fused_runner,
bool use_flash_attention,
bool use_lean_attention,

Check warning on line 112 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:112: Do not indent within a namespace. [whitespace/indent_namespace] [4]
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
Expand All @@ -119,12 +121,20 @@

#if USE_FLASH_ATTENTION
if (use_flash_attention) {
return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads);
return qkv_bytes;
}
#else
ORT_UNUSED_PARAMETER(use_flash_attention);
#endif

#if USE_LEAN_ATTENTION
if (use_lean_attention) {
return qkv_bytes;
}
#else
ORT_UNUSED_PARAMETER(use_lean_attention);
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
if (use_memory_efficient_attention) {
size_t fmha_buffer_bytes = 0;
Expand Down Expand Up @@ -301,7 +311,7 @@

constexpr bool is_bf16 = false;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.scratch),
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.softmax_lse),
parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
parameters.sequence_length, parameters.total_sequence_length, scale, 0.0, parameters.is_unidirectional, is_bf16,
false, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
Expand All @@ -326,6 +336,81 @@
}
#endif

#if USE_LEAN_ATTENTION
template <typename T>
Status LeanAttention(
const cudaDeviceProp& device_prop,

Check warning on line 342 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:342: Do not indent within a namespace. [whitespace/indent_namespace] [4]
cudaStream_t stream,

Check warning on line 343 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:343: Do not indent within a namespace. [whitespace/indent_namespace] [4]
contrib::AttentionParameters& parameters,

Check warning on line 344 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:344: Do not indent within a namespace. [whitespace/indent_namespace] [4]
AttentionData<T>& data,

Check warning on line 345 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:345: Do not indent within a namespace. [whitespace/indent_namespace] [4]
float scale) {

Check warning on line 346 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:346: Do not indent within a namespace. [whitespace/indent_namespace] [4]
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
assert(nullptr == data.mask_index);
assert(nullptr == data.attention_bias);
assert(parameters.head_size == parameters.v_head_size);

constexpr bool is_bf16 = false;

ORT_RETURN_IF_ERROR(onnxruntime::lean::mha_fwd_kvcache(
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
device_prop, stream,
data.q,
data.k, // k_cache

Check warning on line 358 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:358: At least two spaces is best between code and comments [whitespace/comments] [2]
data.v, // v_cache

Check warning on line 359 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:359: At least two spaces is best between code and comments [whitespace/comments] [2]
nullptr, // new_k (we have appended new_k to k_cache)
nullptr, // new_v (we have appended new_v to k_cache)
data.output,
reinterpret_cast<void*>(data.softmax_lse),
nullptr,
nullptr, // cos_cache

Check warning on line 365 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:365: At least two spaces is best between code and comments [whitespace/comments] [2]
nullptr, // sin_cache

Check warning on line 366 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:366: At least two spaces is best between code and comments [whitespace/comments] [2]
nullptr, // block_table
parameters.batch_size,
parameters.num_heads,
parameters.num_heads, // num_heads_k
parameters.head_size,
parameters.sequence_length, // seqlen_q
parameters.total_sequence_length, // seqlen_k
0, // seqlen_k_new
0, // rotary_dim
scale, // softmax_scale
parameters.is_unidirectional,
is_bf16,
false, // past_bsnh
parameters.num_splits,
data.grid_dim_z,
data.max_tiles_per_tb,
data.high_load_tbs,
data.tiles_per_head,
reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum),
data.lean_sync_flag,
-1, // local_window_size
false, // is_rotary_interleaved
false // is_packed_qkv
));

return Status::OK();
}

template <>
Status LeanAttention(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(device_prop);
ORT_UNUSED_PARAMETER(stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "lean attention does not support float tensor");
}
#endif



template <typename T>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Expand Down Expand Up @@ -641,6 +726,11 @@
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
#if USE_LEAN_ATTENTION
if (data.use_lean_attention) {
return LeanAttention(device_prop, stream, parameters, data, scale);
}
#endif

#if USE_FLASH_ATTENTION
if (data.use_flash_attention) {
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ size_t GetAttentionWorkspaceSize(
size_t total_sequence_length,
void* fused_runner,
bool use_flash_attention,
bool use_lean_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
Expand Down Expand Up @@ -102,6 +103,16 @@ struct AttentionData {
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;

// Lean Attention
bool use_lean_attention = false;
#if USE_LEAN_ATTENTION
int grid_dim_z = 0;
int max_tiles_per_tb = 0;
int high_load_tbs = 0;
int tiles_per_head = 0;
int* lean_sync_flag = nullptr;
#endif

// For Debugging
size_t workspace_bytes = 0;
bool allow_debug_info = false;
Expand All @@ -115,6 +126,7 @@ struct AttentionData {

void PrintDebugInfo() const {
std::cout << "flash=" << use_flash_attention
<< ", lean=" << use_lean_attention
<< ", efficient=" << use_memory_efficient_attention
<< ", fused_runner=" << (fused_runner != nullptr)
<< ", fused_cross=" << (fused_cross_attention_kernel != nullptr)
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_lean_attention_ = (value & static_cast<int>(AttentionBackend::LEAN_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0;
Expand All @@ -26,6 +27,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
} else {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
use_lean_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableLeanAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
Expand Down Expand Up @@ -61,6 +63,10 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che
use_flash_attention_ = false;
#endif

#ifndef USE_LEAN_ATTENTION
use_lean_attention_ = false;
#endif

#ifndef USE_MEMORY_EFFICIENT_ATTENTION
use_efficient_attention_ = false;
#endif
Expand All @@ -81,6 +87,7 @@ void AttentionKernelOptions::Print() const {
std::stringstream sstream;
sstream << "AttentionKernelOptions:";
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_);
sstream << " LEAN_ATTENTION=" << int(use_lean_attention_);
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_);
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_);
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_);
Expand Down Expand Up @@ -131,6 +138,8 @@ void AttentionKernelDebugInfo::Print(const char* operator_name,
sstream << " SdpaKernel=";
if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << "FLASH_ATTENTION";
} else if (use_lean_attention.has_value() && use_lean_attention.value()) {
sstream << "LEAN_ATTENTION";
} else if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << "EFFICIENT_ATTENTION";
} else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
namespace onnxruntime {
struct AttentionKernelDebugInfo {
std::optional<bool> use_flash_attention = std::nullopt;
std::optional<bool> use_lean_attention = std::nullopt;
std::optional<bool> use_efficient_attention = std::nullopt;
std::optional<bool> use_trt_fused_attention = std::nullopt;
std::optional<bool> use_cudnn_flash_attention = std::nullopt;
Expand All @@ -24,6 +25,7 @@ class AttentionKernelOptions {
void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false);

bool UseFlashAttention() const { return use_flash_attention_; }
bool UseLeanAttention() const { return use_lean_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
bool UseTrtFusedAttention() const { return use_trt_fused_attention_; }
bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; }
Expand All @@ -44,6 +46,7 @@ class AttentionKernelOptions {

private:
bool use_flash_attention_{true};
bool use_lean_attention_{false};
bool use_efficient_attention_{true};
bool use_trt_fused_attention_{true};
bool use_cudnn_flash_attention_{false};
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,

if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.use_lean_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Use oiginal Query (BSNH) since there is no bias.
data.q = const_cast<T*>(data.query);
Expand Down
Loading
Loading