diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 2542ecf864da..9b6f4a5195e5 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1775,11 +1775,20 @@ def test_send_recv(self): with self.assertRaises(ValueError): dist.send(input_tensor, dist.get_rank()) + with self.assertRaises(ValueError): + dist.send(input_tensor, group_dst=dist.get_rank()) + + with self.assertRaises(ValueError): + dist.send(input_tensor, dist.get_rank(), group_dst=dist.get_rank()) + with self.assertRaises(ValueError): + dist.send(input_tensor) # test recv input_tensor = torch.zeros(2, 2) dist.recv(input_tensor, (self.rank + 1) % self.world_size) self.assertEqual(input_tensor, torch.zeros(2, 2) + 2) + with self.assertRaises(ValueError): + dist.recv(input_tensor, src=0, group_src=0) dist.barrier() # intentionally not calling into `destroy_process_group` as not all diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index dbb520cf7136..1c9d54621cb7 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3825,8 +3825,9 @@ def test_reduce_subgroup(self): @requires_nccl() @skip_if_lt_x_gpu(4) + @parametrize("group_rank", [True, False]) @parametrize("async_op", [True, False]) - def test_send_recv_subgroup(self, async_op): + def test_send_recv_subgroup(self, async_op, group_rank): world_size = 4 if self.rank >= world_size: return @@ -3835,17 +3836,29 @@ def test_send_recv_subgroup(self, async_op): if self.rank == 0 or self.rank == 2: x = torch.empty((10,), device=device) if async_op: - c10d.irecv(x, src=self.rank + 1, group=subgroup).wait() + if group_rank: + c10d.irecv(x, group_src=1, group=subgroup).wait() + else: + c10d.irecv(x, src=self.rank + 1, group=subgroup).wait() else: - c10d.recv(x, src=self.rank + 1, group=subgroup) + if group_rank: + c10d.recv(x, group_src=1, group=subgroup) + else: + c10d.recv(x, src=self.rank + 1, group=subgroup) expected = torch.ones((10,), device=device) * (self.rank + 1) self.assertEqual(x, expected) else: x = torch.ones((10,), device=device) * self.rank if async_op: - c10d.isend(x, dst=self.rank - 1, group=subgroup).wait() + if group_rank: + c10d.isend(x, group_dst=0, group=subgroup).wait() + else: + c10d.isend(x, dst=self.rank - 1, group=subgroup).wait() else: - c10d.send(x, dst=self.rank - 1, group=subgroup) + if group_rank: + c10d.send(x, group_dst=0, group=subgroup) + else: + c10d.send(x, dst=self.rank - 1, group=subgroup) @requires_nccl() @skip_if_lt_x_gpu(4) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 8c8d7f2a8d8e..25c34073694c 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1112,6 +1112,38 @@ def _check_tensor_list(param, param_name) -> None: ) +def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGroup: + if group is None or group is GroupMember.WORLD: + group = _get_default_group() + return group + + +def _canonicalize_group_rank( + group: ProcessGroup, + global_rank: Optional[int] = None, + group_rank: Optional[int] = None, +) -> int: + """ + Helper method to take _either_ a global rank or a group rank and produce a group rank. + """ + if group_rank is not None: + if global_rank is not None: + raise ValueError("Can't specify both group_rank and global_rank") + else: + if global_rank is None: + raise ValueError("Must specify global_rank or group_rank") + group_rank = get_group_rank(group, global_rank) + return group_rank + + +def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str): + if group.rank() == rank: + raise ValueError( + f"Invalid {rank_type} rank: {rank_type} rank should not be the same as " + "the rank of the current process." + ) + + def _as_iterable(obj) -> collections.abc.Iterable: return obj if isinstance(obj, list) else (obj,) @@ -2217,7 +2249,11 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int: def isend( - tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 + tensor: torch.Tensor, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_dst: Optional[int] = None, ) -> Optional[Work]: """ Send a tensor asynchronously. @@ -2229,18 +2265,23 @@ def isend( .. warning:: ``tag`` is not supported with the NCCL backend. + Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self. + Args: tensor (Tensor): Tensor to send. dst (int): Destination rank on global process group (regardless of ``group`` argument) group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. tag (int, optional): Tag to match send with remote recv + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` Returns: A distributed request object. None, if not part of the group """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): _warn_not_in_group("isend") @@ -2249,13 +2290,7 @@ def isend( if tensor.is_complex(): tensor = torch.view_as_real(tensor) - if group is None or group is GroupMember.WORLD: - pg = _get_default_group() - else: - pg = group - dst = get_group_rank(pg, dst) - - return pg.send([tensor], dst, tag) + return group.send([tensor], group_dst, tag) def irecv( @@ -2263,6 +2298,7 @@ def irecv( src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0, + group_src: Optional[int] = None, ) -> Optional[Work]: """ Receives a tensor asynchronously. @@ -2270,6 +2306,8 @@ def irecv( .. warning:: ``tag`` is not supported with the NCCL backend. + Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self. + Args: tensor (Tensor): Tensor to fill with received data. src (int, optional): Source rank on global process group (regardless of ``group`` argument). @@ -2277,6 +2315,7 @@ def irecv( group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. tag (int, optional): Tag to match recv with remote send + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. Returns: A distributed request object. @@ -2291,24 +2330,21 @@ def irecv( if tensor.is_complex(): tensor = torch.view_as_real(tensor) - if group is None or group is GroupMember.WORLD: - pg = _get_default_group() - else: - pg = group - - if src is None: - return pg.recv_anysource([tensor], tag) + group = _group_or_default_group(group) + if src is None and group_src is None: + return group.recv_anysource([tensor], tag) else: - if pg is GroupMember.WORLD: - return pg.recv([tensor], src, tag) - else: - group_src_rank = get_group_rank(pg, src) - return pg.recv([tensor], group_src_rank, tag) + group_src = _canonicalize_group_rank(group, src, group_src) + return group.recv([tensor], group_src, tag) @_exception_logger def send( - tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 + tensor: torch.Tensor, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, + group_dst: Optional[int] = None, ) -> None: """ Send a tensor synchronously. @@ -2323,14 +2359,12 @@ def send( group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. tag (int, optional): Tag to match send with remote recv + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``. """ - if get_rank() == dst: - raise ValueError( - "Invalid destination rank: destination rank should not be the same as " - "the rank of the current process." - ) - + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_not_self_rank(group, group_dst, "destination") _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): _warn_not_in_group("send") @@ -2339,12 +2373,7 @@ def send( if tensor.is_complex(): tensor = torch.view_as_real(tensor) - if group is None or group is GroupMember.WORLD: - default_pg = _get_default_group() - default_pg.send([tensor], dst, tag).wait() - else: - group_dst_rank = get_group_rank(group, dst) - group.send([tensor], group_dst_rank, tag).wait() + group.send([tensor], group_dst, tag).wait() @_exception_logger @@ -2353,6 +2382,7 @@ def recv( src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0, + group_src: Optional[int] = None, ) -> int: """ Receives a tensor synchronously. @@ -2367,7 +2397,7 @@ def recv( group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. tag (int, optional): Tag to match recv with remote send - + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. Returns: Sender rank -1, if not part of the group @@ -2381,23 +2411,18 @@ def recv( if tensor.is_complex(): tensor = torch.view_as_real(tensor) - pg = group or _get_default_group() + group = _group_or_default_group(group) - if src is None: - work = pg.recv_anysource([tensor], tag) + if src is None and group_src is None: + work = group.recv_anysource([tensor], tag) work.wait() src_rank = work._source_rank() - if group is None or group is GroupMember.WORLD: - return src_rank - else: - return get_global_rank(pg, src_rank) + return get_global_rank(group, src_rank) else: - if group is None or group is GroupMember.WORLD: - pg.recv([tensor], src, tag).wait() - else: - group_src_rank = get_group_rank(pg, src) - pg.recv([tensor], group_src_rank, tag).wait() - return src + group_src = _canonicalize_group_rank(group, src, group_src) + _check_not_self_rank(group, group_src, "source") + group.recv([tensor], group_src, tag).wait() + return get_global_rank(group, group_src) class _IllegalWork(Work):