-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scaled_dot_product_attention api #55242
Changes from 10 commits
3f0c520
3362289
d419bc7
0b8f22a
d161b82
3d838ac
bb9aecd
0be5b6c
08af867
bea4ee3
1d2a94e
871ff54
e68519e
5c2def7
29a118d
ddcff7a
a9abd3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -255,4 +255,5 @@ | |
'multi_margin_loss', | ||
'soft_margin_loss', | ||
'gaussian_nll_loss', | ||
'scaled_dot_product_attention', | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -407,4 +407,55 @@ def flash_attn_unpadded( | |
return out, softmax if return_softmax else None | ||
|
||
|
||
scaled_dot_product_attention = flash_attention | ||
def scaled_dot_product_attention( | ||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False | ||
): | ||
r""" | ||
The equation is: | ||
|
||
.. math:: | ||
|
||
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V | ||
|
||
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. | ||
The dimensions of the three parameters are the same. | ||
``d`` represents the size of the last dimension of the three parameters. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在数学公式里面, 一般用 where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK,我这个是用的ChatGPT做的改动,仅供参考。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👌 |
||
|
||
Warning: | ||
This API only supports inputs with dtype float16 and bfloat16. | ||
|
||
Args: | ||
query(Tensor): The query tensor in the Attention module. | ||
4-D tensor with shape: | ||
[batch_size, seq_len, num_heads, head_dim]. | ||
The dtype can be float61 or bfloat16. | ||
key(Tensor): The key tensor in the Attention module. | ||
4-D tensor with shape: | ||
[batch_size, seq_len, num_heads, head_dim]. | ||
The dtype can be float61 or bfloat16. | ||
value(Tensor): The value tensor in the Attention module. | ||
4-D tensor with shape: | ||
[batch_size, seq_len, num_heads, head_dim]. | ||
The dtype can be float61 or bfloat16. | ||
attn_mask(Tensor,optional): A float mask of the same type as query, | ||
key, value that is added to the attention score. | ||
not supported yet. | ||
dropout_p(float): The dropout ratio. | ||
is_causal(bool): Whether enable causal mode. | ||
|
||
Returns: | ||
out(Tensor): The attention tensor. | ||
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. | ||
The dtype can be float16 or bfloat16. | ||
|
||
Examples: | ||
.. code-block:: python | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 框架正在引入xdoctest,示例代码可以顺便改成xdoctest支持的格式,see #55295 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xdoctest支持的格式是什么样的呢? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请参看我给的PR里的改动。 |
||
|
||
>>> import paddle | ||
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) | ||
>>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) | ||
>>> print(output) | ||
""" | ||
assert attn_mask is None, "attn_mask is not supported yet" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经有工作正在支持attn_mask,因此依赖当前PR合入。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
out, _ = flash_attention(query, key, value, dropout_p, is_causal) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to be consistent with other APIs, there must be a parameter
name=None
at last