diff --git a/python/paddle/distributed/communication/all_reduce.py b/python/paddle/distributed/communication/all_reduce.py index bef362a43cb7c9..6cb89eb1513f4d 100644 --- a/python/paddle/distributed/communication/all_reduce.py +++ b/python/paddle/distributed/communication/all_reduce.py @@ -11,13 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING import paddle from paddle.distributed.communication import stream from paddle.distributed.communication.reduce import ReduceOp +if TYPE_CHECKING: + from paddle import Tensor + from paddle.base.core import task + from paddle.distributed.communication.group import Group + -def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True): +def all_reduce( + tensor: Tensor, + op: ReduceOp = ReduceOp.SUM, + group: Group | None = None, + sync_op: bool = True, +) -> task: """ Reduce a tensor over all ranks so that all get the result. @@ -34,7 +47,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True): tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD|ReduceOp.AVG, optional): The operation used. Default value is ReduceOp.SUM. - group (Group, optional): The group instance return by new_group or None for global default group. + group (Group|None, optional): The group instance return by new_group or None for global default group. sync_op (bool, optional): Wether this op is a sync op. Default value is True. Returns: diff --git a/python/paddle/distributed/communication/all_to_all.py b/python/paddle/distributed/communication/all_to_all.py index 2d4aed8fe304c2..a7ca91142270d5 100644 --- a/python/paddle/distributed/communication/all_to_all.py +++ b/python/paddle/distributed/communication/all_to_all.py @@ -11,11 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING from paddle.distributed.communication import stream +if TYPE_CHECKING: + from paddle import Tensor + from paddle.base.core import task + from paddle.distributed.communication.group import Group + -def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True): +def alltoall( + in_tensor_list: list[Tensor], + out_tensor_list: list[Tensor], + group: Group | None = None, + sync_op: bool = True, +) -> task: """ Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list. As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1. @@ -63,13 +76,13 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True): def alltoall_single( - in_tensor, - out_tensor, - in_split_sizes=None, - out_split_sizes=None, - group=None, - sync_op=True, -): + in_tensor: Tensor, + out_tensor: Tensor, + in_split_sizes: list[int] | None = None, + out_split_sizes: list[int] | None = None, + group: Group | None = None, + sync_op: bool = True, +) -> task: """ Scatter a single input tensor to all participators and gather the received tensors in out_tensor. @@ -79,11 +92,11 @@ def alltoall_single( Args: in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor. - in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor`` + in_split_sizes (list[int]|None, optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor`` must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None. - out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor`` + out_split_sizes (list[int]|None, optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor`` must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None. - group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None. + group (Group|None, optional): The group instance return by ``new_group`` or None for global default group. Default: None. sync_op (bool, optional): Whether this op is a sync op. The default value is True. Returns: