Skip to content

Commit

Permalink
[Typing][C-74] Add type annotations for `python/paddle/incubate/nn/fu…
Browse files Browse the repository at this point in the history
…nctional/fused_matmul_bias.py` (#66656)
  • Loading branch information
Betelgeu authored Jul 28, 2024
1 parent c8710f0 commit 03e5e1a
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions python/paddle/incubate/nn/functional/fused_matmul_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from paddle import _C_ops, _legacy_C_ops
from paddle.base.layer_helper import LayerHelper
from paddle.framework import (
Expand All @@ -20,10 +24,18 @@
)
from paddle.tensor.linalg import matmul

if TYPE_CHECKING:
from paddle import Tensor


def fused_matmul_bias(
x, y, bias=None, transpose_x=False, transpose_y=False, name=None
):
x: Tensor,
y: Tensor,
bias: Tensor | None = None,
transpose_x: bool = False,
transpose_y: bool = False,
name: str | None = None,
) -> Tensor:
"""
Applies matrix multiplication of two tensors and then bias addition if provided.
This method requires CUDA version >= 11.6.
Expand Down Expand Up @@ -80,7 +92,13 @@ def fused_matmul_bias(
return out


def fused_linear(x, weight, bias=None, transpose_weight=False, name=None):
def fused_linear(
x: Tensor,
weight: Tensor,
bias: Tensor | None = None,
transpose_weight: bool = False,
name: str | None = None,
) -> Tensor:
"""
Fully-connected linear transformation operator. This method requires CUDA version >= 11.6.
Expand Down Expand Up @@ -116,8 +134,13 @@ def fused_linear(x, weight, bias=None, transpose_weight=False, name=None):


def fused_linear_activation(
x, y, bias, trans_x=False, trans_y=False, activation=None
):
x: Tensor,
y: Tensor,
bias: Tensor,
trans_x: bool = False,
trans_y: bool = False,
activation: Literal['gelu', 'relu'] | None = None,
) -> Tensor:
"""
Fully-connected linear and activation transformation operator. This method requires CUDA version >= 11.6.
Expand Down

0 comments on commit 03e5e1a

Please sign in to comment.