diff --git a/python/paddle/distributed/communication/reduce.py b/python/paddle/distributed/communication/reduce.py index 265f84901c5a53..e27e2f83e37b18 100644 --- a/python/paddle/distributed/communication/reduce.py +++ b/python/paddle/distributed/communication/reduce.py @@ -12,10 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Literal + import paddle from paddle import framework from paddle.distributed.communication import stream +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from paddle import Tensor + from paddle.base.core import task + from paddle.distributed.communication.group import Group + + _ReduceOp: TypeAlias = Literal[0, 1, 2, 3, 4] + class ReduceOp: """ @@ -48,11 +61,11 @@ class ReduceOp: >>> # [[5, 7, 9], [5, 7, 9]] (2 GPUs) """ - SUM = 0 - MAX = 1 - MIN = 2 - PROD = 3 - AVG = 4 + SUM: ClassVar[Literal[0]] = 0 + MAX: ClassVar[Literal[1]] = 1 + MIN: ClassVar[Literal[2]] = 2 + PROD: ClassVar[Literal[3]] = 3 + AVG: ClassVar[Literal[4]] = 4 def _get_reduce_op(reduce_op, func_name): @@ -86,7 +99,13 @@ def _to_inplace_op(op_name): return f"{op_name}_" -def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True): +def reduce( + tensor: Tensor, + dst: int, + op: _ReduceOp = ReduceOp.SUM, + group: Group | None = None, + sync_op: bool = True, +) -> task: """ Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented @@ -103,7 +122,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True): should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. dst (int): The destination rank id. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD|ReduceOp.AVG, optional): The operation used. Default value is ReduceOp.SUM. - group (Group, optional): The group instance return by new_group or None for global default group. + group (Group|None, optional): The group instance return by new_group or None for global default group. sync_op (bool, optional): Whether this op is a sync op. The default value is True. Returns: @@ -207,7 +226,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True): raise ValueError(f"Unknown parameter: {op}.") -def is_avg_reduce_op_supported(): +def is_avg_reduce_op_supported() -> bool: if paddle.is_compiled_with_cuda(): return paddle.base.core.nccl_version() >= 21000 else: