Skip to content

Commit

Permalink
[PGNCCL] Make sure we do not use split for P2P comm creation (pytorch…
Browse files Browse the repository at this point in the history
…#139013)

Resolve comment pytorch#138527 (comment)

There was a split-vs-P2P bug:
When P2P comm creation invokes `getNCCLComm`, it may see a `split_from` options which is meant for the previous PG creation. Then the P2P comm creation may use `ncclCommSplit` and hang, because not all ranks join this call. The bug slips previously/today because there is no CI test with the following recipe: eager init + new group + P2P in that new group.

Pull Request resolved: pytorch#139013
Approved by: https://github.com/shuqiangzhang
  • Loading branch information
kwen2501 authored and pytorchmergebot committed Dec 9, 2024
1 parent 219e9c8 commit d99c9c2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def test_manual_with_data_parallel(self, dp_type, ScheduleClass, use_new_runtime
store=store,
rank=self.rank,
world_size=self.world_size,
device_id=device,
# TODO (kwen2501): disabled eager init below as this test is failing
# with bug fix #139013. Temporarily use lazy init to cover the
# composability aspect of this test.
# device_id=device,
)
device_mesh = init_device_mesh(
"cuda", mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")
Expand Down
22 changes: 22 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,28 @@ def test_non_blocking_p2p(self):
self.assertEqual(send_tensor, recv_tensor)
dist.destroy_process_group()

@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("eager_init", [True, False])
def test_subgroup_p2p(self, eager_init: bool):
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
c10d.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
device_id=device if eager_init else None,
)
send_tensor = torch.ones(10, 10, device=device)
group = dist.new_group()
if self.rank == 0:
dist.send(send_tensor, 1, group=group)
if self.rank == 1:
recv_tensor = torch.rand(10, 10, device=device)
dist.recv(recv_tensor, 0, group=group)
self.assertEqual(send_tensor, recv_tensor)
dist.destroy_process_group()

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_get_uid(self):
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2519,7 +2519,12 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
#endif

#ifdef NCCL_HAS_COMM_SPLIT
if (options_->split_from) {
// Use split to create a new communicator only if:
// 1. The parent comm is known; AND
// 2. The new comm is not for a point-to-point operation.
// ncclCommSplit() is a collective call, so it does not work for P2P
// operations.
if (options_->split_from && !singleP2POp) {
// Find a valid, healthy communicator to split from if possible.
std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
auto& other_comms = options_->split_from->devNCCLCommMap_;
Expand Down

0 comments on commit d99c9c2

Please sign in to comment.