diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f6781de61af19..d7d2b930e6e7a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -22,10 +22,13 @@ steps: working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. -- label: Distributed Correctness Test - command: pytest -v -s --forked test_basic_distributed_correctness.py +- label: Distributed Tests working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. + commands: + - pytest -v -s --forked test_pynccl.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py diff --git a/Dockerfile b/Dockerfile index 1f254c76fe5af..20bbff34b7fcb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -97,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal #################### RUNTIME BASE IMAGE #################### # We used base cuda image because pytorch installs its own cuda libraries. -# However cupy depends on cuda libraries so we had to switch to the runtime image +# However pynccl depends on cuda libraries so we had to switch to the runtime image # In the future it would be nice to get a container with pytorch and cuda without duplicating cuda FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base diff --git a/Dockerfile.rocm b/Dockerfile.rocm index a45265d79a6ac..a09de99f7a468 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -23,9 +23,6 @@ RUN echo "FA_BRANCH is $FA_BRANCH" # In that case, we need to use the python reference attention implementation in vllm ARG BUILD_FA="1" -# whether to build cupy on rocm -ARG BUILD_CUPY="1" - # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y @@ -78,23 +75,6 @@ RUN if [ "$BUILD_FA" = "1" ]; then \ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi -# build cupy -RUN if [ "$BUILD_CUPY" = "1" ]; then \ - mkdir -p libs \ - && cd libs \ - && git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \ - && cd cupy \ - && pip install mpi4py-mpich \ - && pip install scipy==1.9.3 \ - && pip install cython==0.29.* \ - && env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \ - && export CUPY_INSTALL_USE_HIP=1 \ - && export ROCM_HOME=/opt/rocm \ - && export HCC_AMDGPU_TARGET="gfx90a,gfx942,gfx1100" \ - && pip install . \ - && cd ..; \ - fi - COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip diff --git a/requirements.txt b/requirements.txt index eb9977d93dd8d..6d75067b34a7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,3 @@ prometheus_client >= 0.18.0 pynvml == 11.5.0 triton >= 2.1.0 outlines == 0.0.34 -cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. diff --git a/setup.py b/setup.py index 9c9a428f94683..941d315d573bb 100644 --- a/setup.py +++ b/setup.py @@ -306,12 +306,6 @@ def get_requirements() -> List[str]: if _is_cuda(): with open(get_path("requirements.txt")) as f: requirements = f.read().strip().split("\n") - if get_nvcc_cuda_version() <= Version("11.8"): - # replace cupy-cuda12x with cupy-cuda11x for cuda 11.x - for i in range(len(requirements)): - if requirements[i].startswith("cupy-cuda12x"): - requirements[i] = "cupy-cuda11x" - break elif _is_hip(): with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 82075356fccbd..1eba14d7a6422 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -1,13 +1,22 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_basic_distributed_correctness.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_basic_distributed_correctness.py +``` """ +import os + import pytest import torch MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + os.environ["TEST_DIST_MODEL"], ] diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 1d376b18a66b3..0395f7200fd77 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -2,6 +2,8 @@ Run `pytest tests/distributed/test_comm_ops.py --forked`. """ +import os + import pytest import ray import torch @@ -16,6 +18,12 @@ @ray.remote(num_gpus=1, max_calls=1) def all_reduce_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) num_elements = 8 @@ -32,6 +40,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) num_dimensions = 3 @@ -54,6 +68,12 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) test_dict = { diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py new file mode 100644 index 0000000000000..797f18915dec9 --- /dev/null +++ b/tests/distributed/test_pynccl.py @@ -0,0 +1,90 @@ +import multiprocessing +import os + +import pytest +import torch + +from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, + ncclGetUniqueId) + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes = [] + for i in range(number_of_processes): + env = os.environ.copy() + env['RANK'] = str(i) + env['WORLD_SIZE'] = str(number_of_processes) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = '12345' + p = multiprocessing.Process(target=fn, args=(env, )) + processes.append(p) + p.start() + + for p in processes: + p.join() + + +def update_env(fn): + # `multiprocessing.Process` cannot accept environment variables directly + # so we need to pass the environment variables as arguments + # and update the environment variables in the function + def wrapper(env): + import os + os.environ.update(env) + fn() + + return wrapper + + +@update_env +def worker_fn(): + comm = NCCLCommunicator() + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == comm.world_size + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl(): + distributed_run(worker_fn, 2) + + +@update_env +def worker_fn_with_cudagraph(): + with torch.no_grad(): + graph = torch.cuda.CUDAGraph() + comm = NCCLCommunicator() + # run something in the default stream to initialize torch engine + a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + torch.cuda.synchronize() + with torch.cuda.graph(graph, stream=comm.stream): + # operation during the graph capture is recorded but not executed + # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa + comm.all_reduce(a) + comm.stream.synchronize() + assert a.mean().cpu().item() == comm.world_size**0 + graph.replay() + comm.stream.synchronize() + assert a.mean().cpu().item() == comm.world_size**1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_with_cudagraph(): + distributed_run(worker_fn_with_cudagraph, 2) + + +def test_ncclGetUniqueId(): + unique_id = ncclGetUniqueId() + # `list(unique_id.internal)` is something like this: + # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # as long as the function doesn't raise an exception, we're good + assert unique_id is not None diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index f2fc8aec9887b..d8288832a5300 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -188,11 +188,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", is_driver_worker=True, ) - # FIXME(woosuk): We are not properly initializing cupy NCCL when + # FIXME(woosuk): We are not properly initializing pynccl when # we have multiple nodes. - self._run_workers("init_device", - cupy_port=get_open_port() - if not model_config.enforce_eager else None) + self._run_workers("init_device") self._run_workers( "load_model", max_concurrent_workers=self.parallel_config. diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 04b30b4d093d7..9cbb40708dd5b 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -4,12 +4,12 @@ import torch from torch.distributed import ProcessGroup -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils from vllm.model_executor.parallel_utils.custom_all_reduce import ( custom_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, is_cupy_nccl_enabled_for_all_reduce) + get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -30,9 +30,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if is_cupy_nccl_enabled_for_all_reduce(): + if is_pynccl_enabled_for_all_reduce(): # TODO: support multiple parallel groups. - cupy_utils.all_reduce(input_) + pynccl_utils.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) diff --git a/vllm/model_executor/parallel_utils/cupy_utils.py b/vllm/model_executor/parallel_utils/cupy_utils.py deleted file mode 100644 index f8cffc01e3c36..0000000000000 --- a/vllm/model_executor/parallel_utils/cupy_utils.py +++ /dev/null @@ -1,130 +0,0 @@ -"""CuPy utilities for all-reduce. - -We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing -CUDA graphs, because torch.distributed.all_reduce causes errors when capturing -CUDA graphs. - -NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8. -TODO: Remove this file when torch.distributed.all_reduce is fixed. -""" -import contextlib - -import torch -from torch.distributed import ReduceOp - -try: - import cupy - from cupy.cuda import nccl - from cupyx.distributed import NCCLBackend -except ImportError as e: - cupy = e - nccl = None - - class NCCLBackend: - ... - - -_OP_MAPPING = { - ReduceOp.SUM: "sum", - ReduceOp.PRODUCT: "prod", - ReduceOp.MIN: "min", - ReduceOp.MAX: "max", -} - - -class NCCLBackendWithBFloat16(NCCLBackend): - # This is enough to add bfloat16 support for most operations, - # but broadcast will fail (will require changes in compiled - # cupy code). - def _get_nccl_dtype_and_count(self, array, count=None): - nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count) - torch_dtype = getattr(array, "_torch_dtype", None) - if torch_dtype is torch.bfloat16: - nccl_dtype = nccl.NCCL_BFLOAT16 - return nccl_dtype, count - - def barrier(self) -> None: - raise RuntimeError( - "Currently, CuPy NCCL barrier is not supported since the TCP " - "store is immediately stopped after the initialization.") - - -_NCCL_BACKEND = None -_WORLD_SIZE = 0 - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return _NCCL_BACKEND is not None - - -@contextlib.contextmanager -def set_cupy_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream, - stream.device_index) - with cupy_stream: - yield - - -def init_process_group(world_size: int, rank: int, host: str, - port: int) -> None: - """Initializes the CuPy NCCL backend. - - # TODO: handle NCCL timeouts. - """ - assert not is_initialized() - - if isinstance(cupy, Exception): - raise ImportError( - "NCCLBackend is not available. Please install cupy.") from cupy - - # TODO(woosuk): Create TP and PP process groups for CuPy. - global _NCCL_BACKEND - global _WORLD_SIZE - assert world_size > 0, f"{world_size=} should be a positive integer" - assert 0 <= rank < world_size, ( - f"{rank=} should be a integer between [0, {world_size})") - - cupy.cuda.runtime.setDevice(torch.cuda.current_device()) - _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port) - _WORLD_SIZE = world_size - - # Stop the TCP store to prevent the deadlock issues at termination time. - # FIXME(woosuk): This is hacky. Find a more robust solution. - if rank == 0 and hasattr(_NCCL_BACKEND, "_store"): - _NCCL_BACKEND._store.stop() - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - # Hack to support bfloat16 - torch_dtype = input_.dtype - if torch_dtype is torch.bfloat16: - # We need to view as float16, otherwise - # cupy will fail. This will not change - # the underlying data. - input_ = input_.view(torch.float16) - cupy_input = cupy.asarray(input_) - cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access - _NCCL_BACKEND.all_reduce(in_array=cupy_input, - out_array=cupy_input, - op=_OP_MAPPING[op]) - - -def destroy_process_group() -> None: - """Destroys the NCCL backend.""" - global _NCCL_BACKEND - global _WORLD_SIZE - _NCCL_BACKEND = None - _WORLD_SIZE = 0 - - -def get_world_size() -> int: - """Returns the world size.""" - return _WORLD_SIZE - - -def get_nccl_backend(): - return _NCCL_BACKEND diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index c821936d06e4e..bcda5ebf8548b 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -7,7 +7,7 @@ import torch -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None @@ -210,36 +210,36 @@ def destroy_model_parallel(): global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None - # Destroy the cupy states if any. - cupy_utils.destroy_process_group() + # Destroy the pynccl states if any. + pynccl_utils.destroy_process_group() -# Whether to use cupy for nccl all reduce. -# We use cupy for all reduce when using CUDA graph, because torch.distributed +# Whether to use pynccl for nccl all reduce. +# We use pynccl for all reduce when using CUDA graph, because torch.distributed # is not well supported by CUDA graph. -_ENABLE_CUPY_FOR_ALL_REDUCE = False +_ENABLE_PYNCCL_FOR_ALL_REDUCE = False @contextlib.contextmanager -def with_cupy_nccl_for_all_reduce(): - """use CuPy nccl instead of torch.distributed for all reduce""" +def with_pynccl_for_all_reduce(): + """use pynccl instead of torch.distributed for all reduce""" tp_size = get_tensor_model_parallel_world_size() if tp_size == 1: # No-op. - # NOTE(woosuk): We don't initialize CuPy when tp_size is 1. + # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. yield else: - global _ENABLE_CUPY_FOR_ALL_REDUCE - old = _ENABLE_CUPY_FOR_ALL_REDUCE - _ENABLE_CUPY_FOR_ALL_REDUCE = True + global _ENABLE_PYNCCL_FOR_ALL_REDUCE + old = _ENABLE_PYNCCL_FOR_ALL_REDUCE + _ENABLE_PYNCCL_FOR_ALL_REDUCE = True stream = torch.cuda.current_stream() - with cupy_utils.set_cupy_stream(stream): + with pynccl_utils.set_pynccl_stream(stream): yield - _ENABLE_CUPY_FOR_ALL_REDUCE = old + _ENABLE_PYNCCL_FOR_ALL_REDUCE = old -def is_cupy_nccl_enabled_for_all_reduce(): - """check if CuPy nccl is enabled for all reduce""" - global _ENABLE_CUPY_FOR_ALL_REDUCE - return _ENABLE_CUPY_FOR_ALL_REDUCE +def is_pynccl_enabled_for_all_reduce(): + """check if pynccl is enabled for all reduce""" + global _ENABLE_PYNCCL_FOR_ALL_REDUCE + return _ENABLE_PYNCCL_FOR_ALL_REDUCE diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py new file mode 100644 index 0000000000000..e8f0f74ad186b --- /dev/null +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -0,0 +1,258 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import datetime +import logging +import os + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp + +logger = logging.getLogger(__name__) + +so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") + +# manually load the nccl library +if so_file: + logger.info( + f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") +else: + if torch.version.cuda is not None: + so_file = "libnccl.so" + elif torch.version.hip is not None: + so_file = "librccl.so" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.debug(f"Loading nccl from library {so_file}") + +try: + nccl = ctypes.CDLL(so_file) +except Exception as e: + logger.error( + f"Failed to load NCCL library from {so_file} ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise please set the environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.") + raise e + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int + +# equivalent to c declaration: +# ncclResult_t ncclGetVersion(int *version); +_c_ncclGetVersion = nccl.ncclGetVersion +_c_ncclGetVersion.restype = ctypes.c_int +_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] + + +def ncclGetVersion() -> str: + version = ctypes.c_int() + result = _c_ncclGetVersion(ctypes.byref(version)) + assert result == 0 + # something like 21903 --> "2.19.3" + version_str = str(version.value) + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + +class NcclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +# equivalent to c declaration: +# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); +_c_ncclGetUniqueId = nccl.ncclGetUniqueId +_c_ncclGetUniqueId.restype = ctypes.c_int +_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] + + +def ncclGetUniqueId() -> NcclUniqueId: + unique_id = NcclUniqueId() + result = _c_ncclGetUniqueId(ctypes.byref(unique_id)) + assert result == 0 + return unique_id + + +# equivalent to c declaration: +# ncclResult_t ncclCommInitRank( +# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); +# note that ncclComm_t is a pointer type, so the first argument +# is a pointer to a pointer +_c_ncclCommInitRank = nccl.ncclCommInitRank +_c_ncclCommInitRank.restype = ctypes.c_int +_c_ncclCommInitRank.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int +] + + +# enums +class ncclDataType_t(ctypes.c_int): + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +class ncclRedOp_t(ctypes.c_int): + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +# equivalent to c declaration: +# ncclResult_t ncclAllReduce( +# const void* sendbuff, void* recvbuff, size_t count, +# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, +# udaStream_t stream); +# note that cudaStream_t is a pointer type, so the last argument is a pointer +_c_ncclAllReduce = nccl.ncclAllReduce +_c_ncclAllReduce.restype = ctypes.c_int +_c_ncclAllReduce.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p +] + +# equivalent to c declaration: +# ncclResult_t ncclCommDestroy(ncclComm_t comm); +_c_ncclCommDestroy = nccl.ncclCommDestroy +_c_ncclCommDestroy.restype = ctypes.c_int +_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] + + +class NCCLCommunicator: + + def __init__( + self, + backend=None, + init_method=None, + timeout=datetime.timedelta(seconds=10), + world_size: int = -1, + rank: int = -1, + store=None, + group_name: str = "", + pg_options=None, + ): + if not dist.is_initialized(): + backend = backend or "nccl" + assert backend == 'nccl', ( + "only use nccl backend for starting the NCCL communicator") + dist.init_process_group(backend=backend, + init_method=init_method, + timeout=timeout, + world_size=world_size, + rank=rank, + store=store, + group_name=group_name, + pg_options=pg_options) + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + torch.cuda.set_device(self.rank) + if self.rank == 0: + self.unique_id = ncclGetUniqueId() + else: + self.unique_id = NcclUniqueId() + tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( + self.rank) + dist.broadcast(tensor, src=0) + byte_list = tensor.cpu().tolist() + self.unique_id = NcclUniqueId() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + self.comm = ctypes.c_void_p() + result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, + self.unique_id, self.rank) + assert result == 0 + self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}") + + def all_reduce(self, + tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if stream is None: + stream = self.stream + result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), + ctypes.c_void_p(tensor.data_ptr()), + tensor.numel(), + ncclDataType_t.from_torch(tensor.dtype), + ncclRedOp_t.from_torch(op), self.comm, + ctypes.c_void_p(stream.cuda_stream)) + assert result == 0 + + def __del__(self): + dist.destroy_process_group() + _c_ncclCommDestroy(self.comm) diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py new file mode 100644 index 0000000000000..a12d620d7a24c --- /dev/null +++ b/vllm/model_executor/parallel_utils/pynccl_utils.py @@ -0,0 +1,64 @@ +import contextlib +import logging +from typing import Optional + +import torch +from torch.distributed import ReduceOp + +logger = logging.getLogger(__name__) + +try: + from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, + ncclGetVersion) + logger.info(f"vLLM is using nccl=={ncclGetVersion()}") +except Exception as e: + # in non-NVIDIA environments, we can't import the nccl module + # e.g. when running on machines with AMD GPUs + logger.info(f"Failed to import NCCL library: {e}") + logger.info("It is expected if you are not running on NVIDIA GPUs.") + pass + +comm: Optional["NCCLCommunicator"] = None + + +def is_initialized() -> bool: + """Returns whether the NCCL backend is initialized.""" + return comm is not None + + +@contextlib.contextmanager +def set_pynccl_stream(stream: torch.cuda.Stream): + """Set the cuda stream for communication""" + try: + comm.stream = stream + yield + finally: + pass + + +def init_process_group(world_size: int, rank: int, init_method: str) -> None: + assert not is_initialized() + global comm + comm = NCCLCommunicator(init_method=init_method, + world_size=world_size, + rank=rank) + + +def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: + """All-reduces the input tensor across the process group.""" + assert input_.is_cuda, f"{input_} should be a cuda tensor" + comm.all_reduce(input_, op) + + +def destroy_process_group() -> None: + global comm + comm = None + + +def get_world_size() -> int: + """Returns the world size.""" + return comm.world_size + + +def get_nccl_backend(): + return comm diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 75bf6ce373d93..5b2eeafad197e 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -16,10 +16,7 @@ def init_test_distributed_environment( worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - parallel_config, - rank, - cupy_port=None, - distributed_init_method=distributed_init_method) + parallel_config, rank, distributed_init_method=distributed_init_method) def multi_process_tensor_parallel( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8a08c3cbf5836..f0c98700ab749 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -15,11 +15,11 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import cupy_utils, custom_all_reduce +from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( - with_cupy_nccl_for_all_reduce) + with_pynccl_for_all_reduce) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) @@ -764,7 +764,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. - self.cupy_nccl_backend = cupy_utils.get_nccl_backend() + self.pynccl_backend = pynccl_utils.get_nccl_backend() assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " @@ -794,11 +794,11 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ] # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce - # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use - # either custom all-reduce kernel or CuPy NCCL. When not using CUDA + # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA # graph, we use either custom all-reduce kernel or PyTorch NCCL. # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or CuPy NCCL if it is disabled or not supported. + # to PyTorch or pynccl if it is disabled or not supported. with custom_all_reduce.capture(): # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. @@ -846,12 +846,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") def __del__(self) -> None: - # Delete the CUDA graphs before deleting the CuPy NCCL communicator. + # Delete the CUDA graphs before deleting the pynccl communicator. # NOTE(woosuk): This is necessary because otherwise deadlocks can # happen. # FIXME(woosuk): This is a bit hacky. Find a more robust solution. + # TODO(youkaichao): when we get enough user feedback that pynccl is + # more stable than cupy, we can remove this, e.g. in v0.4.1. self.graph_runners.clear() - self.cupy_nccl_backend = None + self.pynccl_backend = None @property def vocab_size(self) -> int: @@ -879,7 +881,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_cupy_nccl(): + with _maybe_pynccl(): self.model( input_ids, positions, @@ -894,7 +896,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 - with _maybe_cupy_nccl(): + with _maybe_pynccl(): hidden_states = self.model( input_ids, positions, @@ -947,9 +949,10 @@ def __call__(self, *args, **kwargs): @contextlib.contextmanager -def _maybe_cupy_nccl(): - if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized(): - with with_cupy_nccl_for_all_reduce(): +def _maybe_pynccl(): + if pynccl_utils.is_initialized( + ) and not custom_all_reduce.is_initialized(): + with with_pynccl_for_all_reduce(): yield else: yield diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 46a62fa693258..6459c0cda669a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,7 +10,7 @@ ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar @@ -75,7 +75,7 @@ def __init__( self.cache_engine = None self.gpu_cache = None - def init_device(self, cupy_port: Optional[int] = None) -> None: + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow @@ -98,7 +98,7 @@ def init_device(self, cupy_port: Optional[int] = None) -> None: f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. init_distributed_environment(self.parallel_config, self.rank, - cupy_port, self.distributed_init_method) + self.distributed_init_method) # Set random seed. set_random_seed(self.model_config.seed) @@ -250,7 +250,6 @@ def get_cache_block_size_bytes(self, block_size: int, def init_distributed_environment( parallel_config: ParallelConfig, rank: int, - cupy_port: Optional[int], distributed_init_method: Optional[str] = None, ) -> None: """Initialize the distributed environment.""" @@ -273,28 +272,27 @@ def init_distributed_environment( init_method=distributed_init_method, ) - if cupy_utils.is_initialized(): - cupy_world_size = cupy_utils.get_world_size() - if cupy_world_size != parallel_config.world_size: + if pynccl_utils.is_initialized(): + pynccl_world_size = pynccl_utils.get_world_size() + if pynccl_world_size != parallel_config.world_size: raise RuntimeError( - "cupy.distributed is already initialized but the cupy world " + "pynccl is already initialized but the pynccl world " "size does not match parallel_config.world_size " - f"({cupy_world_size} vs. {parallel_config.world_size}).") - elif (parallel_config.world_size > 1 and cupy_port is not None): - # NOTE(woosuk): We don't initialize CuPy process group when world size + f"({pynccl_world_size} vs. {parallel_config.world_size}).") + elif parallel_config.world_size > 1: + # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. # TODO(woosuk): Support multi-node connection. - cupy_utils.init_process_group( + pynccl_utils.init_process_group( world_size=parallel_config.world_size, rank=rank, - host="localhost", - port=cupy_port, + init_method=distributed_init_method, ) # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - if cupy_utils.is_initialized(): - cupy_utils.all_reduce(torch.zeros(1).cuda()) + if pynccl_utils.is_initialized(): + pynccl_utils.all_reduce(torch.zeros(1).cuda()) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)