From e4252ea622952b7478ecd9fd552ecc350b06b16f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Sun, 17 Nov 2024 10:20:48 -0800 Subject: [PATCH] [C10D] Support group_dst in scatter/gather (+object) ops (#140827) Also add missing mypy typing and a few asserts to make mypy happy Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 Note: object collective version canonicalizes to global instead of group rank, simply becuase this left more of the original code intact and required less conversions overall. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140827 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_nccl.py | 133 +++++++++++----- .../sharding_spec/chunk_sharding_spec.py | 10 +- torch/distributed/distributed_c10d.py | 145 +++++++++++------- 3 files changed, 197 insertions(+), 91 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 1c9d54621cb75d..504b18944e704a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3747,7 +3747,8 @@ def _init_two_pg2_subgroups(self, world_size: int = 4): @requires_nccl() @skip_if_lt_x_gpu(4) - def test_gather_subgroup(self): + @parametrize("group_rank", [True, False]) + def test_gather_subgroup(self, group_rank): world_size = 4 if self.rank >= world_size: # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later @@ -3758,28 +3759,48 @@ def test_gather_subgroup(self): input = torch.ones((10,), device=device) * self.rank if self.rank == 0 or self.rank == 2: gather_list = [torch.empty_like(input) for _ in range(subgroup.size())] - torch.distributed.gather( - input, - gather_list=gather_list, - dst=self.rank, - group=subgroup, - async_op=False, - ) + if group_rank: + # global_dst=0 group_dst=0 my_global_rank=2 gather_list is not None=True + torch.distributed.gather( + input, + gather_list=gather_list, + group_dst=0, + group=subgroup, + async_op=False, + ) + else: + torch.distributed.gather( + input, + gather_list=gather_list, + dst=self.rank, + group=subgroup, + async_op=False, + ) for src in range(len(gather_list)): expected = (torch.ones_like(input) * self.rank) + src self.assertEqual(gather_list[src], expected) else: - torch.distributed.gather( - input, - gather_list=None, - dst=self.rank - 1, - group=subgroup, - async_op=False, - ) + if group_rank: + torch.distributed.gather( + input, + gather_list=None, + group_dst=0, + group=subgroup, + async_op=False, + ) + else: + torch.distributed.gather( + input, + gather_list=None, + dst=self.rank - 1, + group=subgroup, + async_op=False, + ) @requires_nccl() @skip_if_lt_x_gpu(4) - def test_gather_object_subgroup(self): + @parametrize("group_rank", [True, False]) + def test_gather_object_subgroup(self, group_rank): world_size = 4 if self.rank >= world_size: # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later @@ -3797,15 +3818,25 @@ def test_gather_object_subgroup(self): # another weird thing- what's the point of making me specify some empty objects in my list? # empty list should be valid imo. (but it throws an error) gather_list = [{}, {}] - torch.distributed.gather_object( - input, object_gather_list=gather_list, dst=self.rank, group=subgroup - ) + if group_rank: + torch.distributed.gather_object( + input, object_gather_list=gather_list, group_dst=0, group=subgroup + ) + else: + torch.distributed.gather_object( + input, object_gather_list=gather_list, dst=self.rank, group=subgroup + ) for src in range(len(gather_list)): self.assertEqual(gather_list[src]["rank"], self.rank + src) else: - torch.distributed.gather_object( - input, object_gather_list=None, dst=self.rank - 1, group=subgroup - ) + if group_rank: + torch.distributed.gather_object( + input, object_gather_list=None, group_dst=0, group=subgroup + ) + else: + torch.distributed.gather_object( + input, object_gather_list=None, dst=self.rank - 1, group=subgroup + ) @requires_nccl() @skip_if_lt_x_gpu(4) @@ -3931,7 +3962,8 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod): @requires_nccl() @skip_if_lt_x_gpu(4) - def test_scatter_subgroup(self): + @parametrize("group_rank", [True, False]) + def test_scatter_subgroup(self, group_rank): world_size = 4 if self.rank >= world_size: return @@ -3940,18 +3972,27 @@ def test_scatter_subgroup(self): x = torch.empty((10,), device=device) expected = torch.ones((10,), device=device) * self.rank if self.rank == 0 or self.rank == 2: - c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup) + if group_rank: + c10d.scatter(x, scatter_list=None, group_src=1, group=subgroup) + else: + c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup) else: scatter_list = [ torch.ones((10,), device=device) * (self.rank - 1), torch.ones((10,), device=device) * self.rank, ] - c10d.scatter(x, scatter_list=scatter_list, src=self.rank, group=subgroup) + if group_rank: + c10d.scatter(x, scatter_list=scatter_list, group_src=1, group=subgroup) + else: + c10d.scatter( + x, scatter_list=scatter_list, src=self.rank, group=subgroup + ) self.assertEqual(x, expected) @requires_nccl() @skip_if_lt_x_gpu(4) - def test_scatter_object_list_subgroup(self): + @parametrize("group_rank", [True, False]) + def test_scatter_object_list_subgroup(self, group_rank): world_size = 4 if self.rank >= world_size: return @@ -3960,24 +4001,40 @@ def test_scatter_object_list_subgroup(self): scatter_object_output_list = [None] expected = [{"rank": self.rank}] if self.rank == 0 or self.rank == 2: - c10d.scatter_object_list( - scatter_object_output_list=scatter_object_output_list, - scatter_object_input_list=None, - src=self.rank + 1, - group=subgroup, - ) + if group_rank: + c10d.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=None, + group_src=1, + group=subgroup, + ) + else: + c10d.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=None, + src=self.rank + 1, + group=subgroup, + ) else: scatter_object_input_list = [ {"rank": self.rank - 1}, {"rank": self.rank}, ] - c10d.scatter_object_list( - scatter_object_output_list=scatter_object_output_list, - scatter_object_input_list=scatter_object_input_list, - src=self.rank, - group=subgroup, - ) + if group_rank: + c10d.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=scatter_object_input_list, + group_src=1, + group=subgroup, + ) + else: + c10d.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=scatter_object_input_list, + src=self.rank, + group=subgroup, + ) self.assertEqual(scatter_object_output_list, expected) diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py index dd0e354dfc25cc..5101ed3b8da591 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -132,6 +132,7 @@ def shard( local_shards = [] local_tensor = None local_metadata = None + tensors_to_scatter = cast( List[Optional[torch.Tensor]], [None] * dist.get_world_size(process_group), @@ -192,9 +193,16 @@ def shard( process_group, src_for_scatter ) + tensors_to_scatter_: Optional[List[torch.Tensor]] = None + if current_rank == src_rank: + tensors_to_scatter_ = [] + for t in tensors_to_scatter: + assert isinstance(t, torch.Tensor) + tensors_to_scatter_.append(t) + dist.scatter( local_tensor, - scatter_list=tensors_to_scatter if current_rank == src_rank else None, + scatter_list=tensors_to_scatter_, src=src_for_scatter, group=process_group, ) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index fd137150997516..a2a3eeaf2b5385 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1122,18 +1122,23 @@ def _canonicalize_group_rank( group: ProcessGroup, global_rank: Optional[int] = None, group_rank: Optional[int] = None, + return_global: bool = False, ) -> int: """ Helper method to take _either_ a global rank or a group rank and produce a group rank. + + If 'return_global' is true, produce a global rank instead of a group rank. """ + if group_rank is not None: if global_rank is not None: raise ValueError("Can't specify both group_rank and global_rank") + global_rank = get_global_rank(group, group_rank) else: if global_rank is None: raise ValueError("Must specify global_rank or group_rank") group_rank = get_group_rank(group, global_rank) - return group_rank + return global_rank if return_global else group_rank def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str): @@ -2951,7 +2956,13 @@ def all_gather_object(object_list, obj, group=None): @_exception_logger -def gather_object(obj, object_gather_list=None, dst=0, group=None): +def gather_object( + obj: Any, + object_gather_list: Optional[List[Any]] = None, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + group_dst: Optional[int] = None, +): """ Gathers picklable objects from the whole group in a single process. @@ -2964,9 +2975,11 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): should be correctly sized as the size of the group for this collective and will contain the output. Must be ``None`` on non-dst ranks. (default is ``None``) - dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) group: (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` Returns: None. On the ``dst`` rank, ``object_gather_list`` will contain the @@ -3010,13 +3023,17 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): >>> output ['foo', 12, {1: 2}] """ + group = _group_or_default_group(group) + if dst is None and group_dst is None: + dst = 0 + global_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=True) if _rank_not_in_group(group): _warn_not_in_group("gather_object") return # Ensure object_gather_list is specified appropriately. - my_rank = get_rank() - _validate_output_list_for_rank(my_rank, dst, object_gather_list) + my_global_rank = get_rank() + _validate_output_list_for_rank(my_global_rank, global_dst, object_gather_list) current_device = _get_object_coll_device(group) input_tensor, local_size = _object_to_tensor(obj, current_device, group) @@ -3037,7 +3054,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) # Avoid populating output tensors if the result won't be gathered on this rank. - if my_rank == dst: + if my_global_rank == global_dst: coalesced_output_tensor = torch.empty( max_object_size * group_size, dtype=torch.uint8, device=current_device ) @@ -3049,12 +3066,14 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): # All ranks call gather with equal-sized tensors. gather( input_tensor, - gather_list=output_tensors if my_rank == dst else None, # type: ignore[possibly-undefined] - dst=dst, + gather_list=output_tensors if my_global_rank == global_dst else None, # type: ignore[possibly-undefined] + dst=global_dst, group=group, ) - if my_rank != dst: + if my_global_rank != global_dst: return + + assert object_gather_list is not None, "Must provide object_gather_list on dst rank" for i, tensor in enumerate(output_tensors): tensor = tensor.type(torch.uint8) tensor_size = object_size_list[i] @@ -3366,7 +3385,11 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): @_exception_logger def scatter_object_list( - scatter_object_output_list, scatter_object_input_list, src=0, group=None + scatter_object_output_list: List[Any], + scatter_object_input_list: Optional[List[Any]] = None, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + group_src: Optional[int] = None, ): """ Scatters picklable objects in ``scatter_object_input_list`` to the whole group. @@ -3379,13 +3402,15 @@ def scatter_object_list( Args: scatter_object_output_list (List[Any]): Non-empty list whose first element will store the object scattered to this rank. - scatter_object_input_list (List[Any]): List of input objects to scatter. + scatter_object_input_list (List[Any], optional): List of input objects to scatter. Each object must be picklable. Only objects on the ``src`` rank will be scattered, and the argument can be ``None`` for non-src ranks. src (int): Source rank from which to scatter ``scatter_object_input_list``. Source rank is based on global process group (regardless of ``group`` argument). + (If both ``src`` and ``group_src`` are None, default is global rank 0) group: (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is ``None``. + group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src`` Returns: ``None``. If rank is part of the group, ``scatter_object_output_list`` @@ -3422,6 +3447,10 @@ def scatter_object_list( >>> output_list [{1: 2}] """ + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + global_src = _canonicalize_group_rank(group, src, group_src, return_global=True) if _rank_not_in_group(group): _warn_not_in_group("scatter_object_list") return @@ -3434,9 +3463,13 @@ def scatter_object_list( "Expected argument scatter_object_output_list to be a list of size at least 1." ) - my_rank = get_rank() + my_global_rank = get_rank() pg_device = _get_object_coll_device(group) - if my_rank == src: + if my_global_rank == global_src: + if scatter_object_input_list is None: + raise ValueError( + "source rank must provide non-None scatter_object_input_list" + ) tensor_list, tensor_sizes = zip( *[ _object_to_tensor(obj, pg_device, group) @@ -3445,15 +3478,14 @@ def scatter_object_list( ) tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) - # Src rank broadcasts the maximum tensor size. This is because all ranks are - # expected to call into scatter() with equal-sized tensors. - if my_rank == src: + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] for tensor in tensor_list: # type: ignore[possibly-undefined] tensor.resize_(max_tensor_size) else: max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) - broadcast(max_tensor_size, src=src, group=group) + broadcast(max_tensor_size, src=global_src, group=group) # Scatter actual serialized objects output_tensor = torch.empty( @@ -3461,8 +3493,8 @@ def scatter_object_list( ) scatter( output_tensor, - scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined] - src=src, + scatter_list=None if my_global_rank != global_src else tensor_list, # type: ignore[possibly-undefined] + src=global_src, group=group, ) @@ -3470,8 +3502,8 @@ def scatter_object_list( obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) scatter( obj_tensor_size, - scatter_list=None if my_rank != src else tensor_sizes, # type: ignore[possibly-undefined] - src=src, + scatter_list=None if my_global_rank != global_src else tensor_sizes, # type: ignore[possibly-undefined] + src=global_src, group=group, ) @@ -3779,7 +3811,14 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): @_exception_logger -def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): +def gather( + tensor: torch.Tensor, + gather_list: Optional[List[torch.Tensor]] = None, + dst: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: Optional[int] = None, +): """ Gathers a list of tensors in a single process. @@ -3790,10 +3829,12 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): gather_list (list[Tensor], optional): List of appropriately, same-sized tensors to use for gathered data (default is None, must be specified on the destination rank) - dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` Returns: Async work handle, if async_op is set to True. @@ -3826,26 +3867,22 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): else: gather_list = [] _ensure_all_tensors_same_dtype(tensor, gather_list) - + group = _group_or_default_group(group) if _rank_not_in_group(group): _warn_not_in_group("gather") return - - my_rank = get_rank() - _validate_output_list_for_rank(my_rank, dst, gather_list) - output_tensors = [gather_list] if dst == my_rank else [] + if dst is None and group_dst is None: + dst = 0 + global_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=True) + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + my_global_rank = get_rank() + _validate_output_list_for_rank(my_global_rank, global_dst, gather_list) + output_tensors = [gather_list] if global_dst == my_global_rank else [] input_tensors = [tensor] opts = GatherOptions() - opts.rootRank = dst - - if group is None or group is GroupMember.WORLD: - default_pg = _get_default_group() - work = default_pg.gather(output_tensors, input_tensors, opts) - else: - group_dst_rank = get_group_rank(group, dst) - opts.rootRank = group_dst_rank - work = group.gather(output_tensors, input_tensors, opts) + opts.rootRank = group_dst + work = group.gather(output_tensors, input_tensors, opts) if async_op: return work @@ -3854,7 +3891,14 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): @_exception_logger -def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): +def scatter( + tensor: torch.Tensor, + scatter_list: Optional[List[torch.Tensor]] = None, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: Optional[int] = None, +): """ Scatters a list of tensors to all processes in a group. @@ -3868,10 +3912,11 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): scatter_list (list[Tensor]): List of tensors to scatter (default is None, must be specified on the source rank) src (int): Source rank on global process group (regardless of ``group`` argument). - Default is 0 + (If both ``src`` and ``group_src`` are None, default is global rank 0) group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op + group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src`` Returns: Async work handle, if async_op is set to True. @@ -3902,14 +3947,17 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): """ _check_single_tensor(tensor, "tensor") - # Parameter ``scatter_list`` may be left unspecified on non-src ranks. if scatter_list: _check_tensor_list(scatter_list, "scatter_list") else: scatter_list = [] _ensure_all_tensors_same_dtype(tensor, scatter_list) - + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + global_src = _canonicalize_group_rank(group, src, group_src, return_global=True) + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) if _rank_not_in_group(group): _warn_not_in_group("scatter") return @@ -3918,8 +3966,8 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): ] tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) - my_rank = get_rank() - if src == my_rank: + my_global_rank = get_rank() + if global_src == my_global_rank: if not scatter_list: raise ValueError( "Argument ``scatter_list`` must be specified on source rank." @@ -3936,16 +3984,9 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): output_tensors = [tensor] opts = ScatterOptions() - opts.rootRank = src + opts.rootRank = group_src opts.asyncOp = async_op - - if group is None or group is GroupMember.WORLD: - default_pg = _get_default_group() - work = default_pg.scatter(output_tensors, input_tensors, opts) - else: - group_src_rank = get_group_rank(group, src) - opts.rootRank = group_src_rank - work = group.scatter(output_tensors, input_tensors, opts) + work = group.scatter(output_tensors, input_tensors, opts) if async_op: return work