Skip to content

Commit

Permalink
Some kernel changes for TULR (microsoft#14517)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
1. fix a bug in relative position bias kernel where seq_len > 32
2. rename extra_add_qk to relative_position_bias
3. support relative_position_bias in multihead attention (B, N, S, S*)
4. gru_gate support by Lei


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>
  • Loading branch information
3 people authored and preetha-intel committed Feb 15, 2023
1 parent db869f6 commit 0cf39b6
Show file tree
Hide file tree
Showing 38 changed files with 802 additions and 132 deletions.
4 changes: 4 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ set(contrib_ops_excluded_files
"bert/fast_gelu_impl.h"
"bert/fast_gelu.cc"
"bert/fast_gelu.h"
"bert/relative_attn_bias.cc"
"bert/relative_attn_bias.h"
"bert/relative_attn_bias_impl.cu"
"bert/relative_attn_bias_impl.h"
"bert/skip_layer_norm.cc"
"bert/skip_layer_norm.h"
"bert/skip_layer_norm_impl.cu"
Expand Down
63 changes: 59 additions & 4 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Do not modify directly.*
* <a href="#com.microsoft.FusedConv">com.microsoft.FusedConv</a>
* <a href="#com.microsoft.FusedGemm">com.microsoft.FusedGemm</a>
* <a href="#com.microsoft.FusedMatMul">com.microsoft.FusedMatMul</a>
* <a href="#com.microsoft.GatedRelativePositionBias">com.microsoft.GatedRelativePositionBias</a>
* <a href="#com.microsoft.GatherND">com.microsoft.GatherND</a>
* <a href="#com.microsoft.Gelu">com.microsoft.Gelu</a>
* <a href="#com.microsoft.GemmFastGelu">com.microsoft.GemmFastGelu</a>
Expand Down Expand Up @@ -152,7 +153,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)</dd>
<dt><tt>past</tt> (optional) : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>extra_add</tt> (optional) : T</dt>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
Expand Down Expand Up @@ -1608,6 +1609,58 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.GatedRelativePositionBias"></a><a name="com.microsoft.gatedrelativepositionbias">**com.microsoft.GatedRelativePositionBias**</a>

query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2)
gate_u, gate_r = torch.sigmoid(
self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0
rel_pos_bias = gate_u_1 * rel_pos

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
</dl>

#### Inputs

<dl>
<dt><tt>query_layer</tt> : T</dt>
<dd>tensor with shape (batch_size, seq_len, num_heads x head_size)</dd>
<dt><tt>query_bias</tt> : T</dt>
<dd>1-d tensor with shape (num_heads x head_size)</dd>
<dt><tt>rel_pos</tt> : T</dt>
<dd>tensor with shape (1, num_head, seq_len, seq_len)</dd>
<dt><tt>weight</tt> : T</dt>
<dd>gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2</dd>
<dt><tt>bias</tt> : T</dt>
<dd>bias for the gated_ur_linear, shape (D)</dd>
<dt><tt>eco_a</tt> : T</dt>
<dd>tensor of shape (1, num_heads, 1, 1)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>output tensor with shape (batch_size, num_heads, seq_len, seq_len)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


### <a name="com.microsoft.GatherND"></a><a name="com.microsoft.gathernd">**com.microsoft.GatherND**</a>

Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather
Expand Down Expand Up @@ -2222,7 +2275,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
</dl>

#### Inputs (2 - 5)
#### Inputs (2 - 6)

<dl>
<dt><tt>query</tt> : T</dt>
Expand All @@ -2235,6 +2288,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
</dl>

#### Outputs
Expand Down Expand Up @@ -3221,7 +3276,7 @@ This version of the operator has been available since version 1 of the 'com.micr
left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by
the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past
and present state are optional. Present state could appear in output even when past state is not in input.
Current version does not support past/present, extra_add and qkv_hidden_sizes.
Current version does not support past/present, relative_position_bias and qkv_hidden_sizes.
TODO: Support them if needed in the future.

#### Version
Expand Down Expand Up @@ -3286,7 +3341,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).</dd>
<dt><tt>past</tt> (optional) : Q</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).</dd>
<dt><tt>extra_add</tt> (optional) : S</dt>
<dt><tt>relative_position_bias</tt> (optional) : S</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).</dd>
</dl>

Expand Down
11 changes: 6 additions & 5 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -785,7 +785,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand All @@ -803,18 +803,19 @@ Do not modify directly.*
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* extra_add:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* relative_position_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale:**F**<br> *in* B:**F**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedLongformerAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* weight:**Q**<br> *in* scale_weight:**S**<br> *in* bias:**S**<br> *in* scale_bias:**S**<br> *in* scale_qkv_gemm:**S**<br> *in* mask:**F**<br> *in* global_weight:**Q**<br> *in* scale_global_weight:**S**<br> *in* global_bias:**S**<br> *in* scale_global_gemm:**S**<br> *in* global:**G**<br> *in* scale_output:**S**<br> *out* output:**Q**|1+|**F** = tensor(float16)<br/> **G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down Expand Up @@ -1159,7 +1160,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {

const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(4);
const Tensor* extra_add_qk = context->Input<Tensor>(5);
const Tensor* relative_position_bias = context->Input<Tensor>(5);

const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_);

Expand All @@ -208,7 +208,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
bias->Shape(),
mask_index,
past,
extra_add_qk,
relative_position_bias,
&parameters));

const int batch_size = parameters.batch_size;
Expand Down Expand Up @@ -331,7 +331,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
return ApplyAttention(Q, K, V, mask_index, past, output,
batch_size, sequence_length,
parameters.head_size, parameters.v_head_size, parameters.v_hidden_size,
extra_add_qk, context);
relative_position_bias, context);
}
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit 0cf39b6

Please sign in to comment.