From 24868ee89aafc6d380f0ab94381c65880501c776 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 17:51:02 -0700 Subject: [PATCH 01/29] update CMakeLists.txt --- CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 66842e6845edd..71acc31a39dc1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,10 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") +# used when building pytorch-related extensions +# TODO: only compute this if we are building wheels +# TODO: otherwise autodetect the version from current GPUs +set(TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0;8.6;8.9;9.0") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") From 0af327ba5fea9898cdb3e414e8677d1f22cfc2a0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 17:54:35 -0700 Subject: [PATCH 02/29] import file from huge PR --- tests/distributed/test_pynccl.py | 88 +++++++ .../parallel_utils/communication_op.py | 15 +- .../parallel_utils/cupy_utils.py | 130 ---------- .../parallel_utils/parallel_state.py | 34 +-- vllm/model_executor/parallel_utils/pynccl.py | 239 ++++++++++++++++++ .../parallel_utils/pynccl_utils.py | 67 +++++ vllm/worker/model_runner.py | 217 ++++++---------- vllm/worker/worker.py | 45 ++-- 8 files changed, 513 insertions(+), 322 deletions(-) create mode 100644 tests/distributed/test_pynccl.py delete mode 100644 vllm/model_executor/parallel_utils/cupy_utils.py create mode 100644 vllm/model_executor/parallel_utils/pynccl.py create mode 100644 vllm/model_executor/parallel_utils/pynccl_utils.py diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py new file mode 100644 index 0000000000000..58376306c277e --- /dev/null +++ b/tests/distributed/test_pynccl.py @@ -0,0 +1,88 @@ +# this script is not run with `pytest`. +# It is run with `torchrun`. +import os +import multiprocessing +import pytest +import torch +from vllm.model_executor.parallel_utils.pynccl import ( + NCCLCommunicator, + ncclGetUniqueId, +) + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes = [] + for i in range(number_of_processes): + env = os.environ.copy() + env['RANK'] = str(i) + env['WORLD_SIZE'] = str(number_of_processes) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = '12345' + p = multiprocessing.Process(target=fn, args=(env, )) + processes.append(p) + p.start() + + for p in processes: + p.join() + + +def update_env(fn): + + def wrapper(env): + import os + os.environ.update(env) + fn() + + return wrapper + + +@update_env +def worker_fn(): + comm = NCCLCommunicator() + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == comm.world_size + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl(): + distributed_run(worker_fn, 2) + + +@update_env +def worker_fn_with_cudagraph(): + with torch.no_grad(): + graph = torch.cuda.CUDAGraph() + comm = NCCLCommunicator() + # run something in the default stream to initialize torch engine + a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + torch.cuda.synchronize() + with torch.cuda.graph(graph, stream=comm.stream): + comm.all_reduce(a) + comm.stream.synchronize() + assert a.mean().cpu().item() == comm.world_size**0 + graph.replay() + comm.stream.synchronize() + assert a.mean().cpu().item() == comm.world_size**2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_with_cudagraph(): + distributed_run(worker_fn_with_cudagraph, 2) + + +def test_ncclGetUniqueId(): + unique_id = ncclGetUniqueId() + # `list(unique_id.internal)` is something like this: + # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # as long as the function doesn't raise an exception, we're good + assert unique_id is not None diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 04b30b4d093d7..28433d31f56a5 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -4,12 +4,15 @@ import torch from torch.distributed import ProcessGroup -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_group, + is_pynccl_enabled_for_all_reduce, +) from vllm.model_executor.parallel_utils.custom_all_reduce import ( custom_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, is_cupy_nccl_enabled_for_all_reduce) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -30,9 +33,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if is_cupy_nccl_enabled_for_all_reduce(): + if is_pynccl_enabled_for_all_reduce(): # TODO: support multiple parallel groups. - cupy_utils.all_reduce(input_) + pynccl_utils.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) diff --git a/vllm/model_executor/parallel_utils/cupy_utils.py b/vllm/model_executor/parallel_utils/cupy_utils.py deleted file mode 100644 index f8cffc01e3c36..0000000000000 --- a/vllm/model_executor/parallel_utils/cupy_utils.py +++ /dev/null @@ -1,130 +0,0 @@ -"""CuPy utilities for all-reduce. - -We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing -CUDA graphs, because torch.distributed.all_reduce causes errors when capturing -CUDA graphs. - -NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8. -TODO: Remove this file when torch.distributed.all_reduce is fixed. -""" -import contextlib - -import torch -from torch.distributed import ReduceOp - -try: - import cupy - from cupy.cuda import nccl - from cupyx.distributed import NCCLBackend -except ImportError as e: - cupy = e - nccl = None - - class NCCLBackend: - ... - - -_OP_MAPPING = { - ReduceOp.SUM: "sum", - ReduceOp.PRODUCT: "prod", - ReduceOp.MIN: "min", - ReduceOp.MAX: "max", -} - - -class NCCLBackendWithBFloat16(NCCLBackend): - # This is enough to add bfloat16 support for most operations, - # but broadcast will fail (will require changes in compiled - # cupy code). - def _get_nccl_dtype_and_count(self, array, count=None): - nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count) - torch_dtype = getattr(array, "_torch_dtype", None) - if torch_dtype is torch.bfloat16: - nccl_dtype = nccl.NCCL_BFLOAT16 - return nccl_dtype, count - - def barrier(self) -> None: - raise RuntimeError( - "Currently, CuPy NCCL barrier is not supported since the TCP " - "store is immediately stopped after the initialization.") - - -_NCCL_BACKEND = None -_WORLD_SIZE = 0 - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return _NCCL_BACKEND is not None - - -@contextlib.contextmanager -def set_cupy_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream, - stream.device_index) - with cupy_stream: - yield - - -def init_process_group(world_size: int, rank: int, host: str, - port: int) -> None: - """Initializes the CuPy NCCL backend. - - # TODO: handle NCCL timeouts. - """ - assert not is_initialized() - - if isinstance(cupy, Exception): - raise ImportError( - "NCCLBackend is not available. Please install cupy.") from cupy - - # TODO(woosuk): Create TP and PP process groups for CuPy. - global _NCCL_BACKEND - global _WORLD_SIZE - assert world_size > 0, f"{world_size=} should be a positive integer" - assert 0 <= rank < world_size, ( - f"{rank=} should be a integer between [0, {world_size})") - - cupy.cuda.runtime.setDevice(torch.cuda.current_device()) - _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port) - _WORLD_SIZE = world_size - - # Stop the TCP store to prevent the deadlock issues at termination time. - # FIXME(woosuk): This is hacky. Find a more robust solution. - if rank == 0 and hasattr(_NCCL_BACKEND, "_store"): - _NCCL_BACKEND._store.stop() - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - # Hack to support bfloat16 - torch_dtype = input_.dtype - if torch_dtype is torch.bfloat16: - # We need to view as float16, otherwise - # cupy will fail. This will not change - # the underlying data. - input_ = input_.view(torch.float16) - cupy_input = cupy.asarray(input_) - cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access - _NCCL_BACKEND.all_reduce(in_array=cupy_input, - out_array=cupy_input, - op=_OP_MAPPING[op]) - - -def destroy_process_group() -> None: - """Destroys the NCCL backend.""" - global _NCCL_BACKEND - global _WORLD_SIZE - _NCCL_BACKEND = None - _WORLD_SIZE = 0 - - -def get_world_size() -> int: - """Returns the world size.""" - return _WORLD_SIZE - - -def get_nccl_backend(): - return _NCCL_BACKEND diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index c821936d06e4e..63890d9cd5bd8 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -7,7 +7,7 @@ import torch -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None @@ -210,36 +210,36 @@ def destroy_model_parallel(): global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None - # Destroy the cupy states if any. - cupy_utils.destroy_process_group() + # Destroy the pynccl states if any. + pynccl_utils.destroy_process_group() -# Whether to use cupy for nccl all reduce. -# We use cupy for all reduce when using CUDA graph, because torch.distributed +# Whether to use pynccl for nccl all reduce. +# We use pynccl for all reduce when using CUDA graph, because torch.distributed # is not well supported by CUDA graph. -_ENABLE_CUPY_FOR_ALL_REDUCE = False +_ENABLE_PYNCCL_FOR_ALL_REDUCE = False @contextlib.contextmanager -def with_cupy_nccl_for_all_reduce(): - """use CuPy nccl instead of torch.distributed for all reduce""" +def with_pynccl_for_all_reduce(): + """use pynccl instead of torch.distributed for all reduce""" tp_size = get_tensor_model_parallel_world_size() if tp_size == 1: # No-op. # NOTE(woosuk): We don't initialize CuPy when tp_size is 1. yield else: - global _ENABLE_CUPY_FOR_ALL_REDUCE - old = _ENABLE_CUPY_FOR_ALL_REDUCE - _ENABLE_CUPY_FOR_ALL_REDUCE = True + global _ENABLE_PYNCCL_FOR_ALL_REDUCE + old = _ENABLE_PYNCCL_FOR_ALL_REDUCE + _ENABLE_PYNCCL_FOR_ALL_REDUCE = True stream = torch.cuda.current_stream() - with cupy_utils.set_cupy_stream(stream): + with pynccl_utils.set_pynccl_stream(stream): yield - _ENABLE_CUPY_FOR_ALL_REDUCE = old + _ENABLE_PYNCCL_FOR_ALL_REDUCE = old -def is_cupy_nccl_enabled_for_all_reduce(): - """check if CuPy nccl is enabled for all reduce""" - global _ENABLE_CUPY_FOR_ALL_REDUCE - return _ENABLE_CUPY_FOR_ALL_REDUCE +def is_pynccl_enabled_for_all_reduce(): + """check if pynccl is enabled for all reduce""" + global _ENABLE_PYNCCL_FOR_ALL_REDUCE + return _ENABLE_PYNCCL_FOR_ALL_REDUCE diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py new file mode 100644 index 0000000000000..9f0aaf5f9321b --- /dev/null +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -0,0 +1,239 @@ +# ===================== pynccl.py ================================== +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/199366 +# ==================================================== + +# ===================== import region ===================== +import torch +import ctypes +import torch.distributed as dist +from torch.distributed import ReduceOp +import datetime +import os +import glob +import logging + +logger = logging.getLogger(__name__) + +so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") + +# manually load the nccl library +if so_file: + logger.info( + f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") +else: + _path = os.path.dirname(os.path.abspath(__file__)) + so_file = glob.glob(f"{_path}/../../lib/nvidia/nccl/lib/libnccl.so.*")[0] + logger.info(f"Loading nccl from vLLM builtin file {so_file}") +nccl = ctypes.CDLL(so_file) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int + +# equivalent to c declaration: +# ncclResult_t ncclGetVersion(int *version); +_c_ncclGetVersion = nccl.ncclGetVersion +_c_ncclGetVersion.restype = ctypes.c_int +_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] + + +def ncclGetVersion() -> int: + version = ctypes.c_int() + result = _c_ncclGetVersion(ctypes.byref(version)) + assert result == 0 + # something like 21903 --> "2.19.3" + version_str = str(version.value) + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + +class NcclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +# equivalent to c declaration: +# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); +_c_ncclGetUniqueId = nccl.ncclGetUniqueId +_c_ncclGetUniqueId.restype = ctypes.c_int +_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] + + +def ncclGetUniqueId() -> NcclUniqueId: + unique_id = NcclUniqueId() + result = _c_ncclGetUniqueId(ctypes.byref(unique_id)) + assert result == 0 + return unique_id + + +# equivalent to c declaration: +# ncclResult_t ncclCommInitRank( +# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); +# note that ncclComm_t is a pointer type, so the first argument +# is a pointer to a pointer +_c_ncclCommInitRank = nccl.ncclCommInitRank +_c_ncclCommInitRank.restype = ctypes.c_int +_c_ncclCommInitRank.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int +] + + +# enums +class ncclDataType_t(ctypes.c_int): + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +class ncclRedOp_t(ctypes.c_int): + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +# equivalent to c declaration: +# ncclResult_t ncclAllReduce( +# const void* sendbuff, void* recvbuff, size_t count, +# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, +# udaStream_t stream); +# note that cudaStream_t is a pointer type, so the last argument is a pointer +_c_ncclAllReduce = nccl.ncclAllReduce +_c_ncclAllReduce.restype = ctypes.c_int +_c_ncclAllReduce.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p +] + +# equivalent to c declaration: +# ncclResult_t ncclCommDestroy(ncclComm_t comm); +_c_ncclCommDestroy = nccl.ncclCommDestroy +_c_ncclCommDestroy.restype = ctypes.c_int +_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] + + +class NCCLCommunicator: + + def __init__( + self, + backend=None, + init_method=None, + timeout=datetime.timedelta(seconds=10), + world_size: int = -1, + rank: int = -1, + store=None, + group_name: str = "", + pg_options=None, + ): + if not dist.is_initialized(): + backend = backend or "nccl" + assert backend == 'nccl', ( + "only use nccl backend for starting the NCCL communicator") + dist.init_process_group(backend=backend, + init_method=init_method, + timeout=timeout, + world_size=world_size, + rank=rank, + store=store, + group_name=group_name, + pg_options=pg_options) + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + torch.cuda.set_device(self.rank) + if self.rank == 0: + self.unique_id = ncclGetUniqueId() + else: + self.unique_id = NcclUniqueId() + tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( + self.rank) + dist.broadcast(tensor, src=0) + byte_list = tensor.cpu().tolist() + self.unique_id = NcclUniqueId() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + self.comm = ctypes.c_void_p() + result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, + self.unique_id, self.rank) + assert result == 0 + self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}") + + def all_reduce(self, + tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if stream is None: + stream = self.stream + result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), + ctypes.c_void_p(tensor.data_ptr()), + tensor.numel(), + ncclDataType_t.from_torch(tensor.dtype), + ncclRedOp_t.from_torch(op), self.comm, + ctypes.c_void_p(stream.cuda_stream)) + assert result == 0 + + def __del__(self): + dist.destroy_process_group() + _c_ncclCommDestroy(self.comm) + + +# ===================== pynccl.py ===================== diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py new file mode 100644 index 0000000000000..e498526b71bb8 --- /dev/null +++ b/vllm/model_executor/parallel_utils/pynccl_utils.py @@ -0,0 +1,67 @@ +import contextlib +import logging +import torch + +from typing import Optional +from torch.distributed import ReduceOp + +logger = logging.getLogger(__name__) + +try: + from vllm.model_executor.parallel_utils.pynccl import ( + NCCLCommunicator, + ncclGetVersion, + ) + logger.info(f"vLLM is using nccl=={ncclGetVersion()}") +except Exception as e: + # in non-NVIDIA environments, we can't import the nccl module + # e.g. when running on machines with AMD GPUs + logger.info(f"Failed to import NCCL library: {e}") + logger.info("It is expected if you are not running on NVIDIA GPUs.") + pass + +comm: Optional["NCCLCommunicator"] = None + + +def is_initialized() -> bool: + """Returns whether the NCCL backend is initialized.""" + return comm is not None + + +@contextlib.contextmanager +def set_pynccl_stream(stream: torch.cuda.Stream): + """Set the cuda stream for communication""" + try: + comm.stream = stream + yield + finally: + pass + + +def init_process_group(world_size: int, rank: int, host: str, + port: int) -> None: + assert not is_initialized() + global comm + comm = NCCLCommunicator(init_method=f"tcp://{host}:{port}", + world_size=world_size, + rank=rank) + + +def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: + """All-reduces the input tensor across the process group.""" + assert input_.is_cuda, f"{input_} should be a cuda tensor" + comm.all_reduce(input_, op) + + +def destroy_process_group() -> None: + global comm + comm = None + + +def get_world_size() -> int: + """Returns the world size.""" + return comm.world_size + + +def get_nccl_backend(): + return comm diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8a08c3cbf5836..374f519afc81f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,34 +1,34 @@ import contextlib import time -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Tuple, Set import numpy as np import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VisionLanguageConfig) +from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, + SchedulerConfig) from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata +from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import cupy_utils, custom_all_reduce +from vllm.model_executor.parallel_utils import pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( - with_cupy_nccl_for_all_reduce) + with_pynccl_for_all_reduce) +from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, - SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) logger = init_logger(__name__) +KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 @@ -50,7 +50,6 @@ def __init__( lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - vision_language_config: Optional[VisionLanguageConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config @@ -85,20 +84,14 @@ def __init__( self.graph_block_tables = None # Set after initial profiling. self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype - self.vision_language_config = vision_language_config - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) def load_model(self) -> None: with CudaMemoryProfiler() as m: - self.model = get_model( - self.model_config, - self.device_config, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) self.model_memory_usage = m.consumed_memory logger.info(f"Loading model weights took " @@ -134,9 +127,8 @@ def get_max_block_per_batch(self) -> int: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], List[int], List[int], Set[LoRARequest], - torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -149,7 +141,6 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -195,10 +186,6 @@ def _prepare_prompt( (prompt_len - computed_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) - if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -229,7 +216,7 @@ def _prepare_prompt( slot_mapping.append(slot) max_subquery_len = max(subquery_lens) - max_prompt_len = max(prompt_lens) + max_seq_len = max(prompt_lens) num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 @@ -247,16 +234,6 @@ def _prepare_prompt( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) - - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - # Prepare prefix block tables max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) block_tables = make_tensor_with_pad( @@ -293,7 +270,7 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - attn_metadata = self.attn_backend.make_metadata( + input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, prompt_lens=prompt_lens, @@ -302,7 +279,7 @@ def _prepare_prompt( num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, - max_prompt_len=max_prompt_len, + max_seq_len=max_seq_len, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, @@ -310,15 +287,15 @@ def _prepare_prompt( use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, + return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input) + lora_requests) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], Set[LoRARequest]]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -424,7 +401,7 @@ def _prepare_decode( device=self.device, ) - attn_metadata = self.attn_backend.make_metadata( + input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, @@ -433,7 +410,7 @@ def _prepare_decode( num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, - max_prompt_len=None, + max_seq_len=None, subquery_start_loc=None, seq_start_loc=None, context_lens=context_lens, @@ -441,7 +418,7 @@ def _prepare_decode( use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, + return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) def _prepare_sample( @@ -545,25 +522,23 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[int], LoRAMapping, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, + Set[int], LoRAMapping]: if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, + (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input - ) = self._prepare_prompt(seq_group_metadata_list) + lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, attn_metadata, + (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] subquery_lens = None - multi_modal_input = None sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) @@ -584,9 +559,8 @@ def prepare_input_tensors( sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, - "multi_modal_input": multi_modal_input, } - metadata_dict.update(attn_metadata.asdict_zerocopy()) + metadata_dict.update(input_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) @@ -596,8 +570,7 @@ def prepare_input_tensors( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - multi_modal_input = metadata_dict.pop("multi_modal_input") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + input_metadata = InputMetadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -608,38 +581,34 @@ def prepare_input_tensors( perform_sampling=False, ) - return (input_tokens, input_positions, attn_metadata, - sampling_metadata, lora_requests, lora_mapping, - multi_modal_input) + return (input_tokens, input_positions, input_metadata, + sampling_metadata, lora_requests, lora_mapping) @torch.inference_mode() def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[torch.Tensor], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata, sampling_metadata, + lora_requests, + lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) # Execute the model. - if attn_metadata.use_cuda_graph: + if input_metadata.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, - } - if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) - hidden_states = model_executable(**execute_model_kwargs) + hidden_states = model_executable( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) @@ -687,22 +656,10 @@ def profile_run(self) -> None: # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for vision encoding, which needs - # to be accounted for when calculating the GPU blocks for - # vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - if self.vision_language_config: - max_num_seqs = min( - max_num_seqs, - int(max_num_batched_tokens / - self.vision_language_config.image_feature_size)) for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data, fake_multi_modal_input = _prepare_fake_inputs( - seq_len, self.vision_language_config) + seq_data = SequenceData([0] * seq_len) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -711,13 +668,12 @@ def profile_run(self) -> None: block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, - multi_modal_data=fake_multi_modal_input, ) seqs.append(seq) # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers + kv_caches = [(None, None)] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() return @@ -749,7 +705,7 @@ def list_loras(self) -> Set[int]: return self.lora_manager.list_loras() @torch.inference_mode() - def capture_model(self, kv_caches: List[torch.Tensor]) -> None: + def capture_model(self, kv_caches: List[KVCache]) -> None: """Cuda graph capture a model. Note that CUDA graph's performance gain is negligible if number @@ -764,7 +720,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. - self.cupy_nccl_backend = cupy_utils.get_nccl_backend() + self.cupy_nccl_backend = pynccl_utils.get_nccl_backend() assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " @@ -803,8 +759,8 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): - # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( + # Create dummy input_metadata. + input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, @@ -813,7 +769,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, - max_prompt_len=None, + max_seq_len=None, subquery_start_loc=None, seq_start_loc=None, context_lens=context_lens[:batch_size], @@ -834,7 +790,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: input_tokens[:batch_size], input_positions[:batch_size], kv_caches, - attn_metadata, + input_metadata, memory_pool=self.graph_memory_pool, ) self.graph_memory_pool = graph_runner.graph.pool() @@ -870,22 +826,20 @@ def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, + kv_caches: List[KVCache], + input_metadata: InputMetadata, memory_pool, - **kwargs, ) -> None: assert self.graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_cupy_nccl(): + with _maybe_pynccl(): self.model( input_ids, positions, kv_caches, - attn_metadata, - **kwargs, + input_metadata, ) torch.cuda.synchronize() @@ -894,13 +848,12 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 - with _maybe_cupy_nccl(): + with _maybe_pynccl(): hidden_states = self.model( input_ids, positions, kv_caches, - attn_metadata, - **kwargs, + input_metadata, ) torch.cuda.synchronize() @@ -909,9 +862,9 @@ def capture( "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.context_lens, - "block_tables": attn_metadata.block_tables, + "slot_mapping": input_metadata.slot_mapping, + "context_lens": input_metadata.context_lens, + "block_tables": input_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -920,9 +873,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - **kwargs, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + input_metadata: InputMetadata, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. del kv_caches @@ -930,11 +882,11 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, + self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, + self.input_buffers["context_lens"].copy_(input_metadata.context_lens, non_blocking=True) - self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, + self.input_buffers["block_tables"].copy_(input_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() @@ -947,9 +899,10 @@ def __call__(self, *args, **kwargs): @contextlib.contextmanager -def _maybe_cupy_nccl(): - if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized(): - with with_cupy_nccl_for_all_reduce(): +def _maybe_pynccl(): + if pynccl_utils.is_initialized( + ) and not custom_all_reduce.is_initialized(): + with with_pynccl_for_all_reduce(): yield else: yield @@ -968,21 +921,3 @@ def _get_graph_batch_size(batch_size: int) -> int: else: return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - -def _prepare_fake_inputs( - seq_len: int, vision_language_config: Optional[VisionLanguageConfig]): - """Prepare fake inputs for profile run.""" - if vision_language_config: - prompt_tokens = [ - vision_language_config.image_token_id - ] * vision_language_config.image_feature_size + [0] * ( - seq_len - vision_language_config.image_feature_size) - fake_image_input = MultiModalData( - type=MultiModalData.Type.IMAGE, - data=torch.zeros(vision_language_config.image_input_shape, - dtype=torch.float16)) - else: - prompt_tokens = [0] * seq_len - fake_image_input = None - return SequenceData(prompt_tokens), fake_image_input diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 46a62fa693258..c979effae048d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,16 +1,15 @@ """A GPU worker class.""" import gc import os -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Tuple, Set, Optional import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.lora.request import LoRARequest +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils import pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar @@ -19,6 +18,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner +from vllm.lora.request import LoRARequest class Worker: @@ -39,7 +39,6 @@ def __init__( rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, - vision_language_config: Optional[VisionLanguageConfig] = None, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: @@ -55,20 +54,13 @@ def __init__( if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - self.vision_language_config = vision_language_config - if self.vision_language_config: - assert not self.lora_config, ( - "To be tested: vision language model with LoRA settings.") - - self.model_runner = ModelRunner( - model_config, - parallel_config, - scheduler_config, - device_config, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + self.model_runner = ModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None @@ -136,9 +128,6 @@ def profile_num_available_blocks( # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory - assert peak_memory > 0, ( - "Error in memory profiling. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes( block_size, cache_dtype) @@ -273,8 +262,8 @@ def init_distributed_environment( init_method=distributed_init_method, ) - if cupy_utils.is_initialized(): - cupy_world_size = cupy_utils.get_world_size() + if pynccl_utils.is_initialized(): + cupy_world_size = pynccl_utils.get_world_size() if cupy_world_size != parallel_config.world_size: raise RuntimeError( "cupy.distributed is already initialized but the cupy world " @@ -284,7 +273,7 @@ def init_distributed_environment( # NOTE(woosuk): We don't initialize CuPy process group when world size # is 1. # TODO(woosuk): Support multi-node connection. - cupy_utils.init_process_group( + pynccl_utils.init_process_group( world_size=parallel_config.world_size, rank=rank, host="localhost", @@ -293,8 +282,8 @@ def init_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - if cupy_utils.is_initialized(): - cupy_utils.all_reduce(torch.zeros(1).cuda()) + if pynccl_utils.is_initialized(): + pynccl_utils.all_reduce(torch.zeros(1).cuda()) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) From d3439b996ee824b5e3dd9399c377a9df4dcf2868 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 17:55:18 -0700 Subject: [PATCH 03/29] isort --- tests/distributed/test_pynccl.py | 10 +++++----- .../parallel_utils/communication_op.py | 9 +++------ vllm/model_executor/parallel_utils/pynccl.py | 11 ++++++----- .../model_executor/parallel_utils/pynccl_utils.py | 10 ++++------ vllm/worker/model_runner.py | 15 +++++++-------- vllm/worker/worker.py | 8 ++++---- 6 files changed, 29 insertions(+), 34 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 58376306c277e..2adfcc0cb321d 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,13 +1,13 @@ # this script is not run with `pytest`. # It is run with `torchrun`. -import os import multiprocessing +import os + import pytest import torch -from vllm.model_executor.parallel_utils.pynccl import ( - NCCLCommunicator, - ncclGetUniqueId, -) + +from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, + ncclGetUniqueId) def distributed_run(fn, world_size): diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 28433d31f56a5..9cbb40708dd5b 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -5,14 +5,11 @@ from torch.distributed import ProcessGroup from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_group, - is_pynccl_enabled_for_all_reduce, -) from vllm.model_executor.parallel_utils.custom_all_reduce import ( custom_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 9f0aaf5f9321b..6175d56f25c6d 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -10,15 +10,16 @@ # https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/199366 # ==================================================== -# ===================== import region ===================== -import torch import ctypes -import torch.distributed as dist -from torch.distributed import ReduceOp import datetime -import os import glob import logging +import os + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp logger = logging.getLogger(__name__) diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py index e498526b71bb8..01d45e8ce94a1 100644 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ b/vllm/model_executor/parallel_utils/pynccl_utils.py @@ -1,17 +1,15 @@ import contextlib import logging -import torch - from typing import Optional + +import torch from torch.distributed import ReduceOp logger = logging.getLogger(__name__) try: - from vllm.model_executor.parallel_utils.pynccl import ( - NCCLCommunicator, - ncclGetVersion, - ) + from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, + ncclGetVersion) logger.info(f"vLLM is using nccl=={ncclGetVersion()}") except Exception as e: # in non-NVIDIA environments, we can't import the nccl module diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 374f519afc81f..efe623aa5a034 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,28 +1,27 @@ import contextlib import time -from typing import Dict, List, Optional, Tuple, Set +from typing import Dict, List, Optional, Set, Tuple import numpy as np import torch import torch.nn as nn -from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, +from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import pynccl_utils +from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( with_pynccl_for_all_reduce) -from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler, +from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c979effae048d..1457aca677132 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,13 +1,14 @@ """A GPU worker class.""" import gc import os -from typing import Dict, List, Tuple, Set, Optional +from typing import Dict, List, Optional, Set, Tuple import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils import pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( @@ -18,7 +19,6 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner -from vllm.lora.request import LoRARequest class Worker: From 37fd9fd587ae3cd8549fa9deb691cf8a233a2ce1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 18:30:32 -0700 Subject: [PATCH 04/29] support amd rccl --- vllm/model_executor/parallel_utils/pynccl.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 6175d56f25c6d..be9d16a0c3a22 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -8,6 +8,14 @@ # contains many other potential cuda APIs, that are not allowed during # capturing the CUDA graph. For further details, please check # https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/199366 +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually doable, +# but we often encounter issues related with nccl versions, and need to switch between +# different versions of NCCL. A C/C++ binding is not flexible enough to handle this. +# It requires recompilation of the code if we want to switch between different versions. +# This current implementation, with a pure Python wrapper, is more flexible. We can +# easily switch between different versions of NCCL by changing the environment variable, +# or the `so_file` variable in the code. # ==================================================== import ctypes @@ -30,9 +38,13 @@ logger.info( f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") else: - _path = os.path.dirname(os.path.abspath(__file__)) - so_file = glob.glob(f"{_path}/../../lib/nvidia/nccl/lib/libnccl.so.*")[0] - logger.info(f"Loading nccl from vLLM builtin file {so_file}") + if torch.cuda.version is not None: + so_file = "libnccl.so" + elif torch.hip.version is not None: + so_file = "librccl.so" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info(f"Loading nccl from library {so_file}") nccl = ctypes.CDLL(so_file) # === export types and functions from nccl to Python === @@ -48,7 +60,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] -def ncclGetVersion() -> int: +def ncclGetVersion() -> str: version = ctypes.c_int() result = _c_ncclGetVersion(ctypes.byref(version)) assert result == 0 From 9ec0b674c7f7296d2d6e3cc1e4529dd47dfe43af Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 18:32:24 -0700 Subject: [PATCH 05/29] leave todo --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 71acc31a39dc1..c7463b9728fd7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,8 +13,8 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") # used when building pytorch-related extensions -# TODO: only compute this if we are building wheels -# TODO: otherwise autodetect the version from current GPUs +# TODO(youkaichao): only compute this if we are building wheels +# TODO(youkaichao): otherwise autodetect the version from current GPUs set(TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0;8.6;8.9;9.0") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") From 4904b7297ed20b855f1bf0bf17de0b6c4f2f84ec Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 18:34:50 -0700 Subject: [PATCH 06/29] add pynccl into test --- .buildkite/test-pipeline.yaml | 22 +++++++++++++--------- tests/distributed/test_pynccl.py | 2 -- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f6781de61af19..0654fcfef0da6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -22,8 +22,18 @@ steps: working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. -- label: Distributed Correctness Test - command: pytest -v -s --forked test_basic_distributed_correctness.py +- label: Distributed pynccl Test + command: pytest -v -s --forked test_pynccl.py + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 2 # only support 1 or 2 for now. + +- label: Distributed Correctness Test-facebook/opt-125m + command: TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 2 # only support 1 or 2 for now. + +- label: Distributed Correctness Test-meta-llama/Llama-2-7b-hf + command: TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. @@ -39,15 +49,9 @@ steps: - label: Models Test commands: - - bash ../.buildkite/download-images.sh - - pytest -v -s models --ignore=models/test_llava.py --forked + - pytest -v -s models --forked soft_fail: true -- label: Llava Test - commands: - - bash ../.buildkite/download-images.sh - - pytest -v -s models/test_llava.py - - label: Prefix Caching Test commands: - pytest -v -s prefix_caching diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 2adfcc0cb321d..3b22e7a1eb27e 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,5 +1,3 @@ -# this script is not run with `pytest`. -# It is run with `torchrun`. import multiprocessing import os From 53e2ca319c52eecbcf72bbbfa42b58ceba46a73e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 18:38:36 -0700 Subject: [PATCH 07/29] linter --- vllm/model_executor/parallel_utils/pynccl.py | 22 +++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index be9d16a0c3a22..4a3b4ff73feb3 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -7,20 +7,22 @@ # 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` # contains many other potential cuda APIs, that are not allowed during # capturing the CUDA graph. For further details, please check -# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/199366 -# -# Another rejected idea is to write a C/C++ binding for NCCL. It is usually doable, -# but we often encounter issues related with nccl versions, and need to switch between -# different versions of NCCL. A C/C++ binding is not flexible enough to handle this. -# It requires recompilation of the code if we want to switch between different versions. -# This current implementation, with a pure Python wrapper, is more flexible. We can -# easily switch between different versions of NCCL by changing the environment variable, -# or the `so_file` variable in the code. +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. # ==================================================== import ctypes import datetime -import glob import logging import os From 5cdeb596885869547313d83bc56c070b2eb157f2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 19:41:32 -0700 Subject: [PATCH 08/29] update and merge from vllm/worker/model_runner.py in main --- vllm/worker/model_runner.py | 196 ++++++++++++++++++++++++------------ 1 file changed, 131 insertions(+), 65 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index efe623aa5a034..fdba79dd0bb43 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,28 +6,29 @@ import torch import torch.nn as nn +from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import InputMetadata, SamplingMetadata +from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils +from vllm.model_executor.parallel_utils import pynccl_utils, custom_all_reduce from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( with_pynccl_for_all_reduce) from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, + SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) logger = init_logger(__name__) -KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 @@ -49,6 +50,7 @@ def __init__( lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config @@ -83,14 +85,20 @@ def __init__( self.graph_block_tables = None # Set after initial profiling. self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype + self.vision_language_config = vision_language_config + + self.attn_backend = get_attn_backend( + self.model_config.dtype if model_config is not None else None) def load_model(self) -> None: with CudaMemoryProfiler() as m: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_model( + self.model_config, + self.device_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) self.model_memory_usage = m.consumed_memory logger.info(f"Loading model weights took " @@ -126,8 +134,9 @@ def get_max_block_per_batch(self) -> int: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], - List[int], List[int], Set[LoRARequest]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + List[int], List[int], List[int], Set[LoRARequest], + torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -140,6 +149,7 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -185,6 +195,10 @@ def _prepare_prompt( (prompt_len - computed_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -215,7 +229,7 @@ def _prepare_prompt( slot_mapping.append(slot) max_subquery_len = max(subquery_lens) - max_seq_len = max(prompt_lens) + max_prompt_len = max(prompt_lens) num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 @@ -233,6 +247,16 @@ def _prepare_prompt( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + # Prepare prefix block tables max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) block_tables = make_tensor_with_pad( @@ -269,7 +293,7 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - input_metadata = InputMetadata( + attn_metadata = self.attn_backend.make_metadata( is_prompt=True, slot_mapping=slot_mapping, prompt_lens=prompt_lens, @@ -278,7 +302,7 @@ def _prepare_prompt( num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, - max_seq_len=max_seq_len, + max_prompt_len=max_prompt_len, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, @@ -286,15 +310,15 @@ def _prepare_prompt( use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, input_metadata, prompt_lens, + return (input_tokens, input_positions, attn_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests) + lora_requests, multi_modal_input) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], - Set[LoRARequest]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -400,7 +424,7 @@ def _prepare_decode( device=self.device, ) - input_metadata = InputMetadata( + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, @@ -409,7 +433,7 @@ def _prepare_decode( num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, - max_seq_len=None, + max_prompt_len=None, subquery_start_loc=None, seq_start_loc=None, context_lens=context_lens, @@ -417,7 +441,7 @@ def _prepare_decode( use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, input_metadata, + return (input_tokens, input_positions, attn_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) def _prepare_sample( @@ -521,23 +545,25 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, - Set[int], LoRAMapping]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, + Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, prompt_lens, + (input_tokens, input_positions, attn_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_prompt(seq_group_metadata_list) + lora_requests, multi_modal_input + ) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata, + (input_tokens, input_positions, attn_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] subquery_lens = None + multi_modal_input = None sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) @@ -558,8 +584,9 @@ def prepare_input_tensors( sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, + "multi_modal_input": multi_modal_input, } - metadata_dict.update(input_metadata.asdict_zerocopy()) + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) @@ -569,7 +596,8 @@ def prepare_input_tensors( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - input_metadata = InputMetadata(**metadata_dict) + multi_modal_input = metadata_dict.pop("multi_modal_input") + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -580,34 +608,38 @@ def prepare_input_tensors( perform_sampling=False, ) - return (input_tokens, input_positions, input_metadata, - sampling_metadata, lora_requests, lora_mapping) + return (input_tokens, input_positions, attn_metadata, + sampling_metadata, lora_requests, lora_mapping, + multi_modal_input) @torch.inference_mode() def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, input_metadata, sampling_metadata, - lora_requests, - lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, sampling_metadata, + lora_requests, lora_mapping, multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) # Execute the model. - if input_metadata.use_cuda_graph: + if attn_metadata.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - hidden_states = model_executable( - input_ids=input_tokens, - positions=input_positions, - kv_caches=kv_caches, - input_metadata=input_metadata, - ) + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) @@ -655,10 +687,22 @@ def profile_run(self) -> None: # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for vision encoding, which needs + # to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + if self.vision_language_config: + max_num_seqs = min( + max_num_seqs, + int(max_num_batched_tokens / + self.vision_language_config.image_feature_size)) for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data = SequenceData([0] * seq_len) + seq_data, fake_multi_modal_input = _prepare_fake_inputs( + seq_len, self.vision_language_config) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -667,12 +711,13 @@ def profile_run(self) -> None: block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, + multi_modal_data=fake_multi_modal_input, ) seqs.append(seq) # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [(None, None)] * num_layers + kv_caches = [None] * num_layers self.execute_model(seqs, kv_caches) torch.cuda.synchronize() return @@ -704,7 +749,7 @@ def list_loras(self) -> Set[int]: return self.lora_manager.list_loras() @torch.inference_mode() - def capture_model(self, kv_caches: List[KVCache]) -> None: + def capture_model(self, kv_caches: List[torch.Tensor]) -> None: """Cuda graph capture a model. Note that CUDA graph's performance gain is negligible if number @@ -719,7 +764,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. - self.cupy_nccl_backend = pynccl_utils.get_nccl_backend() + self.pynccl_backend = pynccl_utils.get_nccl_backend() assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " @@ -758,8 +803,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): - # Create dummy input_metadata. - input_metadata = InputMetadata( + # Create dummy attn_metadata. + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, @@ -768,7 +813,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, - max_seq_len=None, + max_prompt_len=None, subquery_start_loc=None, seq_start_loc=None, context_lens=context_lens[:batch_size], @@ -789,7 +834,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_tokens[:batch_size], input_positions[:batch_size], kv_caches, - input_metadata, + attn_metadata, memory_pool=self.graph_memory_pool, ) self.graph_memory_pool = graph_runner.graph.pool() @@ -806,7 +851,7 @@ def __del__(self) -> None: # happen. # FIXME(woosuk): This is a bit hacky. Find a more robust solution. self.graph_runners.clear() - self.cupy_nccl_backend = None + self.pynccl_backend = None @property def vocab_size(self) -> int: @@ -825,9 +870,10 @@ def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, memory_pool, + **kwargs, ) -> None: assert self.graph is None # Run the model once without capturing the graph. @@ -838,7 +884,8 @@ def capture( input_ids, positions, kv_caches, - input_metadata, + attn_metadata, + **kwargs, ) torch.cuda.synchronize() @@ -852,7 +899,8 @@ def capture( input_ids, positions, kv_caches, - input_metadata, + attn_metadata, + **kwargs, ) torch.cuda.synchronize() @@ -861,9 +909,9 @@ def capture( "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, - "slot_mapping": input_metadata.slot_mapping, - "context_lens": input_metadata.context_lens, - "block_tables": input_metadata.block_tables, + "slot_mapping": attn_metadata.slot_mapping, + "context_lens": attn_metadata.context_lens, + "block_tables": attn_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -872,8 +920,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - input_metadata: InputMetadata, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. del kv_caches @@ -881,11 +930,11 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, + self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(input_metadata.context_lens, + self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, non_blocking=True) - self.input_buffers["block_tables"].copy_(input_metadata.block_tables, + self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() @@ -899,8 +948,7 @@ def __call__(self, *args, **kwargs): @contextlib.contextmanager def _maybe_pynccl(): - if pynccl_utils.is_initialized( - ) and not custom_all_reduce.is_initialized(): + if pynccl_utils.is_initialized() and not custom_all_reduce.is_initialized(): with with_pynccl_for_all_reduce(): yield else: @@ -920,3 +968,21 @@ def _get_graph_batch_size(batch_size: int) -> int: else: return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + +def _prepare_fake_inputs( + seq_len: int, vision_language_config: Optional[VisionLanguageConfig]): + """Prepare fake inputs for profile run.""" + if vision_language_config: + prompt_tokens = [ + vision_language_config.image_token_id + ] * vision_language_config.image_feature_size + [0] * ( + seq_len - vision_language_config.image_feature_size) + fake_image_input = MultiModalData( + type=MultiModalData.Type.IMAGE, + data=torch.zeros(vision_language_config.image_input_shape, + dtype=torch.float16)) + else: + prompt_tokens = [0] * seq_len + fake_image_input = None + return SequenceData(prompt_tokens), fake_image_input From af94dc647a4253adc7763ef3c3be775e2605ec17 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 19:43:01 -0700 Subject: [PATCH 09/29] fix isort --- vllm/worker/model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fdba79dd0bb43..b366852d9fb1d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -15,7 +15,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import pynccl_utils, custom_all_reduce +from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -948,7 +948,8 @@ def __call__(self, *args, **kwargs): @contextlib.contextmanager def _maybe_pynccl(): - if pynccl_utils.is_initialized() and not custom_all_reduce.is_initialized(): + if pynccl_utils.is_initialized( + ) and not custom_all_reduce.is_initialized(): with with_pynccl_for_all_reduce(): yield else: From af8e254282e4be261ab9d433fa1a4e964cdddc04 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 19:55:52 -0700 Subject: [PATCH 10/29] restore and merge vllm/worker/worker.py --- vllm/worker/worker.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 1457aca677132..2d40d029d20a3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,7 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils import pynccl_utils @@ -39,6 +39,7 @@ def __init__( rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, + vision_language_config: Optional[VisionLanguageConfig] = None, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: @@ -54,13 +55,20 @@ def __init__( if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - self.model_runner = ModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) + self.vision_language_config = vision_language_config + if self.vision_language_config: + assert not self.lora_config, ( + "To be tested: vision language model with LoRA settings.") + + self.model_runner = ModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None @@ -128,6 +136,9 @@ def profile_num_available_blocks( # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes( block_size, cache_dtype) From db3044af4c0156c6e4be42745a5135098ca22aad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 20:44:25 -0700 Subject: [PATCH 11/29] fix hip condition --- vllm/model_executor/parallel_utils/pynccl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 4a3b4ff73feb3..90b5608449311 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -40,9 +40,9 @@ logger.info( f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") else: - if torch.cuda.version is not None: + if torch.cuda.is_available(): so_file = "libnccl.so" - elif torch.hip.version is not None: + elif torch.version.hip is not None: so_file = "librccl.so" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") From f078110edd55bd6a5414b9968adbf5cf27458af1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 20:53:46 -0700 Subject: [PATCH 12/29] fix cuda condition --- vllm/model_executor/parallel_utils/pynccl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 90b5608449311..cbad40c0298c9 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -40,7 +40,7 @@ logger.info( f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") else: - if torch.cuda.is_available(): + if torch.version.cuda is not None: so_file = "libnccl.so" elif torch.version.hip is not None: so_file = "librccl.so" From 18c9437e048934b0d8038986d9a5448d3582e0e1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 21:01:28 -0700 Subject: [PATCH 13/29] add error message when libnccl cannot be found --- vllm/model_executor/parallel_utils/pynccl.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index cbad40c0298c9..4ea70bc1bb9c9 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -47,7 +47,17 @@ else: raise ValueError("NCCL only supports CUDA and ROCm backends.") logger.info(f"Loading nccl from library {so_file}") -nccl = ctypes.CDLL(so_file) + +try: + nccl = ctypes.CDLL(so_file) +except Exception as e: + logger.error( + f"Failed to load NCCL library from {so_file} ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise please set the environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path." + ) + raise e # === export types and functions from nccl to Python === # for the original nccl definition, please check From 3f1db02d9474bef7ed0302ae2f9770c5d5c26aa6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 21:05:52 -0700 Subject: [PATCH 14/29] fix test --- tests/distributed/test_pynccl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 3b22e7a1eb27e..df291f87f80bf 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -59,12 +59,14 @@ def worker_fn_with_cudagraph(): a = torch.ones((4, 4), device=f'cuda:{comm.rank}') torch.cuda.synchronize() with torch.cuda.graph(graph, stream=comm.stream): + # TODO(youkaichao) + # seems like this operation is not performed comm.all_reduce(a) comm.stream.synchronize() assert a.mean().cpu().item() == comm.world_size**0 graph.replay() comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**2 + assert a.mean().cpu().item() == comm.world_size**1 @pytest.mark.skipif(torch.cuda.device_count() < 2, From 6c7082d50ab9fbc4125c1fe9cf6d2c1c445a5c74 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 21:17:13 -0700 Subject: [PATCH 15/29] fix lint --- .../test_basic_distributed_correctness.py | 17 +++++++++++++---- vllm/model_executor/parallel_utils/pynccl.py | 3 +-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 82075356fccbd..1eba14d7a6422 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -1,13 +1,22 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_basic_distributed_correctness.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_basic_distributed_correctness.py +``` """ +import os + import pytest import torch MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + os.environ["TEST_DIST_MODEL"], ] diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 4ea70bc1bb9c9..556f50d2604bf 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -55,8 +55,7 @@ f"Failed to load NCCL library from {so_file} ." "It is expected if you are not running on NVIDIA/AMD GPUs." "Otherwise please set the environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path." - ) + " to point to the correct nccl library path.") raise e # === export types and functions from nccl to Python === From 80eec0b3f67eec5b7ab9b3476fd1a6ed1a7430cc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 21:51:06 -0700 Subject: [PATCH 16/29] restore test of model and llava --- .buildkite/test-pipeline.yaml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 0654fcfef0da6..ac53de50950b8 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -49,9 +49,15 @@ steps: - label: Models Test commands: - - pytest -v -s models --forked + - bash ../.buildkite/download-images.sh + - pytest -v -s models --ignore=models/test_llava.py --forked soft_fail: true +- label: Llava Test + commands: + - bash ../.buildkite/download-images.sh + - pytest -v -s models/test_llava.py + - label: Prefix Caching Test commands: - pytest -v -s prefix_caching From efb52bf7ab77c396d9640704f4b9b72cebd23652 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 22:00:54 -0700 Subject: [PATCH 17/29] try to clean up cupy --- requirements.txt | 1 - setup.py | 6 ------ vllm/executor/ray_gpu_executor.py | 6 ++---- .../parallel_utils/parallel_state.py | 2 +- vllm/test_utils.py | 1 - vllm/worker/model_runner.py | 8 +++---- vllm/worker/worker.py | 21 ++++++++----------- 7 files changed, 16 insertions(+), 29 deletions(-) diff --git a/requirements.txt b/requirements.txt index eb9977d93dd8d..6d75067b34a7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,3 @@ prometheus_client >= 0.18.0 pynvml == 11.5.0 triton >= 2.1.0 outlines == 0.0.34 -cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. diff --git a/setup.py b/setup.py index 9c9a428f94683..941d315d573bb 100644 --- a/setup.py +++ b/setup.py @@ -306,12 +306,6 @@ def get_requirements() -> List[str]: if _is_cuda(): with open(get_path("requirements.txt")) as f: requirements = f.read().strip().split("\n") - if get_nvcc_cuda_version() <= Version("11.8"): - # replace cupy-cuda12x with cupy-cuda11x for cuda 11.x - for i in range(len(requirements)): - if requirements[i].startswith("cupy-cuda12x"): - requirements[i] = "cupy-cuda11x" - break elif _is_hip(): with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index f2fc8aec9887b..d8288832a5300 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -188,11 +188,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", is_driver_worker=True, ) - # FIXME(woosuk): We are not properly initializing cupy NCCL when + # FIXME(woosuk): We are not properly initializing pynccl when # we have multiple nodes. - self._run_workers("init_device", - cupy_port=get_open_port() - if not model_config.enforce_eager else None) + self._run_workers("init_device") self._run_workers( "load_model", max_concurrent_workers=self.parallel_config. diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index 63890d9cd5bd8..bcda5ebf8548b 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -226,7 +226,7 @@ def with_pynccl_for_all_reduce(): tp_size = get_tensor_model_parallel_world_size() if tp_size == 1: # No-op. - # NOTE(woosuk): We don't initialize CuPy when tp_size is 1. + # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. yield else: global _ENABLE_PYNCCL_FOR_ALL_REDUCE diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 75bf6ce373d93..5bd3ebbbd5b1a 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -18,7 +18,6 @@ def init_test_distributed_environment( init_distributed_environment( parallel_config, rank, - cupy_port=None, distributed_init_method=distributed_init_method) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b366852d9fb1d..e77d145a0794b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -794,11 +794,11 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ] # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce - # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use - # either custom all-reduce kernel or CuPy NCCL. When not using CUDA + # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA # graph, we use either custom all-reduce kernel or PyTorch NCCL. # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or CuPy NCCL if it is disabled or not supported. + # to PyTorch or pynccl if it is disabled or not supported. with custom_all_reduce.capture(): # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. @@ -846,7 +846,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") def __del__(self) -> None: - # Delete the CUDA graphs before deleting the CuPy NCCL communicator. + # Delete the CUDA graphs before deleting the pynccl communicator. # NOTE(woosuk): This is necessary because otherwise deadlocks can # happen. # FIXME(woosuk): This is a bit hacky. Find a more robust solution. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2d40d029d20a3..1258dd8ccbe57 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -75,7 +75,7 @@ def __init__( self.cache_engine = None self.gpu_cache = None - def init_device(self, cupy_port: Optional[int] = None) -> None: + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow @@ -97,8 +97,7 @@ def init_device(self, cupy_port: Optional[int] = None) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_distributed_environment(self.parallel_config, self.rank, - cupy_port, self.distributed_init_method) + init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method) # Set random seed. set_random_seed(self.model_config.seed) @@ -250,7 +249,6 @@ def get_cache_block_size_bytes(self, block_size: int, def init_distributed_environment( parallel_config: ParallelConfig, rank: int, - cupy_port: Optional[int], distributed_init_method: Optional[str] = None, ) -> None: """Initialize the distributed environment.""" @@ -274,21 +272,20 @@ def init_distributed_environment( ) if pynccl_utils.is_initialized(): - cupy_world_size = pynccl_utils.get_world_size() - if cupy_world_size != parallel_config.world_size: + pynccl_world_size = pynccl_utils.get_world_size() + if pynccl_world_size != parallel_config.world_size: raise RuntimeError( - "cupy.distributed is already initialized but the cupy world " + "pynccl is already initialized but the pynccl world " "size does not match parallel_config.world_size " - f"({cupy_world_size} vs. {parallel_config.world_size}).") - elif (parallel_config.world_size > 1 and cupy_port is not None): - # NOTE(woosuk): We don't initialize CuPy process group when world size + f"({pynccl_world_size} vs. {parallel_config.world_size}).") + elif parallel_config.world_size > 1: + # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. # TODO(woosuk): Support multi-node connection. pynccl_utils.init_process_group( world_size=parallel_config.world_size, rank=rank, - host="localhost", - port=cupy_port, + init_method=distributed_init_method, ) # A small all_reduce for warmup. From e6fb64d097a695f9125a0ea8f46fbb7f9e1c583b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 22:01:06 -0700 Subject: [PATCH 18/29] lint --- vllm/test_utils.py | 4 +--- vllm/worker/worker.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 5bd3ebbbd5b1a..5b2eeafad197e 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -16,9 +16,7 @@ def init_test_distributed_environment( worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - parallel_config, - rank, - distributed_init_method=distributed_init_method) + parallel_config, rank, distributed_init_method=distributed_init_method) def multi_process_tensor_parallel( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 1258dd8ccbe57..6459c0cda669a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -97,7 +97,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method) + init_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method) # Set random seed. set_random_seed(self.model_config.seed) From 5eb972e0a2cccb6e48789df48d005c23ca5267b9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 22:44:33 -0700 Subject: [PATCH 19/29] unify init_method --- vllm/model_executor/parallel_utils/pynccl_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py index 01d45e8ce94a1..a12d620d7a24c 100644 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ b/vllm/model_executor/parallel_utils/pynccl_utils.py @@ -36,11 +36,10 @@ def set_pynccl_stream(stream: torch.cuda.Stream): pass -def init_process_group(world_size: int, rank: int, host: str, - port: int) -> None: +def init_process_group(world_size: int, rank: int, init_method: str) -> None: assert not is_initialized() global comm - comm = NCCLCommunicator(init_method=f"tcp://{host}:{port}", + comm = NCCLCommunicator(init_method=init_method, world_size=world_size, rank=rank) From 915213c629669ba8ece5eb84412b9619651abaa6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 23:04:21 -0700 Subject: [PATCH 20/29] further cleanup cupy --- Dockerfile | 2 +- Dockerfile.rocm | 20 -------------------- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/Dockerfile b/Dockerfile index 1f254c76fe5af..20bbff34b7fcb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -97,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal #################### RUNTIME BASE IMAGE #################### # We used base cuda image because pytorch installs its own cuda libraries. -# However cupy depends on cuda libraries so we had to switch to the runtime image +# However pynccl depends on cuda libraries so we had to switch to the runtime image # In the future it would be nice to get a container with pytorch and cuda without duplicating cuda FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base diff --git a/Dockerfile.rocm b/Dockerfile.rocm index a45265d79a6ac..a09de99f7a468 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -23,9 +23,6 @@ RUN echo "FA_BRANCH is $FA_BRANCH" # In that case, we need to use the python reference attention implementation in vllm ARG BUILD_FA="1" -# whether to build cupy on rocm -ARG BUILD_CUPY="1" - # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y @@ -78,23 +75,6 @@ RUN if [ "$BUILD_FA" = "1" ]; then \ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi -# build cupy -RUN if [ "$BUILD_CUPY" = "1" ]; then \ - mkdir -p libs \ - && cd libs \ - && git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \ - && cd cupy \ - && pip install mpi4py-mpich \ - && pip install scipy==1.9.3 \ - && pip install cython==0.29.* \ - && env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \ - && export CUPY_INSTALL_USE_HIP=1 \ - && export ROCM_HOME=/opt/rocm \ - && export HCC_AMDGPU_TARGET="gfx90a,gfx942,gfx1100" \ - && pip install . \ - && cd ..; \ - fi - COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip From 9b4d6fc0d48d0970fddb01faadd48aa0bea0e1a6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Mar 2024 23:58:33 -0700 Subject: [PATCH 21/29] do not know why, but this fixes ray test --- tests/distributed/test_comm_ops.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 1d376b18a66b3..1a8e960dc31b6 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -16,6 +16,10 @@ @ray.remote(num_gpus=1, max_calls=1) def all_reduce_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): + import os + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) num_elements = 8 @@ -32,6 +36,10 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): + import os + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) num_dimensions = 3 @@ -54,6 +62,10 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): + import os + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) test_dict = { From 983243ea9941b4178c53ebc4631c13f21de5a624 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Mar 2024 11:41:29 -0700 Subject: [PATCH 22/29] merge distributed tests --- .buildkite/test-pipeline.yaml | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ac53de50950b8..d7d2b930e6e7a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -22,20 +22,13 @@ steps: working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. -- label: Distributed pynccl Test - command: pytest -v -s --forked test_pynccl.py - working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. - -- label: Distributed Correctness Test-facebook/opt-125m - command: TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py - working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. - -- label: Distributed Correctness Test-meta-llama/Llama-2-7b-hf - command: TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py +- label: Distributed Tests working_dir: "/vllm-workspace/tests/distributed" num_gpus: 2 # only support 1 or 2 for now. + commands: + - pytest -v -s --forked test_pynccl.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py From d2c9b4b4d00277ebf1f962de49875f228fdb885d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Mar 2024 12:29:22 -0700 Subject: [PATCH 23/29] update comment for allreduce graph capture --- tests/distributed/test_pynccl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index df291f87f80bf..d6901e8acec91 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -59,8 +59,8 @@ def worker_fn_with_cudagraph(): a = torch.ones((4, 4), device=f'cuda:{comm.rank}') torch.cuda.synchronize() with torch.cuda.graph(graph, stream=comm.stream): - # TODO(youkaichao) - # seems like this operation is not performed + # operation during the graph capture is recorded but not executed + # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa comm.all_reduce(a) comm.stream.synchronize() assert a.mean().cpu().item() == comm.world_size**0 From cefae388f90a44a18a8c973e285562dbab1b6374 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Mar 2024 12:31:41 -0700 Subject: [PATCH 24/29] explain why delete CUDA_VISIBLE_DEVICES --- tests/distributed/test_comm_ops.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 1a8e960dc31b6..0395f7200fd77 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -2,6 +2,8 @@ Run `pytest tests/distributed/test_comm_ops.py --forked`. """ +import os + import pytest import ray import torch @@ -16,7 +18,9 @@ @ray.remote(num_gpus=1, max_calls=1) def all_reduce_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): - import os + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) @@ -36,7 +40,9 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): - import os + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) @@ -62,7 +68,9 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, distributed_init_port: str): - import os + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) From da58b34b4709da5ebf1fd446714a4b0622f33973 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Mar 2024 12:34:49 -0700 Subject: [PATCH 25/29] explain update_env --- tests/distributed/test_pynccl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index d6901e8acec91..797f18915dec9 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -26,7 +26,9 @@ def distributed_run(fn, world_size): def update_env(fn): - + # `multiprocessing.Process` cannot accept environment variables directly + # so we need to pass the environment variables as arguments + # and update the environment variables in the function def wrapper(env): import os os.environ.update(env) From 850eca1de58229cd26da620598d83087b4c15e27 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Mar 2024 12:37:48 -0700 Subject: [PATCH 26/29] move successful load message to debug level --- vllm/model_executor/parallel_utils/pynccl.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 556f50d2604bf..0bea4b6997d2b 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -1,4 +1,3 @@ -# ===================== pynccl.py ================================== # This file is a pure Python wrapper for the NCCL library. # The main purpose is to use NCCL combined with CUDA graph. # Before writing this script, we tried the following approach: @@ -19,7 +18,6 @@ # more flexible. We can easily switch between different versions of NCCL by # changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` # variable in the code. -# ==================================================== import ctypes import datetime @@ -46,7 +44,7 @@ so_file = "librccl.so" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info(f"Loading nccl from library {so_file}") + logger.debug(f"Loading nccl from library {so_file}") try: nccl = ctypes.CDLL(so_file) @@ -58,6 +56,7 @@ " to point to the correct nccl library path.") raise e + # === export types and functions from nccl to Python === # for the original nccl definition, please check # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in @@ -258,6 +257,3 @@ def all_reduce(self, def __del__(self): dist.destroy_process_group() _c_ncclCommDestroy(self.comm) - - -# ===================== pynccl.py ===================== From 0ed6527e54cc4ca6c28234630a0929641ebe2ee6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Mar 2024 12:39:09 -0700 Subject: [PATCH 27/29] fix lint --- vllm/model_executor/parallel_utils/pynccl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index 0bea4b6997d2b..e8f0f74ad186b 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -56,7 +56,6 @@ " to point to the correct nccl library path.") raise e - # === export types and functions from nccl to Python === # for the original nccl definition, please check # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in From 0361127364442adffaa60373beb255f168505ac2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Mar 2024 23:30:55 -0700 Subject: [PATCH 28/29] restore CMakeLists.txt --- CMakeLists.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c7463b9728fd7..66842e6845edd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,10 +12,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") -# used when building pytorch-related extensions -# TODO(youkaichao): only compute this if we are building wheels -# TODO(youkaichao): otherwise autodetect the version from current GPUs -set(TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0;8.6;8.9;9.0") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") From 6d18496c0289973d33e919c3124b0f49a3cde67d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 26 Mar 2024 23:33:51 -0700 Subject: [PATCH 29/29] leave a todo for __del__ --- vllm/worker/model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e77d145a0794b..f0c98700ab749 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -850,6 +850,8 @@ def __del__(self) -> None: # NOTE(woosuk): This is necessary because otherwise deadlocks can # happen. # FIXME(woosuk): This is a bit hacky. Find a more robust solution. + # TODO(youkaichao): when we get enough user feedback that pynccl is + # more stable than cupy, we can remove this, e.g. in v0.4.1. self.graph_runners.clear() self.pynccl_backend = None