Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][A-3] Add type annotations for paddle/tensor/creation.py #65082

Merged
merged 18 commits into from
Jun 17, 2024
1 change: 0 additions & 1 deletion python/paddle/_typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@

# Shape
from .shape import (
DynamicShapeLike as DynamicShapeLike,
ShapeLike as ShapeLike,
Size1 as Size1,
Size2 as Size2,
Expand Down
14 changes: 7 additions & 7 deletions python/paddle/_typing/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, List, Tuple, Union
from typing import TYPE_CHECKING, List, Sequence, Tuple, Union

from typing_extensions import TypeAlias

if TYPE_CHECKING:
from .. import Tensor

DynamicShapeLike: TypeAlias = Union[
Tuple[Union[int, "Tensor", None], ...],
List[Union[int, "Tensor", None]],

_DynamicShapeLike: TypeAlias = Union[
Sequence[Union[int, "Tensor", None]],
"Tensor",
]


ShapeLike: TypeAlias = Union[
Tuple[int, ...],
List[int],
_StaticShapeLike: TypeAlias = Union[
Sequence[int],
"Tensor",
]

ShapeLike: TypeAlias = Union[_DynamicShapeLike, _StaticShapeLike]

# for size parameters, eg, kernel_size, stride ...
Size1: TypeAlias = Union[int, Tuple[int], List[int]]
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/base/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

from paddle import pir
from paddle._typing.dtype_like import DTypeLike

from ..pir import Value
from ..pir.core import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE, ParameterMeta
Expand Down Expand Up @@ -89,7 +90,7 @@ def convert_uint16_to_float(data):
return np.reshape(new_data, data.shape)


def convert_dtype(dtype):
def convert_dtype(dtype: DTypeLike) -> str:
if isinstance(dtype, core.VarDesc.VarType):
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
Expand Down
Loading