From d86a64ad6ffb1c4e82d5e681e3175de73a0af668 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Tue, 19 Mar 2024 12:24:46 +0000 Subject: [PATCH] Add CPU CMake extension. --- CMakeLists.txt | 16 ++++++++ cmake/cpu_extension.cmake | 77 +++++++++++++++++++++++++++++++++++++++ setup.py | 19 ++++++++-- 3 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 cmake/cpu_extension.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 29a531d44a9d5..70b9d61b5a973 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21) project(vllm_extensions LANGUAGES CXX) +option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cpu") + message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") +message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) @@ -69,6 +72,19 @@ find_package(Torch REQUIRED) # append_torchlib_if_found(torch_python) +# +# Forward the non-CUDA device extensions to external CMake scripts. +# +if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND + NOT VLLM_TARGET_DEVICE STREQUAL "rocm") + if (VLLM_TARGET_DEVICE STREQUAL "cpu") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) + else() + message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + endif() + return() +endif() + # # Set up GPU language and check the torch version and warn if it isn't # what is expected. diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake new file mode 100644 index 0000000000000..236573919b21c --- /dev/null +++ b/cmake/cpu_extension.cmake @@ -0,0 +1,77 @@ +# +# Check the compile flags +# +list(APPEND CXX_COMPILE_FLAGS + "-fopenmp" + "-DVLLM_CPU_EXTENSION") + +execute_process(COMMAND cat /proc/cpuinfo + RESULT_VARIABLE CPUINFO_RET + OUTPUT_VARIABLE CPUINFO) + +if (NOT CPUINFO_RET EQUAL 0) + message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") +endif() + +function (find_isa CPUINFO TARGET OUT) + string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) + if(NOT ISA_FOUND EQUAL -1) + set(${OUT} ON PARENT_SCOPE) + else() + set(${OUT} OFF PARENT_SCOPE) + endif() +endfunction() + +find_isa(${CPUINFO} "avx512f" AVX512_FOUND) + +if (AVX512_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-mavx512f" + "-mavx512vl" + "-mavx512bw" + "-mavx512dq") + + find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) + if (AVX512BF16_FOUND AND + CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + else() + message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") + endif() +else() + message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.") +endif() + +message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") + +message(FATAL_ERROR "vLLM CPU backend is unavailable") + +# +# Define extension targets +# + +# +# _C extension +# +set(VLLM_EXT_SRC + "csrc/cpu/activation.cpp" + "csrc/cpu/attention.cpp" + "csrc/cpu/cache.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/pybind.cpp") + +define_gpu_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + WITH_SOABI +) + +add_custom_target(default) +message(STATUS "Enabling C extension.") +add_dependencies(default _C) + diff --git a/setup.py b/setup.py index 88787334be21a..eeab2f5ec570b 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,8 @@ from torch.utils.cpp_extension import CUDA_HOME ROOT_DIR = os.path.dirname(__file__) +# Target device of vLLM, supporting [cuda (by default), rocm, neuron] +VLLM_TARGET_DEVICE = os.getenv("VLLM_TARGET_DEVICE", "cuda") # vLLM only supports Linux platform assert sys.platform.startswith( @@ -61,8 +63,7 @@ def compute_num_jobs(self): except AttributeError: num_jobs = os.cpu_count() - nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version >= Version("11.2"): + if _is_cuda() and get_nvcc_cuda_version() >= Version("11.2"): nvcc_threads = int(os.getenv("NVCC_THREADS", 8)) num_jobs = max(1, round(num_jobs / (nvcc_threads / 4))) else: @@ -95,6 +96,7 @@ def configure(self, ext: CMakeExtension) -> None: '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir), '-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp), + '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), ] verbose = bool(int(os.getenv('VERBOSE', '0'))) @@ -168,11 +170,12 @@ def build_extensions(self) -> None: def _is_cuda() -> bool: - return torch.version.cuda is not None + return VLLM_TARGET_DEVICE == "cuda" and torch.version.cuda is not None def _is_hip() -> bool: - return torch.version.hip is not None + return (VLLM_TARGET_DEVICE == "cuda" + or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None def _is_neuron() -> bool: @@ -184,6 +187,10 @@ def _is_neuron() -> bool: return torch_neuronx_installed +def _is_cpu() -> bool: + return VLLM_TARGET_DEVICE == "cpu" + + def _install_punica() -> bool: return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) @@ -279,6 +286,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" + elif _is_cpu(): + version += "+cpu" else: raise RuntimeError("Unknown runtime environment") @@ -311,6 +320,8 @@ def get_requirements() -> List[str]: elif _is_neuron(): with open(get_path("requirements-neuron.txt")) as f: requirements = f.read().strip().split("\n") + elif _is_cpu(): + requirements = [] else: raise ValueError( "Unsupported platform, please use CUDA, ROCM or Neuron.")