Skip to content

Commit

Permalink
Stable Diffusion CUDA Optimizations Part 4 (#14680)
Browse files Browse the repository at this point in the history
(1) Support packed QKV format in MultiHeadAttention. This format could
avoid add bias transpose when TRT fused kernel is used.
(2) Add cache for cumulated sequence length computation. For SD, it only
need computed once since sequence length is fixed.
(3) Do not allocate qkv workspace to save memory for packed KV or QKV.
(4) Add unit tests for packed kv and packed qkv format in
MultiHeadAttention
(5) Mark some fusion options for SD only

Performance tests show slight improvement in T4. Average latency reduced
0.15 seconds (from 5.25s to 5.10s) for 512x512 in 50 steps for SD 1.5
models. Memory usage drops from 5.1GB to 4.8GB.
  • Loading branch information
tianleiwu committed Feb 15, 2023
1 parent 7529f14 commit 2505532
Show file tree
Hide file tree
Showing 19 changed files with 1,121 additions and 211 deletions.
6 changes: 3 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2310,12 +2310,12 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
</dl>

#### Inputs (2 - 6)
#### Inputs (1 - 6)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)</dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size)</dd>
Expand Down
86 changes: 56 additions & 30 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,53 +22,79 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
int max_threads_per_block) {
// key_padding_mask (K/V) : (B) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
// When no packing for q/k/v:
// query (Q) : (B, S, D)
// key (K) : (B, L, D)
// value (V) : (B, L, D_v)
// bias (Q/K/V) : (D + D + D_v)
// key_padding_mask (K/V) : (B) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
// When packed kv is used:
// query (Q) : (B, S, D)
// key (K) : (B, L, N, 2, H)
// value (V) : None
// bias (Q/K/V) : None
// When packed qkv is used:
// query (Q) : (B, L, N, 3, H)
// key (K) : None
// value (V) : None
// bias (Q/K/V) : None

const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
if (query_dims.size() != 3 && query_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ",
query_dims.size());
}

const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

int batch_size = static_cast<int>(query_dims[0]);
int sequence_length = static_cast<int>(query_dims[1]);
int hidden_size = static_cast<int>(query_dims[2]);
int hidden_size = query_dims.size() == 3 ? static_cast<int>(query_dims[2]) : (num_heads * static_cast<int>(query_dims[4]));
int head_size = static_cast<int>(hidden_size) / num_heads;
int kv_sequence_length = static_cast<int>(key_dims[1]);
int kv_sequence_length = sequence_length;

if (key != nullptr) {
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ",
query_dims.size());
}

if (key_dims.size() == 3) {
if (key_dims[2] != query_dims[2]) {
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 2 (hidden_size)");
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}
} else // if (key_dims.size() == 5)
{
if (static_cast<int>(key_dims[2]) != num_heads || static_cast<int>(key_dims[3]) != 2 || static_cast<int>(key_dims[4]) != head_size) {

if (key_dims.size() == 3) {
if (key_dims[2] != query_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 2 (hidden_size)");
}
} else // if (key_dims.size() == 5)
{
if (static_cast<int>(key_dims[2]) != num_heads || static_cast<int>(key_dims[3]) != 2 || static_cast<int>(key_dims[4]) != head_size) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
}
}

kv_sequence_length = static_cast<int>(key_dims[1]);
} else { // packed QKV
if (query_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 5 dimensions when key is empty, got ",
query_dims.size());
}
if (static_cast<int>(query_dims[2]) != num_heads || static_cast<int>(query_dims[3]) != 3) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
"Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv");
}
}

Expand All @@ -82,17 +108,17 @@ Status CheckInputs(const T* query,
// Currently, bias is not allowed for packed KV. This constraint can be removed later.
// Here we assume that fusion tool will not include bias for packed KV.
if (value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. ");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed qkv or kv. ");
}
}

AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
const auto& mask_dims = key_padding_mask->Shape().GetDims();
if (mask_dims.size() == 1 && mask_dims[0] == key_dims[0]) {
if (mask_dims.size() == 1 && mask_dims[0] == static_cast<int64_t>(batch_size)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
} else if (mask_dims.size() == 2 && mask_dims[0] == key_dims[0] && mask_dims[1] == key_dims[1]) {
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
}

Expand All @@ -115,7 +141,7 @@ Status CheckInputs(const T* query,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}

if (key_dims[1] != value_dims[1]) {
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)");
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -190,6 +191,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner,
use_fused_cross_attention,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());

Expand All @@ -204,12 +206,15 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.mask_index_dims = (nullptr == mask_index) ? gsl::span<const int64_t>() : mask_index->Shape().GetDims();
data.past = (nullptr == past) ? nullptr : reinterpret_cast<const CudaT*>(past->Data<T>());
data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast<const CudaT*>(relative_position_bias->Data<T>());
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast<CudaT*>(work_space.get());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = nullptr;
data.use_memory_efficient_attention = use_memory_efficient_attention;
data.cumulated_sequence_length_q_cache = nullptr;
data.cumulated_sequence_length_kv_cache = nullptr;

return QkvToContext<CudaT>(device_prop, cublas, Stream(context), parameters, data);
}
Expand Down
Loading

0 comments on commit 2505532

Please sign in to comment.