Skip to content

Commit

Permalink
[C10D] Support group_dst/group_src in c10d send/recv object_list (pyt…
Browse files Browse the repository at this point in the history
…orch#140847)

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Pull Request resolved: pytorch#140847
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#140843
  • Loading branch information
wconstab authored and pobin6 committed Dec 5, 2024
1 parent 8f7dd05 commit 52985f3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
19 changes: 16 additions & 3 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3928,7 +3928,10 @@ def test_broadcast_subgroup(self, group_rank):
"set_device",
[SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
)
def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
@parametrize("group_rank", [True, False])
def test_send_recv_object_list_subgroup(
self, set_device: SetDeviceMethod, group_rank
):
world_size = 4
if self.rank >= world_size:
return
Expand All @@ -3940,12 +3943,22 @@ def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
device = torch.device("cuda:%d" % self.rank)
if self.rank == 0 or self.rank == 2:
x = [{}]
c10d.recv_object_list(x, src=self.rank + 1, group=subgroup, device=device)
if group_rank:
c10d.recv_object_list(x, group_src=1, group=subgroup, device=device)
else:
c10d.recv_object_list(
x, src=self.rank + 1, group=subgroup, device=device
)
expected = [{"rank": self.rank + 1}]
self.assertEqual(x, expected)
else:
x = [{"rank": self.rank}]
c10d.send_object_list(x, dst=self.rank - 1, group=subgroup, device=device)
if group_rank:
c10d.send_object_list(x, group_dst=0, group=subgroup, device=device)
else:
c10d.send_object_list(
x, dst=self.rank - 1, group=subgroup, device=device
)

@requires_nccl()
@skip_if_lt_x_gpu(4)
Expand Down
36 changes: 24 additions & 12 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3087,7 +3087,13 @@ def gather_object(


@_exception_logger
def send_object_list(object_list, dst, group=None, device=None):
def send_object_list(
object_list: List[Any],
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group_dst: Optional[int] = None,
):
"""
Sends picklable objects in ``object_list`` synchronously.
Expand All @@ -3105,7 +3111,8 @@ def send_object_list(object_list, dst, group=None, device=None):
device (``torch.device``, optional): If not None, the objects are
serialized and converted to tensors which are moved to the
``device`` before sending. Default is ``None``.
group_dst (int, optional): Destination rank on ``group``.
Must specify one of ``dst`` and ``group_dst`` but not both
Returns:
``None``.
Expand Down Expand Up @@ -3143,11 +3150,9 @@ def send_object_list(object_list, dst, group=None, device=None):
>>> objects
['foo', 12, {1: 2}]
"""
if get_rank() == dst:
raise ValueError(
"Invalid destination rank: destination rank should not be the same as "
"the rank of the current process."
)
group = _group_or_default_group(group)
group_dst = _canonicalize_group_rank(group, dst, group_dst)
_check_not_self_rank(group, group_dst, "destination")

if _rank_not_in_group(group):
_warn_not_in_group("send_object_list")
Expand All @@ -3167,7 +3172,7 @@ def send_object_list(object_list, dst, group=None, device=None):
object_sizes_tensor = torch.cat(size_list)

# Send object sizes
send(object_sizes_tensor, dst=dst, group=group)
send(object_sizes_tensor, group_dst=group_dst, group=group)

# Concatenate and send serialized object tensors
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
Expand All @@ -3177,11 +3182,17 @@ def send_object_list(object_list, dst, group=None, device=None):
else:
object_tensor = torch.cat(tensor_list)

send(object_tensor, dst=dst, group=group)
send(object_tensor, group_dst=group_dst, group=group)


@_exception_logger
def recv_object_list(object_list, src=None, group=None, device=None):
def recv_object_list(
object_list: List[Any],
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
group_src: Optional[int] = None,
):
"""
Receives picklable objects in ``object_list`` synchronously.
Expand All @@ -3197,6 +3208,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
the default process group will be used. Default is ``None``.
device (``torch.device``, optional): If not None, receives on this device.
Default is ``None``.
group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
Returns:
Sender rank. -1 if rank is not part of the group. If rank is part of the group,
Expand Down Expand Up @@ -3252,7 +3264,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
)

# Receive object sizes
rank_sizes = recv(object_sizes_tensor, src=src, group=group)
rank_sizes = recv(object_sizes_tensor, src=src, group=group, group_src=group_src)

# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
Expand All @@ -3261,7 +3273,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
device=current_device,
)

rank_objects = recv(object_tensor, src=src, group=group)
rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src)
assert (
rank_sizes == rank_objects
), "Mismatch in return ranks for object sizes and objects."
Expand Down

0 comments on commit 52985f3

Please sign in to comment.