diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index f83caef879da3..7602897d3dd8f 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -145,8 +145,10 @@ def _is_full_nvlink(rank, world_size): for i in range(world_size): if i != rank: try: - link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i) - if not link_state: + peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i) + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: return False except pynvml.NVMLError as error: logger.info(