Skip to content

Commit

Permalink
[Typing][C-13,C-14] Add type annotations for `python/paddle/distribut…
Browse files Browse the repository at this point in the history
…ed/communication/{all_reduce,all_to_all}.py` (PaddlePaddle#66505)
  • Loading branch information
enkilee authored and inaomIIsfarell committed Jul 31, 2024
1 parent 4fc6d0a commit 986dc11
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 13 deletions.
17 changes: 15 additions & 2 deletions python/paddle/distributed/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import paddle
from paddle.distributed.communication import stream
from paddle.distributed.communication.reduce import ReduceOp

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group


def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
def all_reduce(
tensor: Tensor,
op: ReduceOp = ReduceOp.SUM,
group: Group | None = None,
sync_op: bool = True,
) -> task:
"""
Reduce a tensor over all ranks so that all get the result.
Expand All @@ -34,7 +47,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD|ReduceOp.AVG, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
group (Group|None, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
Returns:
Expand Down
35 changes: 24 additions & 11 deletions python/paddle/distributed/communication/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from paddle.distributed.communication import stream

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import task
from paddle.distributed.communication.group import Group


def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
def alltoall(
in_tensor_list: list[Tensor],
out_tensor_list: list[Tensor],
group: Group | None = None,
sync_op: bool = True,
) -> task:
"""
Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list.
As shown below, the in_tensor_list in GPU0 includes 0_0 and 0_1, and GPU1 includes 1_0 and 1_1.
Expand Down Expand Up @@ -63,13 +76,13 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):


def alltoall_single(
in_tensor,
out_tensor,
in_split_sizes=None,
out_split_sizes=None,
group=None,
sync_op=True,
):
in_tensor: Tensor,
out_tensor: Tensor,
in_split_sizes: list[int] | None = None,
out_split_sizes: list[int] | None = None,
group: Group | None = None,
sync_op: bool = True,
) -> task:
"""
Scatter a single input tensor to all participators and gather the received tensors in out_tensor.
Expand All @@ -79,11 +92,11 @@ def alltoall_single(
Args:
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
in_split_sizes (list[int]|None, optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
out_split_sizes (list[int]|None, optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
group (Group|None, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Returns:
Expand Down

0 comments on commit 986dc11

Please sign in to comment.