Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-78] Add type annotations for python/paddle/incubate/nn/functional/masked_multihead_attention.py #67558

Merged
merged 6 commits into from
Aug 22, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, overload

from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode

if TYPE_CHECKING:
from paddle import Tensor


@overload
def masked_multihead_attention(
x: Tensor,
cache_kv: Tensor | None = ...,
bias: Tensor | None = ...,
src_mask: Tensor | None = ...,
cum_offsets: Tensor | None = ...,
sequence_lengths: Tensor | None = ...,
rotary_tensor: Tensor | None = ...,
beam_cache_offset: None = ...,
qkv_out_scale: Tensor | None = ...,
out_shift: Tensor | None = ...,
out_smooth: Tensor | None = ...,
seq_len: int = ...,
rotary_emb_dims: int = ...,
use_neox_rotary_style: bool = ...,
compute_dtype: str = ...,
out_scale: float = ...,
quant_round_type: int = ...,
quant_max_bound: float = ...,
quant_min_bound: float = ...,
) -> tuple[Tensor, Tensor]: ...


@overload
def masked_multihead_attention(
x: Tensor,
cache_kv: Tensor | None = ...,
bias: Tensor | None = ...,
src_mask: Tensor | None = ...,
cum_offsets: Tensor | None = ...,
sequence_lengths: Tensor | None = ...,
rotary_tensor: Tensor | None = ...,
beam_cache_offset: Tensor = ...,
qkv_out_scale: Tensor | None = ...,
out_shift: Tensor | None = ...,
out_smooth: Tensor | None = ...,
seq_len: int = ...,
rotary_emb_dims: int = ...,
use_neox_rotary_style: bool = ...,
compute_dtype: str = ...,
out_scale: float = ...,
quant_round_type: int = ...,
quant_max_bound: float = ...,
quant_min_bound: float = ...,
) -> tuple[Tensor, Tensor, Tensor]: ...


def masked_multihead_attention(
x,
Expand Down Expand Up @@ -44,7 +99,7 @@ def masked_multihead_attention(

Args:
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
cache_kv (Tensor): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
bias (Tensor, optional): The bias tensor. Its shape is [3, num_head, head_dim].
src_mask (Tensor, optional): The src_mask tensor. Its shape is [batch_size, 1, 1, sequence_length].
sequence_lengths (Tensor, optional): The sequence_lengths tensor, used to index input. Its shape is [batch_size, 1].
Expand Down