Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] fix _is_full_nvlink detection #4233

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from contextlib import contextmanager
from typing import Optional
from typing import List, Optional

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -53,14 +54,20 @@ def init_custom_ar() -> None:
return False
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
full_nvlink = _is_full_nvlink(rank, world_size)
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = list(
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
logger.warn(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
# test P2P capability
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
Expand Down Expand Up @@ -138,23 +145,28 @@ def _nvml():
pynvml.nvmlShutdown()


# query if the set of gpus are fully connected by nvlink (1 hop)
@_nvml()
def _is_full_nvlink(rank, world_size):
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
for i in range(world_size):
if i != rank:
try:
peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
except pynvml.NVMLError as error:
logger.info(
f"NVLink detection failed with message \"{str(error)}\". "
"This is normal if your machine has no NVLink equipped")
return False
return True


Expand Down
Loading