Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scaled_dot_product_attention api #55242

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,5 @@
'multi_margin_loss',
'soft_margin_loss',
'gaussian_nll_loss',
'scaled_dot_product_attention',
]
53 changes: 52 additions & 1 deletion python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,4 +407,55 @@ def flash_attn_unpadded(
return out, softmax if return_softmax else None


scaled_dot_product_attention = flash_attention
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to be consistent with other APIs, there must be a parameter name=None at last

):
r"""
The equation is:

.. math::

result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V

where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, Q, K, and V denote the three input parameters of the attention module, all sharing identical dimensions. d represents the size of the last dimension of these three parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在数学公式里面, 一般用 where

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK,我这个是用的ChatGPT做的改动,仅供参考。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌


Warning:
This API only supports inputs with dtype float16 and bfloat16.

Args:
query(Tensor): The query tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
key(Tensor): The key tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
value(Tensor): The value tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
attn_mask(Tensor,optional): A float mask of the same type as query,
key, value that is added to the attention score.
not supported yet.
dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode.

Returns:
out(Tensor): The attention tensor.
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
The dtype can be float16 or bfloat16.

Examples:
.. code-block:: python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

框架正在引入xdoctest,示例代码可以顺便改成xdoctest支持的格式,see #55295

Copy link
Contributor Author

@liuzhenhai93 liuzhenhai93 Jul 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xdoctest支持的格式是什么样的呢?
是否有个 demo 或明确的规范

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请参看我给的PR里的改动。


>>> import paddle
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16)
>>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
>>> print(output)
"""
assert attn_mask is None, "attn_mask is not supported yet"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If attn_mask is not currently supported, add a TODO statement to indicate that it will be supported later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经有工作正在支持attn_mask,因此依赖当前PR合入。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

out, _ = flash_attention(query, key, value, dropout_p, is_causal)
return out
51 changes: 40 additions & 11 deletions test/legacy_test/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.nn.functional.flash_attention import (
flash_attention,
flash_attn_unpadded,
scaled_dot_product_attention,
)


Expand Down Expand Up @@ -85,6 +86,7 @@ def setUp(self):
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
self.use_sdp_api = False

def test_unpadded(self):
print(
Expand Down Expand Up @@ -212,9 +214,15 @@ def test_all(self):
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
if self.use_sdp_api:
out = scaled_dot_product_attention(
q, k, v, None, self.dropout, self.causal
)
else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)

else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
Expand Down Expand Up @@ -253,14 +261,19 @@ def test_all(self):
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
if self.use_sdp_api:
outs = scaled_dot_product_attention(
qs, ks, vs, None, self.dropout, self.causal
)
else:
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
else:
outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax
Expand Down Expand Up @@ -334,6 +347,22 @@ def setUp(self):
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.use_sdp_api = False
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False


class TestSDPAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.use_sdp_api = True
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False
Expand Down