Skip to content

Commit

Permalink
[C10D] Support group_dst in scatter/gather (+object) ops (pytorch#140827
Browse files Browse the repository at this point in the history
)

Also add missing mypy typing and a few asserts to make mypy happy

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in pytorch#140460

Note: object collective version canonicalizes to global instead of group
rank, simply becuase this left more of the original code intact and
required less conversions overall.

Pull Request resolved: pytorch#140827
Approved by: https://github.com/kwen2501
  • Loading branch information
wconstab authored and youssef62 committed Nov 23, 2024
1 parent 5d94fc9 commit 570d032
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 91 deletions.
133 changes: 95 additions & 38 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3747,7 +3747,8 @@ def _init_two_pg2_subgroups(self, world_size: int = 4):

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_gather_subgroup(self):
@parametrize("group_rank", [True, False])
def test_gather_subgroup(self, group_rank):
world_size = 4
if self.rank >= world_size:
# just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
Expand All @@ -3758,28 +3759,48 @@ def test_gather_subgroup(self):
input = torch.ones((10,), device=device) * self.rank
if self.rank == 0 or self.rank == 2:
gather_list = [torch.empty_like(input) for _ in range(subgroup.size())]
torch.distributed.gather(
input,
gather_list=gather_list,
dst=self.rank,
group=subgroup,
async_op=False,
)
if group_rank:
# global_dst=0 group_dst=0 my_global_rank=2 gather_list is not None=True
torch.distributed.gather(
input,
gather_list=gather_list,
group_dst=0,
group=subgroup,
async_op=False,
)
else:
torch.distributed.gather(
input,
gather_list=gather_list,
dst=self.rank,
group=subgroup,
async_op=False,
)
for src in range(len(gather_list)):
expected = (torch.ones_like(input) * self.rank) + src
self.assertEqual(gather_list[src], expected)
else:
torch.distributed.gather(
input,
gather_list=None,
dst=self.rank - 1,
group=subgroup,
async_op=False,
)
if group_rank:
torch.distributed.gather(
input,
gather_list=None,
group_dst=0,
group=subgroup,
async_op=False,
)
else:
torch.distributed.gather(
input,
gather_list=None,
dst=self.rank - 1,
group=subgroup,
async_op=False,
)

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_gather_object_subgroup(self):
@parametrize("group_rank", [True, False])
def test_gather_object_subgroup(self, group_rank):
world_size = 4
if self.rank >= world_size:
# just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
Expand All @@ -3797,15 +3818,25 @@ def test_gather_object_subgroup(self):
# another weird thing- what's the point of making me specify some empty objects in my list?
# empty list should be valid imo. (but it throws an error)
gather_list = [{}, {}]
torch.distributed.gather_object(
input, object_gather_list=gather_list, dst=self.rank, group=subgroup
)
if group_rank:
torch.distributed.gather_object(
input, object_gather_list=gather_list, group_dst=0, group=subgroup
)
else:
torch.distributed.gather_object(
input, object_gather_list=gather_list, dst=self.rank, group=subgroup
)
for src in range(len(gather_list)):
self.assertEqual(gather_list[src]["rank"], self.rank + src)
else:
torch.distributed.gather_object(
input, object_gather_list=None, dst=self.rank - 1, group=subgroup
)
if group_rank:
torch.distributed.gather_object(
input, object_gather_list=None, group_dst=0, group=subgroup
)
else:
torch.distributed.gather_object(
input, object_gather_list=None, dst=self.rank - 1, group=subgroup
)

@requires_nccl()
@skip_if_lt_x_gpu(4)
Expand Down Expand Up @@ -3931,7 +3962,8 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_scatter_subgroup(self):
@parametrize("group_rank", [True, False])
def test_scatter_subgroup(self, group_rank):
world_size = 4
if self.rank >= world_size:
return
Expand All @@ -3940,18 +3972,27 @@ def test_scatter_subgroup(self):
x = torch.empty((10,), device=device)
expected = torch.ones((10,), device=device) * self.rank
if self.rank == 0 or self.rank == 2:
c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup)
if group_rank:
c10d.scatter(x, scatter_list=None, group_src=1, group=subgroup)
else:
c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup)
else:
scatter_list = [
torch.ones((10,), device=device) * (self.rank - 1),
torch.ones((10,), device=device) * self.rank,
]
c10d.scatter(x, scatter_list=scatter_list, src=self.rank, group=subgroup)
if group_rank:
c10d.scatter(x, scatter_list=scatter_list, group_src=1, group=subgroup)
else:
c10d.scatter(
x, scatter_list=scatter_list, src=self.rank, group=subgroup
)
self.assertEqual(x, expected)

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_scatter_object_list_subgroup(self):
@parametrize("group_rank", [True, False])
def test_scatter_object_list_subgroup(self, group_rank):
world_size = 4
if self.rank >= world_size:
return
Expand All @@ -3960,24 +4001,40 @@ def test_scatter_object_list_subgroup(self):
scatter_object_output_list = [None]
expected = [{"rank": self.rank}]
if self.rank == 0 or self.rank == 2:
c10d.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=None,
src=self.rank + 1,
group=subgroup,
)
if group_rank:
c10d.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=None,
group_src=1,
group=subgroup,
)
else:
c10d.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=None,
src=self.rank + 1,
group=subgroup,
)

else:
scatter_object_input_list = [
{"rank": self.rank - 1},
{"rank": self.rank},
]
c10d.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=scatter_object_input_list,
src=self.rank,
group=subgroup,
)
if group_rank:
c10d.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=scatter_object_input_list,
group_src=1,
group=subgroup,
)
else:
c10d.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=scatter_object_input_list,
src=self.rank,
group=subgroup,
)
self.assertEqual(scatter_object_output_list, expected)


Expand Down
10 changes: 9 additions & 1 deletion torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def shard(
local_shards = []
local_tensor = None
local_metadata = None

tensors_to_scatter = cast(
List[Optional[torch.Tensor]],
[None] * dist.get_world_size(process_group),
Expand Down Expand Up @@ -192,9 +193,16 @@ def shard(
process_group, src_for_scatter
)

tensors_to_scatter_: Optional[List[torch.Tensor]] = None
if current_rank == src_rank:
tensors_to_scatter_ = []
for t in tensors_to_scatter:
assert isinstance(t, torch.Tensor)
tensors_to_scatter_.append(t)

dist.scatter(
local_tensor,
scatter_list=tensors_to_scatter if current_rank == src_rank else None,
scatter_list=tensors_to_scatter_,
src=src_for_scatter,
group=process_group,
)
Expand Down
Loading

0 comments on commit 570d032

Please sign in to comment.