-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature]: CI: Test on NVLink-enabled machine #4770
Comments
Ideally, distributed related tests should run in both NVLink machine and PCIe machine. |
https://buildkite.com/vllm/ci/builds/7222#018f7121-bd93-456f-ae01-25e0a4e63061
The log above shows that it didn't pass the p2p check (but it passed the nvlink check). Therefore, there might be some non-hardware problems in buildkit agent machine? |
It skips nvlink check because it directly read p2p cache file rather than testing it. |
@youkaichao Custom allreduce also works with 2 PCIe cards as a special case |
@hanzhi713 do you mean custom allreduce with |
It's more performant than NCCL when either
Currently, only case 1 is enabled.
|
Case 2 is not enabled/currently supported because the memory model of multiple GPUs over PCIe fabric is not very well documented. I'm afraid that we'll run into some memory ordering/visibility issues. #2760 (comment) here's a comment made regarding the performance with more than two PCIe GPUs. |
For non-NVLink GPUs, do they need to have p2p access for custom allreduce to work? In our CI machine, with 2/4 * L4, it seems it does not have p2p access. The machine topology is:
|
Empirically, I find p2p access is only available for PIX and NV# connection. |
Yes. You need P2P access for custom allreduce to work. Not all PCIe platforms support this feature. I have a bunch of A30, A10 and T4 machines and the topology is all SYS, but they do support PCIe P2P. |
I'm quite confused about how to detect p2p access ability. On my L4 * 2 machine, import torch
print(torch.cuda.can_device_access_peer(0, 1)) # False
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c).cpu().item()
print(_can_actually_p2p(0, 1)) # True |
I believe cudaMemcpyPeer implements something like this
So even if there's no p2p support, it still might work. We use it to check when p2p is supported, it actually works and produces the correct result. |
Here is another script: import torch
import torch.distributed as dist
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(dist.get_rank())
data = torch.zeros(2, 2, device='cuda') + dist.get_rank() + 1
def share_cuda_tensor(data, src, rank):
if rank == src:
func, args = torch.multiprocessing.reductions.reduce_tensor(data)
dist.broadcast_object_list([[func, args]], src)
else:
recv = [None]
dist.broadcast_object_list(recv, src)
func, args = recv[0]
data = func(*args)
return data
data = share_cuda_tensor(data, 0, dist.get_rank())
if dist.get_rank() == 1:
data += 1
dist.barrier()
print(f"Rank {dist.get_rank()} has data {data}") The check https://github.com/pytorch/pytorch/blob/de42af4b0087118cf5527261c532927efcb9a0df/torch/csrc/StorageSharing.cpp#L324 for details. My question is, why |
And, if |
It's not always true. See also pytorch/pytorch#119638 for a discussion on this. |
Not quite sure why The doc says "Maps memory exported from another process with cudaIpcGetMemHandle into the current device address space. For contexts on different devices cudaIpcOpenMemHandle can attempt to enable peer access between the devices as if the user called cudaDeviceEnablePeerAccess. This behavior is controlled by the cudaIpcMemLazyEnablePeerAccess flag. cudaDeviceCanAccessPeer can determine if a mapping is possible." and that is what I assume: IPC is only possible if cudaDeviceEnablePeerAccess returns True. |
I would say, IPC + p2p is only possible if cudaDeviceEnablePeerAccess returns True. There is another case, that processes can use IPC in the same GPU, which is how pytorch uses it. |
close as finished in #5689 |
🚀 The feature, motivation and pitch
The custom allreduce code https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/custom_all_reduce.py works with NVLink only.
To run the corresponding distributed tests https://github.com/vllm-project/vllm/blob/main/tests/distributed/test_custom_all_reduce.py , we need at least 2 NVLink-enabled GPUs. (Ideally we need 4 to test for all cases).
cc @simon-mo
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: