Skip to content

Commit

Permalink
[Core] nccl integrity check and test (vllm-project#4155)
Browse files Browse the repository at this point in the history
[Core] Add integrity check during initialization; add test for it (vllm-project#4155)
  • Loading branch information
youkaichao authored Apr 18, 2024
1 parent aa59b2f commit 01d5891
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 26 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.
commands:
- pytest -v -s test_pynccl.py
- pytest -v -s test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
Expand Down
43 changes: 43 additions & 0 deletions tests/distributed/test_pynccl_library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import multiprocessing
import tempfile


def target_fn(env, filepath):
from vllm.utils import update_environment_variables
update_environment_variables(env)
from vllm.utils import nccl_integrity_check
nccl_integrity_check(filepath)


def test_library_file():
# note: don't import vllm.distributed.device_communicators.pynccl
# before running this test, otherwise the library file will be loaded
# and it might interfere with the test
from vllm.utils import find_nccl_library
so_file = find_nccl_library()
with open(so_file, 'rb') as f:
content = f.read()
try:
# corrupt the library file, should raise an exception
with open(so_file, 'wb') as f:
f.write(content[:len(content) // 2])
p = multiprocessing.Process(target=target_fn, args=({}, so_file))
p.start()
p.join()
assert p.exitcode != 0

# move the library file to a tmp path
# test VLLM_NCCL_SO_PATH
fd, path = tempfile.mkstemp()
with open(path, 'wb') as f:
f.write(content)
p = multiprocessing.Process(target=target_fn,
args=({
"VLLM_NCCL_SO_PATH": path
}, path))
p.start()
p.join()
assert p.exitcode == 0
finally:
with open(so_file, 'wb') as f:
f.write(content)
38 changes: 12 additions & 26 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,35 @@

import ctypes
import datetime
import glob
import os
import platform

# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp

from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check

logger = init_logger(__name__)

so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")

# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser(
f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path)
vllm_nccl_path = files[0] if files else None

# manually load the nccl library
if so_file:
logger.info(
f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Loading nccl from library {so_file}")
so_file = find_nccl_library()

try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
f"Failed to load NCCL library from {so_file} ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise please set the environment variable VLLM_NCCL_SO_PATH"
"Otherwise, the nccl library might not exist, be corrupted "
f"or it does not support the current platform {platform.platform()}."
f"One solution is to download libnccl2 version 2.18 from "
f"https://developer.download.nvidia.com/compute/cuda/repos/ "
f"and extract the libnccl.so.2 file. If you already have the "
f"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.")
raise e

Expand Down
51 changes: 51 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import enum
import gc
import glob
import os
import socket
import subprocess
Expand Down Expand Up @@ -517,3 +518,53 @@ def init_cached_hf_modules():
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()


def nccl_integrity_check(filepath):
"""
when the library is corrupted, we cannot catch
the exception in python. it will crash the process.
instead, we use the exit code of `ldd` to check
if the library is corrupted. if not, we will return
the version of the library.
"""
exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
if exit_code != 0:
raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
import ctypes

nccl = ctypes.CDLL(filepath)
version = ctypes.c_int()
nccl.ncclGetVersion.restype = ctypes.c_int
nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
result = nccl.ncclGetVersion(ctypes.byref(version))
assert result == 0
return version.value


def find_nccl_library():
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")

# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser(
f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path)
vllm_nccl_path = files[0] if files else None

# manually load the nccl library
if so_file:
logger.info(
f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}"
)
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}")
return so_file

0 comments on commit 01d5891

Please sign in to comment.