Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable fused causal attention #14732

Merged
merged 3 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ enum AttentionQkvFormat {
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
};

enum AttentionKernelType{
enum AttentionKernelType {
AttentionKernel_Unfused,
AttentionKernel_TrtFusedAttention,
AttentionKernel_TrtFlashAttention,
Expand All @@ -38,15 +38,15 @@ enum AttentionKernelType{
struct AttentionParameters {
int batch_size;
int sequence_length;
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
bool is_unidirectional;
bool past_present_share_buffer;
Expand All @@ -56,13 +56,17 @@ struct AttentionParameters {
};

namespace attention {
// Environment variable to enable or disable fused self/causal attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION";
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";

// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION";

// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled).
// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled).
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";

// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";

// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
Expand Down
14 changes: 8 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
disable_fused_runner_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);

enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);

enable_fused_causal_attention_ = sizeof(T) == 2 &&
ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);

#if USE_FLASH_ATTENTION
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
Expand Down Expand Up @@ -97,14 +100,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
int sm = device_prop.major * 10 + device_prop.minor;
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;

if (is_unidirectional_) { // GPT
if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT
// GPT fused kernels requires left side padding. mask can be:
// none (no padding), 1D sequence lengths or 2d mask.
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
// where past state is empty.
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = !disable_fused_runner_ &&
(nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
Expand All @@ -121,7 +123,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
fused_runner = fused_fp16_runner_.get();
}
} else { // BERT
bool use_fused_runner = !disable_fused_runner_ &&
bool use_fused_runner = !disable_fused_self_attention_ &&
(nullptr == mask_index || is_mask_1d_seq_len) &&
nullptr == past &&
nullptr == present &&
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ class Attention final : public CudaKernel, public AttentionBase {
Status ComputeInternal(OpKernelContext* context) const override;

protected:
bool disable_fused_runner_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool enable_fused_causal_attention_;
bool disable_memory_efficient_attention_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
};
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,14 @@ Status QkvToContext(
if (use_fused_kernel || use_fused_causal) {
int* sequence_offset = reinterpret_cast<int*>(scratch1);
if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
DUMP_TENSOR_D("mask", reinterpret_cast<const int*>(data.mask_index), batch_size, sequence_length);
LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
} else {
sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
data.mask_index, batch_size, sequence_length, stream,
sequence_offset);
}
DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1);
CUDA_RETURN_IF_ERROR(cudaGetLastError());

FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)

mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);

disable_fused_runner_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);

enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
Expand Down Expand Up @@ -124,7 +124,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
}

bool use_fused_runner = !disable_fused_runner_ &&
bool use_fused_runner = !disable_fused_self_attention_ &&
fused_cross_attention_kernel == nullptr &&
nullptr == relative_position_bias &&
(value != nullptr || key == nullptr) &&
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MultiHeadAttention final : public CudaKernel {
protected:
int num_heads_; // number of attention heads
float mask_filter_value_;
bool disable_fused_runner_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool disable_fused_cross_attention_;
bool disable_memory_efficient_attention_;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def get_ort_environment_variables():
# Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
env_names = [
"ORT_DISABLE_FUSED_ATTENTION",
"ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
"ORT_DISABLE_FUSED_CROSS_ATTENTION",
"ORT_DISABLE_TRT_FLASH_ATTENTION",
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
Expand Down
18 changes: 10 additions & 8 deletions onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from benchmark_helper import Precision
from float16 import float_to_float16_max_diff
from fusion_options import AttentionMaskFormat
from io_binding_helper import IOBindingHelper
from onnx_model import OnnxModel
from torch_onnx_export_helper import torch_onnx_export
Expand Down Expand Up @@ -188,6 +187,7 @@ def get_dummy_inputs(
input_ids_dtype: torch.dtype = torch.int32,
position_ids_dtype: torch.dtype = torch.int32,
attention_mask_dtype: torch.dtype = torch.int32,
left_side_padding: bool = True,
) -> Gpt2Inputs:
"""Create random inputs for GPT2 model.
Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
Expand Down Expand Up @@ -218,9 +218,14 @@ def get_dummy_inputs(
dtype=attention_mask_dtype,
device=device,
)

if total_sequence_length >= 2:
padding_position = random.randint(0, total_sequence_length - 1) # test input with padding.
attention_mask[:, padding_position] = 0
for i in range(batch_size):
padding_length = random.randint(0, total_sequence_length - 1)
if left_side_padding:
attention_mask[i, :padding_length] = 0
else: # right side padding
attention_mask[i, total_sequence_length - padding_length :] = 0

# Deduce position_ids from attention mask
position_ids = None
Expand Down Expand Up @@ -517,11 +522,6 @@ def optimize_onnx(

optimization_options = FusionOptions("gpt2")

if is_float16 and stage == 1:
# For init_decoder, enable mask index to use fused causal cuda kernel.
# Potentially, we can add other optimization like unpad for effective transformer
optimization_options.attention_mask_format = AttentionMaskFormat.MaskIndexEnd

# TODO(hasesh): Investigate parity issue for GPT-2 fp16 when SkipLayerNormalization
# is enabled
if is_float16:
Expand Down Expand Up @@ -841,6 +841,7 @@ def test_parity(
input_ids_dtype=input_ids_dtype,
position_ids_dtype=position_ids_dtype,
attention_mask_dtype=attention_mask_dtype,
left_side_padding=True,
)
outputs = Gpt2Helper.pytorch_inference(model, dummy_inputs)
if use_io_binding:
Expand Down Expand Up @@ -868,6 +869,7 @@ def test_parity(
max_abs_diff_list.append(max_abs_diff)
if is_all_close:
passed_test_cases += 1

if is_top1_matched:
top1_matched_cases += 1
top1_matched_cases_per_run[run_id] += 1
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/test/contrib_ops/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"}}};
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
use_past_state, past_sequence_length, &past_data, &present_data);
Expand All @@ -941,7 +942,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
use_past_state, past_sequence_length, &past_data, &present_data);
Expand All @@ -952,7 +954,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
use_past_state, past_sequence_length, &past_data, &present_data);
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/test/contrib_ops/multihead_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
RunMultiHeadAttentionTest(
Expand All @@ -195,7 +195,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
Expand All @@ -209,7 +209,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
Expand All @@ -224,7 +224,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
RunMultiHeadAttentionTest(
Expand All @@ -239,7 +239,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
Expand Down