Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] P2P Operations on NCCL do not respect tag #125079

Closed
andoorve opened this issue Apr 26, 2024 · 5 comments
Closed

[Distributed] P2P Operations on NCCL do not respect tag #125079

andoorve opened this issue Apr 26, 2024 · 5 comments
Labels
module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@andoorve
Copy link
Contributor

andoorve commented Apr 26, 2024

🐛 Describe the bug

When using NCCL with Send/Recv operations we expect the tag argument to be respected for send/recv matching. This doesn't occur in practice.

Example program:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
    """ Distributed function to be implemented later. """
    device = torch.device(f"cuda:{rank}")
    if rank == 0:
        tens = torch.ones([33, 4096], dtype=torch.bfloat16, device=device)
        tens2 = 2 * torch.ones([33, 4096], dtype=torch.bfloat16, device=device)
        dist.send(tens2, dst=1, tag=0)
        dist.send(tens, dst=1, tag=1)
    else:
        tens = torch.empty([33, 4096], dtype=torch.bfloat16, device=device)
        tens2 = torch.empty([33, 4096], dtype=torch.bfloat16, device=device)
        dist.recv(tens, src=0, tag=1)
        dist.recv(tens2, src=0, tag=0)
        print (f'{tens}, {tens2}')

def init_process(rank, size, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29501'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)

if __name__ == "__main__":
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=init_process, args=(rank, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Here we expect tens to be a tensor of 1s and tens2 to be a tensor of 2s when received, or at least a hang. The opposite happens.

Should be the exact same issue as this: #94819 but for send/recv instead of isend and irecv.

Versions

Collecting environment information...
PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.0
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.5.0-1016-gcp-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.3.107
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA L4
GPU 1: NVIDIA L4
GPU 2: NVIDIA L4
GPU 3: NVIDIA L4

Nvidia driver version: 545.23.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             96
On-line CPU(s) list:                0-95
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          2
Stepping:                           7
BogoMIPS:                           4400.42
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          1.5 MiB (48 instances)
L1i cache:                          1.5 MiB (48 instances)
L2 cache:                           48 MiB (48 instances)
L3 cache:                           77 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-23,48-71
NUMA node1 CPU(s):                  24-47,72-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT Host state unknown

Versions of relevant libraries:
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] torch==2.2.1
[pip3] triton==2.2.0
[conda] Could not collect

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@andoorve
Copy link
Contributor Author

cc: @H-Huang @albanD

@andoorve andoorve changed the title P2P Operations on NCCL do not respect tag [Distributed] P2P Operations on NCCL do not respect tag Apr 26, 2024
@cpuhrsch cpuhrsch added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 30, 2024
@wconstab
Copy link
Contributor

This is actually a known limitation. We should better document it though.

NCCL's API does not support tags, so there isn't a clear way that we could make use of this, even though our APIs expose it (such that it can be used by backends that do support a tag).

https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#c.ncclSend

cc @kwen2501 @H-Huang Please keep me honest here.

wconstab added a commit that referenced this issue Apr 30, 2024
Existing documentation on isend/irecv also applies to send/recv. This PR
copies the doc/warning to send/recv ops as well.

Note: tag may be supplied, but will be ignored when used with nccl
backend.

Fixes #94819 #125079

[ghstack-poisoned]
wconstab added a commit that referenced this issue Apr 30, 2024
Existing documentation on isend/irecv also applies to send/recv. This PR
copies the doc/warning to send/recv ops as well.

Note: tag may be supplied, but will be ignored when used with nccl
backend.

Fixes #94819 #125079

ghstack-source-id: caf8308608ac82433d8d1c76d17524b7d0e2154d
Pull Request resolved: #125278
pytorchmergebot pushed a commit that referenced this issue May 1, 2024
Existing documentation on isend/irecv also applies to send/recv. This PR
copies the doc/warning to send/recv ops as well.

Note: tag may be supplied, but will be ignored when used with nccl
backend.

Fixes #94819 #125079

Pull Request resolved: #125278
Approved by: https://github.com/kwen2501
@wconstab
Copy link
Contributor

wconstab commented May 1, 2024

Closing as fixed by updating docs.

@wconstab wconstab closed this as completed May 1, 2024
@kwen2501
Copy link
Contributor

kwen2501 commented May 1, 2024

Wanted to note that tagging is not intended for supporting out-of-order send/recv calls (in particular the blocking version). Neither NCCL nor MPI would be able to support the example in this issue.

@andoorve
Copy link
Contributor Author

andoorve commented May 1, 2024

In that case wouldn't we at least expect hangs (with MPI)? Wouldn't both processes block on differently tagged send/recv?

petrex pushed a commit to petrex/pytorch that referenced this issue May 3, 2024
Existing documentation on isend/irecv also applies to send/recv. This PR
copies the doc/warning to send/recv ops as well.

Note: tag may be supplied, but will be ignored when used with nccl
backend.

Fixes pytorch#94819 pytorch#125079

Pull Request resolved: pytorch#125278
Approved by: https://github.com/kwen2501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nccl Problems related to nccl support oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants