Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: CI: Test on NVLink-enabled machine #4770

Closed
youkaichao opened this issue May 12, 2024 · 18 comments
Closed

[Feature]: CI: Test on NVLink-enabled machine #4770

youkaichao opened this issue May 12, 2024 · 18 comments

Comments

@youkaichao
Copy link
Member

youkaichao commented May 12, 2024

🚀 The feature, motivation and pitch

The custom allreduce code https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/custom_all_reduce.py works with NVLink only.

To run the corresponding distributed tests https://github.com/vllm-project/vllm/blob/main/tests/distributed/test_custom_all_reduce.py , we need at least 2 NVLink-enabled GPUs. (Ideally we need 4 to test for all cases).

cc @simon-mo

Alternatives

No response

Additional context

No response

@youkaichao
Copy link
Member Author

Ideally, distributed related tests should run in both NVLink machine and PCIe machine.

@AllenDou
Copy link
Contributor

AllenDou commented May 13, 2024

https://buildkite.com/vllm/ci/builds/7222#018f7121-bd93-456f-ae01-25e0a4e63061

(eager_allreduce pid=3040) INFO 05-13 08:48:35 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
(eager_allreduce pid=3040) WARNING 05-13 08:48:35 custom_all_reduce.py:166] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
FAILED
test_custom_all_reduce.py::test_custom_allreduce[test_target0-2-2] SKIPPED2024-05-13 08:48:35,962	ERROR worker.py:406 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::eager_allreduce() (pid=3040, ip=10.68.0.191)

The log above shows that it didn't pass the p2p check (but it passed the nvlink check). Therefore, there might be some non-hardware problems in buildkit agent machine?

@youkaichao
Copy link
Member Author

The log above shows that it didn't pass the p2p check (but it passed the nvlink check)

It skips nvlink check because it directly read p2p cache file rather than testing it.

@hanzhi713
Copy link
Contributor

hanzhi713 commented May 15, 2024

@youkaichao Custom allreduce also works with 2 PCIe cards as a special case

@youkaichao
Copy link
Member Author

@hanzhi713 do you mean custom allreduce with full_nvlink=False? Is it still more performant than nccl?

@hanzhi713
Copy link
Contributor

hanzhi713 commented May 15, 2024

@hanzhi713 do you mean custom allreduce with full_nvlink=False? Is it still more performant than nccl?

It's more performant than NCCL when either

  1. there are only two PCIe GPUs (they can be connected to the PCIe root complex directly or with a PCIe switch), or
  2. there are multiple PCIe GPUs connected to the same PCIe switch.

Currently, only case 1 is enabled.

if world_size > 2 and not full_nvlink:

@hanzhi713
Copy link
Contributor

hanzhi713 commented May 15, 2024

Case 2 is not enabled/currently supported because the memory model of multiple GPUs over PCIe fabric is not very well documented. I'm afraid that we'll run into some memory ordering/visibility issues.

#2760 (comment) here's a comment made regarding the performance with more than two PCIe GPUs.

@youkaichao
Copy link
Member Author

For non-NVLink GPUs, do they need to have p2p access for custom allreduce to work?

In our CI machine, with 2/4 * L4, it seems it does not have p2p access. The machine topology is:

GPU0 GPU1 GPU2 GPU3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X PHB PHB PHB 0-47 N/A N/A
GPU1 PHB X PHB PHB 0-47 N/A N/A
GPU2 PHB PHB X PHB 0-47 N/A N/A
GPU3 PHB PHB PHB X 0-47 N/A N/A

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

@youkaichao
Copy link
Member Author

Empirically, I find p2p access is only available for PIX and NV# connection.

@hanzhi713
Copy link
Contributor

Yes. You need P2P access for custom allreduce to work. Not all PCIe platforms support this feature. I have a bunch of A30, A10 and T4 machines and the topology is all SYS, but they do support PCIe P2P.

@youkaichao
Copy link
Member Author

I'm quite confused about how to detect p2p access ability.

On my L4 * 2 machine, torch.cuda.can_device_access_peer(0, 1) == False, but _can_actually_p2p(0, 1) == True .

import torch
print(torch.cuda.can_device_access_peer(0, 1)) # False
def _can_actually_p2p(idx_a, idx_b):
    dev_i = f"cuda:{idx_a}"
    dev_j = f"cuda:{idx_b}"
    a = torch.randn(5, device=dev_i) + 123.0
    b = a.to(dev_j)
    c = b.to(dev_i)
    return torch.all(a == c).cpu().item()
print(_can_actually_p2p(0, 1)) # True

@hanzhi713
Copy link
Contributor

I believe cudaMemcpyPeer implements something like this

if can_device_access_peer:
   use_p2p_memcpy()
else:
   use_fallback_implementation_that_goes_through_host_mem()

So even if there's no p2p support, it still might work. We use it to check when p2p is supported, it actually works and produces the correct result.

@youkaichao
Copy link
Member Author

Here is another script:

import torch
import torch.distributed as dist
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(dist.get_rank())

data = torch.zeros(2, 2, device='cuda') + dist.get_rank() + 1

def share_cuda_tensor(data, src, rank):
    if rank == src:
        func, args = torch.multiprocessing.reductions.reduce_tensor(data)
        dist.broadcast_object_list([[func, args]], src)
    else:
        recv = [None]
        dist.broadcast_object_list(recv, src)
        func, args = recv[0]
        data = func(*args)
    return data

data = share_cuda_tensor(data, 0, dist.get_rank())

if dist.get_rank() == 1:
    data += 1
dist.barrier()

print(f"Rank {dist.get_rank()} has data {data}")

The torch.multiprocessing.reductions.reduce_tensor(data) internally uses t.untyped_storage()._share_cuda_() , which unconditionally uses cudaIpcGetMemHandle . It still succeeds.

check https://github.com/pytorch/pytorch/blob/de42af4b0087118cf5527261c532927efcb9a0df/torch/csrc/StorageSharing.cpp#L324 for details.

My question is, why torch.cuda.can_device_access_peer(0, 1) == False , but it can still uses cudaIpcGetMemHandle for sharing cuda tensors.

@youkaichao
Copy link
Member Author

And, if torch.cuda.can_device_access_peer(0, 1) == False, but _can_actually_p2p(0, 1) == True , what is the rationale for testing _can_actually_p2p then? It is always True .

@hanzhi713
Copy link
Contributor

hanzhi713 commented May 17, 2024

It's not always true. can_device_access_peer=True does not mean that P2P is correctly supported, i.e. driver can be buggy. vLLM is run on all sorts of consumer hardware and there are those edge cases that we must pay attention. It's not our problem. It's Nvidia's problem and _can_actually_p2p is our workaround.

See also pytorch/pytorch#119638 for a discussion on this.

@hanzhi713
Copy link
Contributor

hanzhi713 commented May 17, 2024

Not quite sure why can_device_access_peer=False but cudaIpc can still be used. This is the part where documentation or clarification is really lacking from Nvidia.

The doc says "Maps memory exported from another process with cudaIpcGetMemHandle into the current device address space. For contexts on different devices cudaIpcOpenMemHandle can attempt to enable peer access between the devices as if the user called cudaDeviceEnablePeerAccess. This behavior is controlled by the cudaIpcMemLazyEnablePeerAccess flag. cudaDeviceCanAccessPeer can determine if a mapping is possible." and that is what I assume: IPC is only possible if cudaDeviceEnablePeerAccess returns True.

@youkaichao
Copy link
Member Author

IPC is only possible if cudaDeviceEnablePeerAccess returns True.

I would say, IPC + p2p is only possible if cudaDeviceEnablePeerAccess returns True. There is another case, that processes can use IPC in the same GPU, which is how pytorch uses it.

@youkaichao
Copy link
Member Author

close as finished in #5689

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants