Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add integrity check during initialization; add test for it #4155

Merged
merged 10 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.
commands:
- pytest -v -s test_pynccl.py
- pytest -v -s test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
Expand Down
43 changes: 43 additions & 0 deletions tests/distributed/test_pynccl_library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import multiprocessing
import tempfile


def target_fn(env, filepath):
from vllm.utils import update_environment_variables
update_environment_variables(env)
from vllm.utils import nccl_integrity_check
nccl_integrity_check(filepath)


def test_library_file():
# note: don't import vllm.distributed.device_communicators.pynccl
# before running this test, otherwise the library file will be loaded
# and it might interfere with the test
from vllm.utils import find_nccl_library
so_file = find_nccl_library()
with open(so_file, 'rb') as f:
content = f.read()
try:
# corrupt the library file, should raise an exception
with open(so_file, 'wb') as f:
f.write(content[:len(content) // 2])
p = multiprocessing.Process(target=target_fn, args=({}, so_file))
p.start()
p.join()
assert p.exitcode != 0

# move the library file to a tmp path
# test VLLM_NCCL_SO_PATH
fd, path = tempfile.mkstemp()
with open(path, 'wb') as f:
f.write(content)
p = multiprocessing.Process(target=target_fn,
args=({
"VLLM_NCCL_SO_PATH": path
}, path))
p.start()
p.join()
assert p.exitcode == 0
finally:
with open(so_file, 'wb') as f:
f.write(content)
38 changes: 12 additions & 26 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,35 @@

import ctypes
import datetime
import glob
import os
import platform

# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp

from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check

logger = init_logger(__name__)

so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")

# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser(
f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path)
vllm_nccl_path = files[0] if files else None

# manually load the nccl library
if so_file:
logger.info(
f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Loading nccl from library {so_file}")
so_file = find_nccl_library()

try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
f"Failed to load NCCL library from {so_file} ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise please set the environment variable VLLM_NCCL_SO_PATH"
"Otherwise, the nccl library might not exist, be corrupted "
f"or it does not support the current platform {platform.platform()}."
f"One solution is to download libnccl2 version 2.18 from "
f"https://developer.download.nvidia.com/compute/cuda/repos/ "
f"and extract the libnccl.so.2 file. If you already have the "
f"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.")
raise e

Expand Down
51 changes: 51 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import enum
import gc
import glob
import os
import socket
import subprocess
Expand Down Expand Up @@ -517,3 +518,53 @@ def init_cached_hf_modules():
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()


def nccl_integrity_check(filepath):
"""
when the library is corrupted, we cannot catch
the exception in python. it will crash the process.
instead, we use the exit code of `ldd` to check
if the library is corrupted. if not, we will return
the version of the library.
"""
exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
if exit_code != 0:
raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
import ctypes

nccl = ctypes.CDLL(filepath)
version = ctypes.c_int()
nccl.ncclGetVersion.restype = ctypes.c_int
nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
result = nccl.ncclGetVersion(ctypes.byref(version))
assert result == 0
return version.value


def find_nccl_library():
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")

# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser(
f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path)
vllm_nccl_path = files[0] if files else None

# manually load the nccl library
if so_file:
logger.info(
f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}"
)
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}")
return so_file
Loading