diff --git a/python/paddle/base/data_feeder.py b/python/paddle/base/data_feeder.py index c914c792022873..248120a637af27 100644 --- a/python/paddle/base/data_feeder.py +++ b/python/paddle/base/data_feeder.py @@ -17,6 +17,7 @@ import numpy as np from paddle import pir +from paddle._typing.dtype_like import DTypeLike from ..pir import Value from ..pir.core import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE, ParameterMeta @@ -89,7 +90,7 @@ def convert_uint16_to_float(data): return np.reshape(new_data, data.shape) -def convert_dtype(dtype): +def convert_dtype(dtype: DTypeLike) -> str: if isinstance(dtype, core.VarDesc.VarType): if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE: return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 6ef53a757ee23f..a6cff759816687 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -16,15 +16,19 @@ import math import re +from typing import Any, Sequence, overload import numpy as np +import numpy.typing as npt import paddle from paddle import _C_ops from paddle._typing import ( DTypeLike, NestedNumbericSequence, + Numberic, PlaceLike, + ShapeLike, TensorLike, ) from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only @@ -54,7 +58,7 @@ __all__ = [] -def _complex_to_real_dtype(dtype): +def _complex_to_real_dtype(dtype: DTypeLike) -> DTypeLike: if dtype == core.VarDesc.VarType.COMPLEX64: return core.VarDesc.VarType.FP32 elif dtype == core.VarDesc.VarType.COMPLEX128: @@ -67,7 +71,7 @@ def _complex_to_real_dtype(dtype): return dtype -def _real_to_complex_dtype(dtype): +def _real_to_complex_dtype(dtype: DTypeLike) -> DTypeLike: if dtype == core.VarDesc.VarType.FP32: return core.VarDesc.VarType.COMPLEX64 elif dtype == core.VarDesc.VarType.FP64: @@ -81,8 +85,13 @@ def _real_to_complex_dtype(dtype): def create_global_var( - shape, value, dtype, persistable=False, force_cpu=False, name=None -): + shape: ShapeLike, + value: float, + dtype: DTypeLike, + persistable: bool = False, + force_cpu: bool = False, + name: str | None = None, +) -> paddle.Tensor: """ This function creates a new tensor variable with value in the global block(block 0). @@ -95,7 +104,7 @@ def create_global_var( Default: False force_cpu (bool, optional): Force this variable to be on CPU. Default: False - name (str, optional): For detailed information, please refer to + name (str|None, optional): For detailed information, please refer to :ref:`api_guide_Name` . Usually name is no need to set and None by default. Returns: @@ -162,8 +171,13 @@ def create_global_var( def create_parameter( - shape, dtype, name=None, attr=None, is_bias=False, default_initializer=None -): + shape: ShapeLike, + dtype: DTypeLike, + name: str | None = None, + attr: ParamAttr | None = None, + is_bias: bool = False, + default_initializer: paddle.nn.initializer.Initializer | None = None, +) -> paddle.Tensor: """ This function creates a parameter. The parameter is a learnable variable, which can have gradient, and can be optimized. @@ -174,15 +188,15 @@ def create_parameter( Args: shape (list of int): Shape of the parameter dtype (str): Data type of the parameter. It can be set as 'float16', 'float32', 'float64'. - name (str, optional): For detailed information, please refer to + name(str|None, optional): For detailed information, please refer to :ref:`api_guide_Name` . Usually name is no need to set and None by default. - attr (ParamAttr, optional): Attribute object of the specified argument. For detailed information, please refer to + attr (ParamAttr|None, optional): Attribute object of the specified argument. For detailed information, please refer to :ref:`api_paddle_ParamAttr` None by default, which means that ParamAttr will be initialized as it is. is_bias (bool, optional): This can affect which default initializer is chosen when default_initializer is None. If is_bias, initializer.Constant(0.0) will be used. Otherwise, Xavier() will be used. - default_initializer (Initializer, optional): Initializer for the parameter + default_initializer (Initializer|None, optional): Initializer for the parameter Returns: The created parameter. @@ -243,7 +257,9 @@ def create_parameter( ) -def create_tensor(dtype, name=None, persistable=False): +def create_tensor( + dtype: DTypeLike, name: str | None = None, persistable: bool = False +) -> paddle.Tensor: """ Create a variable, which will hold a Tensor with data type dtype. @@ -285,7 +301,13 @@ def create_tensor(dtype, name=None, persistable=False): ) -def linspace(start, stop, num, dtype=None, name=None): +def linspace( + start: float | paddle.Tensor, + stop: float | paddle.Tensor, + num: int | paddle.Tensor, + dtype: DTypeLike | None = None, + name: str | None = None, +) -> paddle.Tensor: r""" Return fixed number of evenly spaced values within a given interval. Note: no gradient calculation is performed. @@ -296,9 +318,9 @@ def linspace(start, stop, num, dtype=None, name=None): or a 0-D Tensor with data type int32, int64, float32 or float64. num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int, \ or a 0-D Tensor with data type int32. - dtype(np.dtype|str, optional): The data type of output tensor, it could be + dtype(str|paddle.dtype|np.dtype|None, optional): The data type of output tensor, it could be int32, int64, float32 and float64. Default: if None, the data type is float32. - name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: the output data type will be float32, float64. The 1-D tensor with fixed number of evenly spaced values, \ @@ -405,7 +427,14 @@ def linspace(start, stop, num, dtype=None, name=None): return out -def logspace(start, stop, num, base=10.0, dtype=None, name=None): +def logspace( + start: float | paddle.Tensor, + stop: float | paddle.Tensor, + num: int | paddle.Tensor, + base: float | paddle.Tensor = 10.0, + dtype: DTypeLike | None = None, + name: str | None = None, +) -> paddle.Tensor: r""" Return fixed number of logarithmical-evenly spaced values within the interval \ :math:`[base^{start}, base^{stop}]`. @@ -427,7 +456,7 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None): float32 or float64. dtype(np.dtype|str, optional): The data type of output tensor, it could be \ int32, int64, float32 or float64. Default: if None, the data type is float32. \ - name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: The output data type will be float32, float64. The 1-D tensor with \ @@ -558,14 +587,23 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None): return out -def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True): - def _handle_tensor_dtype(tensor, dtype): +def _to_tensor_non_static( + data: TensorLike, + dtype: DTypeLike | None = None, + place: PlaceLike | None = None, + stop_gradient: bool = True, +) -> paddle.Tensor: + def _handle_tensor_dtype( + tensor: paddle.Tensor, dtype: DTypeLike + ) -> paddle.Tensor: if dtype: if convert_dtype(dtype) != convert_dtype(tensor.dtype): return tensor.astype(convert_dtype(dtype)) return tensor - def _handle_np_dtype(ndarray, dtype): + def _handle_np_dtype( + ndarray: npt.NDArray[Any], dtype: DTypeLike + ) -> npt.NDArray[Any]: if dtype: if convert_dtype(dtype) != convert_dtype(ndarray.dtype): # should not ndarray.astype('uint16') directly, data bits is wrong @@ -658,7 +696,11 @@ def _handle_np_dtype(ndarray, dtype): ) -def _to_tensor_static(data, dtype=None, stop_gradient=None): +def _to_tensor_static( + data: TensorLike, + dtype: DTypeLike | None = None, + stop_gradient: bool = True, +) -> paddle.Tensor: if isinstance(data, (Variable, paddle.pir.Value)): output = data if dtype is not None and dtype != data.dtype: @@ -813,7 +855,12 @@ def to_tensor( return _to_tensor_static(data, dtype, stop_gradient) -def full_like(x, fill_value, dtype=None, name=None): +def full_like( + x: paddle.Tensor, + fill_value: bool | float, + dtype: DTypeLike | None = None, + name: str | None = None, +) -> paddle.Tensor: """ This function creates a tensor filled with ``fill_value`` which has identical shape of ``x`` and ``dtype``. @@ -825,7 +872,7 @@ def full_like(x, fill_value, dtype=None, name=None): dtype(np.dtype|str, optional): The data type of output. The data type can be one of bool, float16, float32, float64, int32, int64. The default value is None, which means the output data type is the same as input. - name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Tensor which is created according to ``x``, ``fill_value`` and ``dtype``. @@ -896,7 +943,14 @@ def full_like(x, fill_value, dtype=None, name=None): return out -def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): +def fill_constant( + shape: ShapeLike, + dtype: DTypeLike, + value: float | paddle.Tensor, + force_cpu: bool = False, + out: paddle.Tensor | None = None, + name: str | None = None, +) -> paddle.Tensor: if in_dynamic_or_pir_mode(): place = _current_expected_place() if force_cpu: @@ -926,11 +980,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): out = _C_ops.full(shape, value, dtype, place) out.stop_gradient = True return out - - if out is not None: - _C_ops.full_(out, shape, value, dtype, place) - out.stop_gradient = True - return out + _C_ops.full_(out, shape, value, dtype, place) + out.stop_gradient = True + return out else: attrs = {'force_cpu': force_cpu} @@ -996,7 +1048,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): return out -def ones(shape, dtype=None, name=None): +def ones( + shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None +) -> paddle.Tensor: """ Create a Tensor of specified :attr:`shape` and :attr:`dtype` and fill it with 1. @@ -1006,7 +1060,7 @@ def ones(shape, dtype=None, name=None): If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. dtype (np.dtype|str, optional): Data type of output Tensor, it should be one of bool, float16, float32, float64, int32 and int64. If it is set to None, the data type will be float32. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: A Tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements are 1. @@ -1044,7 +1098,9 @@ def ones(shape, dtype=None, name=None): return fill_constant(value=1.0, shape=shape, dtype=dtype, name=name) -def ones_like(x, dtype=None, name=None): +def ones_like( + x: paddle.Tensor, dtype: DTypeLike | None = None, name: str | None = None +) -> paddle.Tensor: """ Returns a Tensor filled with the value 1, with the same shape and data type (use ``dtype`` if ``dtype`` is not None) as ``x``. @@ -1056,7 +1112,7 @@ def ones_like(x, dtype=None, name=None): output tensor. Supported data types: bool, float16, float32, float64, int32, int64. If ``dtype`` is None, the data type is the same as ``x``. Default is None. - name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: A Tensor filled with the value 1, with the same shape and @@ -1079,7 +1135,11 @@ def ones_like(x, dtype=None, name=None): return full_like(x=x, fill_value=1, dtype=dtype, name=name) -def zeros(shape, dtype=None, name=None): +def zeros( + shape: ShapeLike, + dtype: DTypeLike | None = None, + name: str | None = None, +) -> paddle.Tensor: """ Creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0. @@ -1089,7 +1149,7 @@ def zeros(shape, dtype=None, name=None): If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. dtype(np.dtype|str, optional): Data type of output Tensor, it supports bool, float16, float32, float64, int32 and int64. Default: if None, the data type is float32. - name(str, optional): The default value is None. Normally there is no need for user to set this + name(str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -1128,7 +1188,9 @@ def zeros(shape, dtype=None, name=None): return fill_constant(value=0.0, shape=shape, dtype=dtype, name=name) -def zeros_like(x, dtype=None, name=None): +def zeros_like( + x: paddle.Tensor, dtype: DTypeLike | None = None, name: str | None = None +) -> paddle.Tensor: """ Returns a Tensor filled with the value 0, with the same shape and data type (use ``dtype`` if ``dtype`` is not None) as ``x``. @@ -1140,7 +1202,7 @@ def zeros_like(x, dtype=None, name=None): output tensor. Supported data types: bool, float16, float32, float64, int32, int64. If ``dtype`` is None, the data type is the same as ``x``. Default is None. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: A Tensor filled with the value 0, with the same shape and @@ -1164,19 +1226,24 @@ def zeros_like(x, dtype=None, name=None): return full_like(x=x, fill_value=0, dtype=dtype, name=name) -def eye(num_rows, num_columns=None, dtype=None, name=None): +def eye( + num_rows: int, + num_columns: int | None = None, + dtype: DTypeLike | None = None, + name: str | None = None, +) -> paddle.Tensor: """ This function constructs 2-D Tensor with ones on the diagonal and zeros elsewhere. Args: num_rows(int): the number of rows in each batch Tensor. - num_columns(int, optional): the number of columns in each batch Tensor. + num_columns(int|None, optional): the number of columns in each batch Tensor. If None, default: num_rows. dtype(np.dtype|str, optional): The data type of the returned Tensor. It should be int32, int64, float16, float32, float64, complex64, complex128. Default: if None, the data type is float32. - name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: An identity Tensor or LoDTensor of shape [num_rows, num_columns]. @@ -1252,7 +1319,12 @@ def _check_attr(attr, message): return out -def full(shape, fill_value, dtype=None, name=None): +def full( + shape: ShapeLike, + fill_value: bool | float | paddle.Tensor, + dtype: DTypeLike | None = None, + name: str | None = None, +) -> paddle.Tensor: """ Return a Tensor with the ``fill_value`` which size is same as ``shape``. @@ -1266,7 +1338,7 @@ def full(shape, fill_value, dtype=None, name=None): dtype(np.dtype|str, optional): Data type of the output Tensor which can be float16, float32, float64, int32, int64, if dtype is `None`, the data type of created Tensor is `float32`. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Tensor which is created according to ``shape``, ``fill_value`` and ``dtype``. @@ -1314,7 +1386,13 @@ def full(shape, fill_value, dtype=None, name=None): return fill_constant(shape=shape, dtype=dtype, value=fill_value, name=name) -def arange(start=0, end=None, step=1, dtype=None, name=None): +def arange( + start: float | paddle.Tensor = 0, + end: float | paddle.Tensor | None = None, + step: float | paddle.Tensor = 1, + dtype: DTypeLike | None = None, + name: str | None = None, +) -> paddle.Tensor: """ Returns a 1-D Tensor with spaced values within a given interval. @@ -1341,7 +1419,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): dtype(str|np.dtype, optional): The data type of the output tensor. Supported data types: int32, int64, float32, float64. If ``dtype`` is None, the data type is float32. Default is None. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: A 1-D Tensor with values from the interval [``start``, ``end``) @@ -1447,7 +1525,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): return out -def _tril_triu_op(helper): +def _tril_triu_op(helper: LayerHelper) -> paddle.Tensor: """Base op of tril_op and triu_op""" op_type = helper.layer_type x = helper.kwargs.get('x', None) @@ -1496,7 +1574,9 @@ def _tril_triu_op(helper): return out -def tril(x, diagonal=0, name=None): +def tril( + x: paddle.Tensor, diagonal: int = 0, name: str | None = None +) -> paddle.Tensor: r""" Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices :attr:`x`, the other elements of the result tensor are set @@ -1513,7 +1593,7 @@ def tril(x, diagonal=0, name=None): the main diagonal. The main diagonal are the set of indices :math:`\{(i, i)\}` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where :math:`d_{1}, d_{2}` are the dimensions of the matrix. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Results of lower triangular operation by the specified diagonal of input tensor x, @@ -1561,7 +1641,9 @@ def tril(x, diagonal=0, name=None): @inplace_apis_in_dygraph_only -def tril_(x, diagonal=0, name=None): +def tril_( + x: paddle.Tensor, diagonal: int = 0, name: str | None = None +) -> paddle.Tensor | None: r""" Inplace version of ``tril`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_tril`. @@ -1571,7 +1653,9 @@ def tril_(x, diagonal=0, name=None): return _C_ops.tril_(x, diagonal) -def triu(x, diagonal=0, name=None): +def triu( + x: paddle.Tensor, diagonal: int = 0, name: str | None = None +) -> paddle.Tensor: r""" Return the upper triangular part of a matrix (2-D tensor) or batch of matrices :attr:`x`, the other elements of the result tensor are set to 0. @@ -1588,7 +1672,7 @@ def triu(x, diagonal=0, name=None): the main diagonal. The main diagonal are the set of indices :math:`\{(i, i)\}` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where :math:`d_{1}, d_{2}` are the dimensions of the matrix. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Results of upper triangular operation by the specified diagonal of input tensor x, @@ -1638,7 +1722,9 @@ def triu(x, diagonal=0, name=None): @inplace_apis_in_dygraph_only -def triu_(x, diagonal=0, name=None): +def triu_( + x: paddle.Tensor, diagonal: int = 0, name: str | None = None +) -> paddle.Tensor | None: r""" Inplace version of ``triu`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_triu`. @@ -1648,6 +1734,18 @@ def triu_(x, diagonal=0, name=None): return _C_ops.triu_(x, diagonal) +@overload +def meshgrid( + args: Sequence[paddle.Tensor], name: str | None = None +) -> paddle.Tensor: + ... + + +@overload +def meshgrid(*args: paddle.Tensor, name: str | None = None) -> paddle.Tensor: + ... + + def meshgrid(*args, **kwargs): """ @@ -1722,7 +1820,9 @@ def meshgrid(*args, **kwargs): return out -def diag_embed(input, offset=0, dim1=-2, dim2=-1): +def diag_embed( + input: TensorLike, offset: int = 0, dim1: int = -2, dim2: int = -1 +) -> paddle.Tensor: """ Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by ``input``. By default, a 2D plane formed by the last two dimensions @@ -1841,7 +1941,9 @@ def __check_input(input, offset, dim1, dim2): return out -def diagflat(x, offset=0, name=None): +def diagflat( + x: paddle.Tensor, offset: int = 0, name: str | None = None +) -> paddle.Tensor: """ If ``x`` is a vector (1-D tensor), a 2-D square tensor with the elements of ``x`` as the diagonal is returned. @@ -1859,7 +1961,7 @@ def diagflat(x, offset=0, name=None): Args: x (Tensor): The input tensor. It can be any shape. Its data type should be float16, float32, float64, int32, int64. offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal. Default: 0 (main diagonal). - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor, a square matrix. The output data type is the same as input data type. @@ -1975,7 +2077,12 @@ def diagflat(x, offset=0, name=None): return out2 -def diag(x, offset=0, padding_value=0, name=None): +def diag( + x: paddle.Tensor, + offset: int = 0, + padding_value: int = 0, + name: str | None = None, +) -> paddle.Tensor: """ If ``x`` is a vector (1-D tensor), a 2-D square tensor with the elements of ``x`` as the diagonal is returned. @@ -1993,7 +2100,7 @@ def diag(x, offset=0, padding_value=0, name=None): x (Tensor): The input tensor. Its shape is either 1-D or 2-D. Its data type should be float16, float32, float64, int32, int64, complex64, complex128. offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal. padding_value (int|float, optional): Use this value to fill the area outside the specified diagonal band. Only takes effect when the input is a 1-D Tensor. The default value is 0. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor, a square matrix or a vector. The output data type is the same as input data type. @@ -2092,7 +2199,9 @@ def diag(x, offset=0, padding_value=0, name=None): return out -def empty(shape, dtype=None, name=None): +def empty( + shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None +) -> paddle.Tensor: """ Returns a Tensor with uninitialized data which size is same as ``shape``. @@ -2104,7 +2213,7 @@ def empty(shape, dtype=None, name=None): which can be bool, float16, float32, float64, int32, int64, complex64, complex128 if dtype is `None`, the data type of created Tensor use global default dtype (see ``get_default_dtype`` for details). - name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Tensor which is created according to ``shape`` and ``dtype``, and is uninitialized. @@ -2229,7 +2338,9 @@ def empty(shape, dtype=None, name=None): return out -def empty_like(x, dtype=None, name=None): +def empty_like( + x: paddle.Tensor, dtype: DTypeLike | None = None, name: str | None = None +) -> paddle.Tensor: """ Returns a Tensor with uninitialized data which has identical shape of ``x`` and ``dtype``. If the ``dtype`` is None, the data type of Tensor is same with ``x``. @@ -2239,7 +2350,7 @@ def empty_like(x, dtype=None, name=None): dtype(np.dtype|str, optional): The data type of output. The data type can be one of bool, float16, float32, float64, int32, int64. The default value is None, which means the output data type is the same as input. - name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: Tensor which is created according to ``x`` and ``dtype``, and is uninitialized. @@ -2339,7 +2450,7 @@ def empty_like(x, dtype=None, name=None): return out -def assign(x, output=None): +def assign(x: TensorLike, output: paddle.Tensor | None = None) -> paddle.Tensor: """ Copy value of the :attr:`x` to the :attr:`output`. @@ -2348,7 +2459,7 @@ def assign(x, output=None): x (Tensor|np.ndarray|list|tuple|scalar): A Tensor, numpy ndarray, tuple/list of scalar, or scalar. Its data type can be float16, float32, float64, int32, int64 or bool. Note: the float64 data will be converted to float32 because of current platform protobuf data limitation. - output (Tensor, optional): A Tensor. If :attr:`output` is None, a new Tensor will be created as :attr:`output`. Default: None. + output (Tensor|None, optional): A Tensor. If :attr:`output` is None, a new Tensor will be created as :attr:`output`. Default: None. Returns: Tensor: A Tensor with the same shape, data type and value as :attr:`x`. @@ -2363,9 +2474,9 @@ def assign(x, output=None): [[2.5 2.5] [2.5 2.5] [2.5 2.5]] - >>> array = np.array([[1, 1], - ... [3, 4], - ... [1, 3]]).astype(np.int64) + >>> array = np.array([[1, 1], [3, 4], [1, 3]]).astype( + ... np.int64 + ... ) # type: ignore >>> result1 = paddle.zeros(shape=[3, 3], dtype='float32') >>> paddle.assign(array, result1) >>> print(result1.numpy()) @@ -2547,7 +2658,7 @@ def convert_scalar(x): return output -def clone(x, name=None): +def clone(x: paddle.Tensor, name: str | None = None) -> paddle.Tensor: """ Returns a copy of input Tensor. It will always have a Tensor copy. @@ -2555,7 +2666,7 @@ def clone(x, name=None): Parameters: x (Tensor): The input Tensor. - name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor, A Tensor copied from ``input``. @@ -2574,16 +2685,16 @@ def clone(x, name=None): >>> y = clone_x**3 >>> y.backward() - >>> print(clone_x.grad.numpy()) + >>> print(clone_x.grad.numpy()) # type: ignore [3. 3.] - >>> print(x.grad.numpy()) + >>> print(x.grad.numpy()) # type: ignore [3. 3.] """ return x.clone() # NOTE(zhiqiu): not public -def _memcpy(input, place=None, output=None): +def _memcpy(input, place=None, output=None) -> paddle.Tensor: """ The OP copies the :attr:`input` to the :attr:`output`. @@ -2667,16 +2778,18 @@ def _memcpy(input, place=None, output=None): return output -def complex(real, imag, name=None): +def complex( + real: paddle.Tensor, imag: paddle.Tensor, name: str | None = None +) -> paddle.Tensor: """Return a complex tensor given the real and image component. Args: real (Tensor): The real component. The data type should be 'float32' or 'float64'. imag (Tensor): The image component. The data type should be the same as ``real``. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: - Tensor: The output tensor. The data type is 'complex64' or 'complex128', with the same precision as ``real`` and ``imag``. + Tensor, The output tensor. The data type is 'complex64' or 'complex128', with the same precision as ``real`` and ``imag``. Note: ``paddle.complex`` supports broadcasting. If you want know more about broadcasting, please refer to `Introduction to Tensor`_ . @@ -2719,7 +2832,9 @@ def complex(real, imag, name=None): return out -def tril_indices(row, col, offset=0, dtype='int64'): +def tril_indices( + row: int, col: int, offset: int = 0, dtype='int64' +) -> paddle.Tensor: """ Return the indices of the lower triangular part of the 2-D matrix whose row and col is known. Indices are ordered based on row and then columns. @@ -2803,7 +2918,9 @@ def tril_indices(row, col, offset=0, dtype='int64'): return out -def triu_indices(row, col=None, offset=0, dtype='int64'): +def triu_indices( + row: int, col: int | None = None, offset: int = 0, dtype='int64' +) -> paddle.Tensor: """ Return the indices of the upper triangular part of the 2-D matrix whose row and col is known. Indices are ordered based on row and then columns. @@ -2812,7 +2929,7 @@ def triu_indices(row, col=None, offset=0, dtype='int64'): Args: row (int): The input x which is a int number describe the number of row of the matrix. - col (int, optional): The input x which is a int number describe the number of col of the matrix. + col (int|None, optional): The input x which is a int number describe the number of col of the matrix. default value for col is None, then it will be set equal to row, indicting a square matrix. offset (int, optional): The offset to consider, default value is 0. @@ -2882,16 +2999,18 @@ def triu_indices(row, col=None, offset=0, dtype='int64'): return out -def polar(abs, angle, name=None): +def polar( + abs: paddle.Tensor, angle: paddle.Tensor, name: str | None = None +) -> paddle.Tensor: """Return a Cartesian coordinates corresponding to the polar coordinates complex tensor given the ``abs`` and ``angle`` component. Args: abs (Tensor): The abs component. The data type should be 'float32' or 'float64'. angle (Tensor): The angle component. The data type should be the same as ``abs``. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: - Tensor: The output tensor. The data type is 'complex64' or 'complex128', with the same precision as ``abs`` and ``angle``. + Tensor, The output tensor. The data type is 'complex64' or 'complex128', with the same precision as ``abs`` and ``angle``. Note: ``paddle.polar`` supports broadcasting. If you want know more about broadcasting, please refer to `Introduction to Tensor`_ . @@ -2921,14 +3040,19 @@ def polar(abs, angle, name=None): @dygraph_only -def cauchy_(x, loc=0, scale=1, name=None): +def cauchy_( + x: paddle.Tensor, + loc: Numberic = 0, + scale: Numberic = 1, + name: str | None = None, +) -> paddle.Tensor: """Fills the tensor with numbers drawn from the Cauchy distribution. Args: x (Tensor): the tensor will be filled, The data type is float32 or float64. loc (scalar, optional): Location of the peak of the distribution. The data type is float32 or float64. scale (scalar, optional): The half-width at half-maximum (HWHM). The data type is float32 or float64. Must be positive values. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: input tensor with numbers drawn from the Cauchy distribution. @@ -2955,14 +3079,18 @@ def cauchy_(x, loc=0, scale=1, name=None): @dygraph_only -def geometric_(x, probs, name=None): +def geometric_( + x: paddle.Tensor, + probs: float | paddle.Tensor, + name: str | None = None, +) -> paddle.Tensor: """Fills the tensor with numbers drawn from the Geometric distribution. Args: x (Tensor): the tensor will be filled, The data type is float32 or float64. - probs (Real|Tensor): Probability parameter. + probs (float|Tensor): Probability parameter. The value of probs must be positive. When the parameter is a tensor, probs is probability of success for each trial. - name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: input tensor with numbers drawn from the Geometric distribution.