diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh new file mode 100644 index 0000000000000..f187d1f181724 --- /dev/null +++ b/.buildkite/run-cpu-test.sh @@ -0,0 +1,14 @@ +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t cpu-test -f Dockerfile.cpu . + +# Setup cleanup +remove_docker_container() { docker rm -f cpu-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 4dde733581822..3ed23c62c005d 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -8,6 +8,9 @@ steps: queue: amd command: bash .buildkite/run-amd-test.sh + - label: "CPU Test" + command: bash .buildkite/run-cpu-test.sh + - label: ":docker: build image" commands: - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 2578d448436d2..ed200fe724d3e 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -15,6 +15,7 @@ $python_executable -m pip install -r requirements.txt export MAX_JOBS=1 # Make sure punica is built for the release (for LoRA) export VLLM_INSTALL_PUNICA_KERNELS=1 - +# Make sure release wheels are built for the following architectures +export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" # Build $python_executable setup.py bdist_wheel --dist-dir=dist diff --git a/CMakeLists.txt b/CMakeLists.txt index 412b9c0cd59e0..9d90f4e7a0496 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21) project(vllm_extensions LANGUAGES CXX) +option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda") + message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") +message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) @@ -76,6 +79,19 @@ find_package(Torch REQUIRED) find_library(torch_python_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") +# +# Forward the non-CUDA device extensions to external CMake scripts. +# +if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND + NOT VLLM_TARGET_DEVICE STREQUAL "rocm") + if (VLLM_TARGET_DEVICE STREQUAL "cpu") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) + else() + message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + endif() + return() +endif() + # # Set up GPU language and check the torch version and warn if it isn't # what is expected. diff --git a/Dockerfile.cpu b/Dockerfile.cpu new file mode 100644 index 0000000000000..4251fddd6cc3b --- /dev/null +++ b/Dockerfile.cpu @@ -0,0 +1,20 @@ +# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. + +FROM ubuntu:22.04 + +RUN apt-get update -y \ + && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ + && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +RUN pip install --upgrade pip \ + && pip install wheel packaging ninja setuptools>=49.4.0 numpy + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + +RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install + +CMD ["/bin/bash"] diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 96a372e5511b7..ad428bd1c3644 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -334,7 +334,8 @@ async def async_request_openai_chat_completions( timestamp = time.perf_counter() data = json.loads(chunk) - if "content" in data["choices"][0]["delta"]: + delta = data["choices"][0]["delta"] + if delta.get("content", None): # First token if ttft == 0: ttft = time.perf_counter() - st @@ -345,8 +346,7 @@ async def async_request_openai_chat_completions( output.itl.append(timestamp - most_recent_timestamp) - generated_text += data["choices"][0]["delta"][ - "content"] + generated_text += delta["content"] most_recent_timestamp = timestamp diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake new file mode 100644 index 0000000000000..0cf37769a6960 --- /dev/null +++ b/cmake/cpu_extension.cmake @@ -0,0 +1,90 @@ +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# +# Define environment variables for special configurations +# +if(DEFINED ENV{VLLM_CPU_AVX512BF16}) + set(ENABLE_AVX512BF16 ON) +endif() + +include_directories("${CMAKE_SOURCE_DIR}/csrc") + +# +# Check the compile flags +# +list(APPEND CXX_COMPILE_FLAGS + "-fopenmp" + "-DVLLM_CPU_EXTENSION") + +execute_process(COMMAND cat /proc/cpuinfo + RESULT_VARIABLE CPUINFO_RET + OUTPUT_VARIABLE CPUINFO) + +if (NOT CPUINFO_RET EQUAL 0) + message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") +endif() + +function (find_isa CPUINFO TARGET OUT) + string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) + if(NOT ISA_FOUND EQUAL -1) + set(${OUT} ON PARENT_SCOPE) + else() + set(${OUT} OFF PARENT_SCOPE) + endif() +endfunction() + +find_isa(${CPUINFO} "avx512f" AVX512_FOUND) + +if (AVX512_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-mavx512f" + "-mavx512vl" + "-mavx512bw" + "-mavx512dq") + + find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) + if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + else() + message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") + endif() + else() + message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") + endif() +else() + message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.") +endif() + +message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") + + +# +# Define extension targets +# + +# +# _C extension +# +set(VLLM_EXT_SRC + "csrc/cpu/activation.cpp" + "csrc/cpu/attention.cpp" + "csrc/cpu/cache.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/cpu/pybind.cpp") + +define_gpu_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + WITH_SOABI +) + +add_custom_target(default) +message(STATUS "Enabling C extension.") +add_dependencies(default _C) + diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp new file mode 100644 index 0000000000000..1bd24eb79d129 --- /dev/null +++ b/csrc/cpu/activation.cpp @@ -0,0 +1,148 @@ +#include "cpu_types.hpp" + +namespace { +template +void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, + scalar_t *__restrict__ output) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + + TORCH_CHECK(d % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + for (int j = 0; j < d; j += VEC_ELEM_NUM) { + int start = i * d; + if constexpr (is_gated) { + start *= 2; + } + + const scalar_vec_t x(input + start + j); + const vec_op::FP32Vec8 f32_x(x); + vec_op::FP32Vec8 f32_ans = func(f32_x); + + if constexpr (is_gated) { + const scalar_vec_t y(input + start + d + j); + const vec_op::FP32Vec8 f32_y(y); + f32_ans = f32_y * f32_ans; + } + + const scalar_vec_t result(f32_ans); + result.save(output + i * d + j); + } + } +} + +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + return x / (ones + (zeros - x).exp()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 x3 = x * x * x; + const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT1_2); + const vec_op::FP32Vec8 w2(0.5); + return x * w2 * (ones + (x * w1).er()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); + const vec_op::FP32Vec8 w2(0.5); + const vec_op::FP32Vec8 w3(0.044715); + const vec_op::FP32Vec8 x_3 = x * x * x; + const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); + return x * w2 * (ones + inner.tanh()); +} +}; // namespace + +void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel(num_tokens, d, + input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); +} + +void gelu_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel(num_tokens, d, + input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); +} + +void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_tanh_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl) + }); +} + +void gelu_new(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_new_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_new_impl) + }); +} + +void gelu_fast(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_fast_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_fast_impl) + }); +} diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp new file mode 100644 index 0000000000000..6f38e923d7d6f --- /dev/null +++ b/csrc/cpu/attention.cpp @@ -0,0 +1,744 @@ +#include "cpu_types.hpp" + +namespace { + +template struct KernelVecType { + using q_load_vec_type = void; + using q_vec_type = void; + using k_load_vec_type = void; + using k_vec_type = void; + using qk_acc_vec_type = void; + using v_load_vec_type = void; +}; + +template <> struct KernelVecType { + using q_load_vec_type = vec_op::FP32Vec4; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::FP32Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP32Vec16; +}; + +#ifdef __AVX512BF16__ +template <> struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::BF16Vec32; + using k_load_vec_type = vec_op::BF16Vec32; + using k_vec_type = vec_op::BF16Vec32; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#else +template <> struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::BF16Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; +#endif + +template +FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, + const int capacity) { + T max = data[0]; + for (int i = 1; i < size; ++i) { + max = max >= data[i] ? max : data[i]; + } + + T sum = 0; + for (int i = 0; i < size; ++i) { + data[i] = std::exp(data[i] - max); + sum += data[i]; + } + + int i = 0; + for (; i < size; ++i) { + data[i] /= sum; + } + + for (; i < capacity; ++i) { + data[i] = 0; + } + + return {max, sum}; +} + +template +FORCE_INLINE std::pair +reduceSoftmaxAlibi(T *data, const int size, const int capacity, + const float alibi_slope, const int start_index, + const int context_len) { + data[0] += alibi_slope * (start_index - context_len + 1); + T max = data[0]; + for (int i = 1; i < size; ++i) { + T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); + data[i] = qk; + max = max >= qk ? max : qk; + } + + T sum = 0; + for (int i = 0; i < size; ++i) { + data[i] = std::exp(data[i] - max); + sum += data[i]; + } + + int i = 0; + for (; i < size; ++i) { + data[i] /= sum; + } + + for (; i < capacity; ++i) { + data[i] = 0; + } + + return {max, sum}; +} + +template +FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, + const int size) { + T max = max_data[0]; + for (int i = 1; i < size; ++i) { + max = max >= max_data[i] ? max : max_data[i]; + } + + T rescaled_sum = 0; + for (int i = 0; i < size; ++i) { + T rescale_factor = std::exp(max_data[i] - max); + rescaled_sum += rescale_factor * sum_data[i]; + sum_data[i] *= rescale_factor; + } + for (int i = 0; i < size; ++i) { + sum_data[i] /= rescaled_sum + 1e-8; + } +} + +template +struct reduceQKBlockKernel { + using q_load_vec_type = typename KernelVecType::q_load_vec_type; + using q_vec_type = typename KernelVecType::q_vec_type; + using k_load_vec_type = typename KernelVecType::k_load_vec_type; + using k_vec_type = typename KernelVecType::k_vec_type; + using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; + + constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; + constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; + constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; + + static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); + static_assert(k_load_vec_type::get_elem_num() % x == 0); + static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); + + FORCE_INLINE static void call(const scalar_t *__restrict__ q, + const scalar_t *__restrict__ k_block, + float *__restrict__ logits, float scale, + const int token_num) { + const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; + + qk_acc_vec_type group_accums[MAX_GROUP_NUM]; + if (token_num == BLOCK_SIZE) { + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block += x * BLOCK_SIZE) { + q_load_vec_type q_load_group_vec(q + q_offset); + q_vec_type q_group_vec(q_load_group_vec); + + vec_op::unroll_loop( + [k_block, &q_group_vec, &group_accums](int token_group_idx) { + k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * + TOKEN_PER_GROUP); + k_vec_type k_group_vec(k_load_group_vec); + vec_op::fma(group_accums[token_group_idx], q_group_vec, + k_group_vec); + vec_op::prefetch(k_block + x * BLOCK_SIZE + + token_group_idx * x * TOKEN_PER_GROUP); + }); + } + } else { + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block += x * BLOCK_SIZE) { + q_load_vec_type q_load_group_vec(q + q_offset); + q_vec_type q_group_vec(q_load_group_vec); + for (int token_group_start = 0; token_group_start < group_num; + token_group_start += UNROLL_GROUP_NUM) { + vec_op::unroll_loop( + [token_group_start, k_block, &q_group_vec, + &group_accums](int token_group_idx) { + token_group_idx += token_group_start; + k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * + TOKEN_PER_GROUP); + k_vec_type k_group_vec(k_load_group_vec); + vec_op::fma(group_accums[token_group_idx], q_group_vec, + k_group_vec); + vec_op::prefetch(k_block + x * BLOCK_SIZE + + token_group_idx * x * TOKEN_PER_GROUP); + }); + } + } + } + + for (int token_group_idx = 0; token_group_idx < group_num; + ++token_group_idx) { + vec_op::unroll_loop( + [&group_accums, logits, scale, token_group_idx](int token_idx) { + float dot_v = + group_accums[token_group_idx] + .template reduce_sub_sum(token_idx); + logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = + dot_v * scale; + }); + } + } +}; + +template +FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, + acc_t &&acc) { + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); + static_assert(BLOCK_SIZE == ELEM_NUM); + vec_op::FP32Vec16 prob_vec(prob); + + vec_op::unroll_loop([&](int head_elem_idx) { + v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx); + vec_op::FP32Vec16 fp32_v_vec(v_vec); + acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; + }); +} +}; // namespace + +// Paged attention v1 +namespace { +template +struct paged_attention_v1_impl { + static void + call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + static_assert(BLOCK_SIZE == 16); + + int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + + const int parallel_work_item_num = omp_get_max_threads(); + + size_t logits_bytes = + parallel_work_item_num * max_context_len_padded * sizeof(float); + float *logits = (float *)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_context_len_padded] + +#pragma omp parallel for collapse(2) schedule(dynamic, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + int context_len = context_lens[seq_idx]; + const int *seq_block_table = + block_tables + max_num_blocks_per_seq * seq_idx; + const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t *__restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + const int last_block_token_num = + context_len - (block_num - 1) * BLOCK_SIZE; + float *__restrict__ thread_block_logits = + logits + omp_get_thread_num() * max_context_len_padded; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t *__restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float *__restrict__ head_block_logits = + thread_block_logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( + q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, + block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + } + + // Compute softmax + if (alibi_slopes) { + reduceSoftmaxAlibi(thread_block_logits, context_len, + block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, + context_len); + } else { + reduceSoftmax(thread_block_logits, context_len, + block_num * BLOCK_SIZE); + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t *__restrict__ out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float *__restrict__ prob_vec_ptr = + thread_block_logits + block_idx * BLOCK_SIZE; + const scalar_t *__restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + reduceValueBlock( + prob_vec_ptr, v_block_cache_ptr, accums); + + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; + const scalar_t *__restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + vec_op::unroll_loop( + [&](int head_elem_idx) { + if (head_elem_idx % 2 == 0) { + vec_op::prefetch(next_v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + } + }); + } + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32(value, out_ptr + head_elem_idx); + }); + } + } + } + std::free(logits); + } +}; + +#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v1_impl::call( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ + num_heads); + +template +void paged_attention_v1_impl_launcher( + torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, + int max_context_len, const c10::optional &alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float *alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T *out_ptr = reinterpret_cast(out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *context_lens_ptr = context_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + context_lens, max_context_len, alibi_slopes); + +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, + torch::Tensor &key_cache, torch::Tensor &value_cache, + int num_kv_heads, float scale, + torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype) { + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) + CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) + }); +} + +// Paged attention v2 +namespace { +template +struct paged_attention_v2_impl { + static void call( + scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float + *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads, const int max_num_partitions) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + static_assert(BLOCK_SIZE == 16); + static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0); + static_assert(PARTITION_SIZE % BLOCK_SIZE == 0); + +#pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int partition_idx = 0; partition_idx < max_num_partitions; + ++partition_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int context_len = context_lens[seq_idx]; + const int start_token_idx = partition_idx * PARTITION_SIZE; + + if (start_token_idx >= context_len) + continue; + + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + const bool no_reduce = (partition_num == 1); + const int context_token_num = + (std::min(context_len, start_token_idx + PARTITION_SIZE) - + start_token_idx); + const int block_num = + (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int last_block_token_num = + context_token_num - (block_num - 1) * BLOCK_SIZE; + const int *seq_block_table = block_tables + + max_num_blocks_per_seq * seq_idx + + start_token_idx / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t *__restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + + float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t *__restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float *__restrict__ head_block_logits = + logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( + q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, + block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + } + + std::pair max_and_sum; + if (alibi_slopes) { + max_and_sum = reduceSoftmaxAlibi( + logits, context_token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, context_len); + } else { + max_and_sum = reduceSoftmax(logits, context_token_num, + block_num * BLOCK_SIZE); + } + + auto &&[max_logit, exp_sum] = max_and_sum; + + scalar_t *__restrict__ output_buffer = nullptr; + if (!no_reduce) { + auto idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + max_logits[idx] = max_logit; + exp_sums[idx] = exp_sum; + output_buffer = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + output_buffer = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t *__restrict__ out_ptr = + output_buffer + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float *__restrict__ prob_vec_ptr = + logits + block_idx * BLOCK_SIZE; + const scalar_t *__restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + reduceValueBlock( + prob_vec_ptr, v_block_cache_ptr, accums); + + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; + const scalar_t *__restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + vec_op::unroll_loop( + [&](int head_elem_idx) { + if (head_elem_idx % 2 == 0) { + vec_op::prefetch(next_v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + } + }); + } + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32(value, out_ptr + head_elem_idx); + }); + } + } + } + } + + // Rescale partition softmax and store the factors to exp_sums +#pragma omp parallel for collapse(2) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int context_len = context_lens[seq_idx]; + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + if (partition_num == 1) + continue; + + reducePartitonSoftmax( + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions, + exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions, + partition_num); + } + } + + // Reduce values + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); + constexpr int head_elem_num_per_group = + 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE + // didn't align with 64 bytes + static_assert(HEAD_SIZE % head_elem_num_per_group == 0); + constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; + const float *__restrict__ rescale_factors = exp_sums; +#pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { + const int context_len = context_lens[seq_idx]; + const int partition_num = + (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + if (partition_num == 1) + continue; + + const float *__restrict__ seq_head_rescale_factors = + rescale_factors + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + const scalar_t *__restrict__ seq_head_tmp_out = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + group_idx * head_elem_num_per_group; + scalar_t *__restrict__ seq_head_output = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + group_idx * head_elem_num_per_group; + + vec_op::FP32Vec16 acc; + for (int i = 0; i < partition_num; ++i) { + vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]); + v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE); + vec_op::FP32Vec16 fp32_value(value); + acc = acc + fp32_value * rescale_factor; + } + v_load_vec_type cast_acc(acc); + cast_acc.save(seq_head_output); + } + } + } + } +}; + +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ + max_num_partitions); + +template +void paged_attention_v2_impl_launcher( + torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, + torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, + int max_context_len, const c10::optional &alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + int max_num_partitions = exp_sums.size(-1); + + // NOTE: alibi_slopes is optional. + const float *alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T *out_ptr = reinterpret_cast(out.data_ptr()); + float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *context_lens_ptr = context_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, block_size, \ + max_context_len, alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, + torch::Tensor &max_logits, torch::Tensor &tmp_out, + torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, + float scale, torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype) { + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) + CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) + }); +} diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp new file mode 100644 index 0000000000000..94f5affc39f02 --- /dev/null +++ b/csrc/cpu/cache.cpp @@ -0,0 +1,139 @@ +#include +#include + +#include "cpu_types.hpp" + +namespace { +template +void copy_blocks_cpu_impl( + std::vector &key_caches, + std::vector &value_caches, + const std::vector> mapping_pairs, + const int element_num_per_block, const int layer_num) { + const size_t pair_num = mapping_pairs.size(); + const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; +#pragma omp parallel for collapse(2) + for (int layer = 0; layer < layer_num; ++layer) { + for (size_t pair = 0; pair < pair_num; ++pair) { + int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t target_offset = + element_num_per_block * mapping_pairs[pair].second; + scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t *source_ptr = key_cache_ptr + source_offset; + scalar_t *target_ptr = key_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + + scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + source_ptr = value_cache_ptr + source_offset; + target_ptr = value_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + } + } +} + +template +void reshape_and_cache_cpu_impl( + const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, + scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, + const int64_t *__restrict__ slot_mapping, const int num_tokens, + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x) { + const int block_elem_num = num_heads * head_size * block_size; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx >= 0) { + int src_key_head_idx = token_idx * key_stride + head_idx * head_size; + int src_value_head_idx = + token_idx * value_stride + head_idx * head_size; + const scalar_t *src_key_head_ptr = key + src_key_head_idx; + const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const int64_t block_index = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + scalar_t *target_key_head_ptr = key_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + scalar_t *target_value_head_ptr = value_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + + for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { + const int64_t target_offset = + src_key_idx * block_size + block_offset * x; + for (int i = 0; i < x; ++i) { + target_key_head_ptr[target_offset + i] = + src_key_head_ptr[src_key_idx + i]; + } + } + + for (int src_value_idx = 0; src_value_idx < head_size; + ++src_value_idx) { + const int64_t target_offset = + src_value_idx * block_size + block_offset; + target_value_head_ptr[target_offset] = + src_value_head_ptr[src_value_idx]; + } + } + } + } +} +}; // namespace + +void copy_blocks(std::vector &key_caches, + std::vector &value_caches, + const std::map> &block_mapping) { + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + + std::vector> mapping_pairs; + mapping_pairs.reserve(block_mapping.size()); + for (const auto &pair : block_mapping) { + for (const auto &dst : pair.second) { + mapping_pairs.emplace_back(pair.first, dst); + } + } + + const int element_num_per_block = key_caches[0][0].numel(); + VLLM_DISPATCH_FLOATING_TYPES( + key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) + copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + element_num_per_block, num_layers); + CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) + }); +} + +void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping, + const std::string &kv_cache_dtype) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) + reshape_and_cache_cpu_impl( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), num_tokens, key_stride, + value_stride, num_heads, head_size, block_size, x); + CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) + }); +} + +void swap_blocks(torch::Tensor &src, torch::Tensor &dst, + const std::map &block_mapping) { + TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") +} diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp new file mode 100644 index 0000000000000..c1d3ec058b991 --- /dev/null +++ b/csrc/cpu/cpu_types.hpp @@ -0,0 +1,352 @@ + +#ifndef CPU_TYPES_HPP +#define CPU_TYPES_HPP + +#include +#include + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +#ifdef __AVX512FP16__ +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128h reg; + + explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + + explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + + explicit FP16Vec8(__m128h data) : reg(data) {} + + FP16Vec8 operator*(const FP16Vec8 &b) const { + return FP16Vec8(_mm_mul_ph(reg, b.reg)); + } + + FP16Vec8 operator+(const FP16Vec8 &b) const { + return FP16Vec8(_mm_add_ph(reg, b.reg)); + } + + FP16Vec8 operator-(const FP16Vec8 &b) const { + return FP16Vec8(_mm_sub_ph(reg, b.reg)); + } + + FP16Vec8 operator/(const FP16Vec8 &b) const { + return FP16Vec8(_mm_div_ph(reg, b.reg)); + } + + void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +}; +#endif + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m512i reg; + + explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + + explicit BF16Vec32(__m512i data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512i)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + +#ifdef __AVX512FP16__ + explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +#endif + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg((__m512)_mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), + (__m128i)data.reg, 1), + (__m128i)data.reg, 2), + (__m128i)data.reg, 3)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg((__m512)_mm512_inserti32x8( + _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + : reg(_mm512_castsi512_ps( + _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_sub_ps(reg, b.reg)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_div_ps(reg, b.reg)); + } + + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); + return _mm512_mask_reduce_add_ps(mask, reg); + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +#ifdef __AVX512FP16__ +template <> struct VecType { using vec_type = FP16Vec16; }; +#endif + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +#ifdef __AVX512FP16__ +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast<_Float16 *>(ptr) = v; +} +#endif + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +#ifdef __AVX512BF16__ +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} + +inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { + acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); +} +#else +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtepi32_epi16( + _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtepi32_epi16( + _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} +#endif + +inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + +}; // namespace vec_op + +#endif diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp new file mode 100644 index 0000000000000..467f0dc84982c --- /dev/null +++ b/csrc/cpu/layernorm.cpp @@ -0,0 +1,117 @@ +#include "cpu_types.hpp" + +namespace { +template +void rms_norm_impl(scalar_t *__restrict__ out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto output_p = out + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + vec_op::FP32Vec8 fp32_x(x); + variance = variance + fp32_x * fp32_x; + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t w(weight + j); + + vec_op::FP32Vec8 fp32_x(x); + vec_op::FP32Vec8 fp32_w(w); + + vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(output_p + j); + } + } +} + +template +void fused_add_rms_norm_impl(scalar_t *__restrict__ input, + scalar_t *__restrict__ residual, + const scalar_t *__restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto residual_p = residual + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t res(residual_p + j); + vec_op::FP32Vec8 fp32_x(x); + vec_op::FP32Vec8 fp32_res(res); + + fp32_x = fp32_x + fp32_res; + variance = variance + fp32_x * fp32_x; + scalar_vec_t out(fp32_x); + out.save(residual_p + j); + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t w(weight + j); + scalar_vec_t res(residual_p + j); + + vec_op::FP32Vec8 fp32_w(w); + vec_op::FP32Vec8 fp32_res(res); + + vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(input_p + j); + } + } +} +} // namespace + +void rms_norm(torch::Tensor &out, torch::Tensor &input, + torch::Tensor &weight, float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(rms_norm_impl) + rms_norm_impl(out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_impl) + }); +} + +void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, + torch::Tensor &weight, float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "fused_add_rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl) + fused_add_rms_norm_impl( + input.data_ptr(), residual.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl) + }); +} diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp new file mode 100644 index 0000000000000..e9b3992204bb2 --- /dev/null +++ b/csrc/cpu/pos_encoding.cpp @@ -0,0 +1,199 @@ + +#include "cpu_types.hpp" + +namespace { +template +void rotary_embedding_impl( + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + constexpr int ELEM_SIZE = sizeof(scalar_t); + + const int embed_dim = rot_dim / 2; + TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t q_x(query + out_x); + const scalar_vec_t q_y(query + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); + + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(query + out_x); + + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(query + out_y); + } + } + + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t k_x(key + out_x); + const scalar_vec_t k_y(key + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_k_x(k_x); + vec_op::FP32Vec8 fp32_k_y(k_y); + + auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; + scalar_vec_t(out1).save(key + out_x); + auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; + scalar_vec_t(out2).save(key + out_y); + } + } + } +} + +template +void rotary_embedding_gptj_impl( + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + const int embed_dim = rot_dim / 2; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_heads; ++i) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t *cos_cache_ptr = cache_ptr; + const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + scalar_t *head_query = token_head + query; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; + const int y_index = 2 * rot_offset + 1; + + const float cos = cos_cache_ptr[rot_offset]; + const float sin = sin_cache_ptr[rot_offset]; + + const float x = head_query[x_index]; + const float y = head_query[y_index]; + + head_query[x_index] = x * cos - y * sin; + head_query[y_index] = y * cos + x * sin; + } + } + } + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_kv_heads; ++i) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t *cos_cache_ptr = cache_ptr; + const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + scalar_t *head_key = key + token_head; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; + const int y_index = 2 * rot_offset + 1; + + const float cos = cos_cache_ptr[rot_offset]; + const float sin = sin_cache_ptr[rot_offset]; + + const float x = head_key[x_index]; + const float y = head_key[y_index]; + + head_key[x_index] = x * cos - y * sin; + head_key[y_index] = y * cos + x * sin; + } + } + } +} +}; // namespace + +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox) { + int num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t key_stride = key.stride(-2); + int64_t query_stride = query.stride(-2); + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "rotary_embedding_impl", [&] { + CPU_KERNEL_GUARD_IN(rotary_embedding_impl) + if (is_neox) { + rotary_embedding_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + } else { + rotary_embedding_gptj_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + } + + CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) + }); +} diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp new file mode 100644 index 0000000000000..bba044087f37c --- /dev/null +++ b/csrc/cpu/pybind.cpp @@ -0,0 +1,73 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // vLLM custom ops + pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + + // Attention ops + ops.def( + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); + ops.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); + + // Activation ops + ops.def( + "silu_and_mul", + &silu_and_mul, + "Activation function used in SwiGLU."); + ops.def( + "gelu_and_mul", + &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def( + "gelu_tanh_and_mul", + &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def( + "gelu_new", + &gelu_new, + "GELU implementation used in GPT-2."); + ops.def( + "gelu_fast", + &gelu_fast, + "Approximate GELU implementation."); + + // Layernorm + ops.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + ops.def( + "fused_add_rms_norm", + &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); + + // Rotary embedding + ops.def( + "rotary_embedding", + &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + + // Cache ops + pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); + cache_ops.def( + "swap_blocks", + &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def( + "copy_blocks", + ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def( + "reshape_and_cache", + &reshape_and_cache, + "Reshape the key and value tensors and cache them"); +} diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 96749b9327d7a..0e76763a87b7c 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -7,4 +7,6 @@ sphinx-argparse # packages to install to build the documentation pydantic -f https://download.pytorch.org/whl/cpu -torch \ No newline at end of file +torch +py-cpuinfo +transformers diff --git a/docs/source/conf.py b/docs/source/conf.py index 61d8e55d2cc6c..5619ea2191934 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -75,6 +75,7 @@ # Mock out external dependencies here. autodoc_mock_imports = [ + "cpuinfo", "torch", "transformers", "psutil", diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst new file mode 100644 index 0000000000000..ba8b0645adcdf --- /dev/null +++ b/docs/source/getting_started/cpu-installation.rst @@ -0,0 +1,87 @@ +.. _installation_cpu: + +Installation with CPU +======================== + +vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. + +Table of contents: + +#. :ref:`Requirements ` +#. :ref:`Quick start using Dockerfile ` +#. :ref:`Build from source ` +#. :ref:`Performance tips ` + +.. _cpu_backend_requirements: + +Requirements +------------ + +* OS: Linux +* Compiler: gcc/g++>=12.3.0 (recommended) +* Instruction set architecture (ISA) requirement: AVX512 is required. + +.. _cpu_backend_quick_start_dockerfile: + +Quick start using Dockerfile +---------------------------- + +.. code-block:: console + + $ docker build -f Dockerfile.cpu -t vllm-cpu-env --shm-size=4g . + $ docker run -it \ + --rm \ + --network=host \ + --cpuset-cpus= \ + --cpuset-mems= \ + vllm-cpu-env + +.. _build_cpu_backend_from_source: + +Build from source +----------------- + +- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: + +.. code-block:: console + + $ sudo apt-get update -y + $ sudo apt-get install -y gcc-12 g++-12 + $ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +- Second, install Python packages for vLLM CPU backend building: + +.. code-block:: console + + $ pip install --upgrade pip + $ pip install wheel packaging ninja setuptools>=49.4.0 numpy + $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + +- Finally, build and install vLLM CPU backend: + +.. code-block:: console + + $ VLLM_TARGET_DEVICE=cpu python setup.py install + +.. note:: + - BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support. + + - AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16. + + - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building. + +.. _cpu_backend_performance_tips: + +Performance tips +----------------- + +- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. + +- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription. + +- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading. + +- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful. + + + diff --git a/docs/source/index.rst b/docs/source/index.rst index 5196ef062dc19..390409204cbc3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -63,6 +63,7 @@ Documentation getting_started/installation getting_started/amd-installation getting_started/neuron-installation + getting_started/cpu-installation getting_started/quickstart .. toctree:: diff --git a/requirements-12.1.0.txt b/requirements-12.1.0.txt index 7e84034dde439..df0f6dd1ee3ca 100644 --- a/requirements-12.1.0.txt +++ b/requirements-12.1.0.txt @@ -6,7 +6,6 @@ sentencepiece # Required for LLaMA tokenizer. numpy torch == 2.1.2 requests -psutil py-cpuinfo transformers >= 4.39.1 # Required for StarCoder2 & Llava. xformers == 0.0.23.post1 # Required for CUDA 12.1. diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000000000..580bffea5a018 --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,15 @@ +cmake>=3.21 +ninja # For faster builds. +psutil +ray >= 2.9 +sentencepiece # Required for LLaMA tokenizer. +numpy +transformers >= 4.38.0 # Required for Gemma. +fastapi +uvicorn[standard] +pydantic >= 2.0 # Required for OpenAI server. +prometheus_client >= 0.18.0 +torch == 2.1.2+cpu +triton >= 2.1.0 +filelock == 3.13.3 +py-cpuinfo \ No newline at end of file diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 0dc2f0e664114..4e9f598551fee 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -15,3 +15,4 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 outlines == 0.0.34 +tiktoken == 0.6.0 # Required for DBRX tokenizer diff --git a/setup.py b/setup.py index 0406c814c0de2..f4cb4e975f369 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,8 @@ ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) +# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] +VLLM_TARGET_DEVICE = os.getenv("VLLM_TARGET_DEVICE", "cuda") # vLLM only supports Linux platform assert sys.platform.startswith( @@ -115,6 +117,7 @@ def configure(self, ext: CMakeExtension) -> None: '-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir), '-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp), + '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), ] verbose = bool(int(os.getenv('VERBOSE', '0'))) @@ -188,11 +191,14 @@ def build_extensions(self) -> None: def _is_cuda() -> bool: - return torch.version.cuda is not None and not _is_neuron() + return VLLM_TARGET_DEVICE == "cuda" \ + and torch.version.cuda is not None \ + and not _is_neuron() def _is_hip() -> bool: - return torch.version.hip is not None + return (VLLM_TARGET_DEVICE == "cuda" + or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None def _is_neuron() -> bool: @@ -204,6 +210,10 @@ def _is_neuron() -> bool: return torch_neuronx_installed +def _is_cpu() -> bool: + return VLLM_TARGET_DEVICE == "cpu" + + def _install_punica() -> bool: return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) @@ -299,6 +309,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" + elif _is_cpu(): + version += "+cpu" else: raise RuntimeError("Unknown runtime environment") @@ -325,6 +337,9 @@ def get_requirements() -> List[str]: elif _is_neuron(): with open(get_path("requirements-neuron.txt")) as f: requirements = f.read().strip().split("\n") + elif _is_cpu(): + with open(get_path("requirements-cpu.txt")) as f: + requirements = f.read().strip().split("\n") else: raise ValueError( "Unsupported platform, please use CUDA, ROCM or Neuron.") diff --git a/tests/conftest.py b/tests/conftest.py index 770da1e6f14b8..5c409c8cd5ee5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,10 +55,20 @@ def cleanup(): torch.cuda.empty_cache() +@pytest.fixture() +def should_do_global_cleanup_after_test() -> bool: + """Allow subdirectories to skip global cleanup by overriding this fixture. + This can provide a ~10x speedup for non-GPU unit tests since they don't need + to initialize torch. + """ + return True + + @pytest.fixture(autouse=True) -def cleanup_fixture(): +def cleanup_fixture(should_do_global_cleanup_after_test: bool): yield - cleanup() + if should_do_global_cleanup_after_test: + cleanup() @pytest.fixture(scope="session") diff --git a/tests/core/block/conftest.py b/tests/core/block/conftest.py new file mode 100644 index 0000000000000..0464d6a74da61 --- /dev/null +++ b/tests/core/block/conftest.py @@ -0,0 +1,12 @@ +import pytest + + +@pytest.fixture() +def should_do_global_cleanup_after_test() -> bool: + """Disable the global cleanup fixture for tests in this directory. This + provides a ~10x speedup for unit tests that don't load a model to GPU. + + This requires that tests in this directory clean up after themselves if they + use the GPU. + """ + return False diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index e1a9dd28e5737..1d99cb5d32219 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,25 +1,10 @@ -import contextlib -import gc - import pytest -import ray -import torch +from tests.conftest import cleanup from vllm import LLM -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel) from vllm.model_executor.utils import set_random_seed -def cleanup(): - destroy_model_parallel() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - gc.collect() - torch.cuda.empty_cache() - ray.shutdown() - - @pytest.fixture def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, seed): diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 283d99fe0b193..5a7f828456e2d 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -77,6 +77,159 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Use a large block size to trigger more copy-on-writes. + "block_size": 32, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify beam search equality with block manager v1 and v2. + + This requires copy-on-writes; if the v1 and v2 output is the same, then + we have some confidence cow is working. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + use_beam_search=True, + best_of=2, + ) + + print('Getting token ids from block manager v1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids from block manager v2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # Our prompts will generate 128 tokens; since the prompts themselves are + # small, we don't need much KV space beyond 128. + "max_model_len": 160, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Lookahead scheduling only supported in v2 block manager. + "use_v2_block_manager": True, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + "block_size": 16, + + # Allow only 2 sequences of ~128 tokens in worst case. + # Note 8 = 128/block_size + "forced_num_gpu_blocks": 2 * (8 + 1), + }, + { + "block_size": 8, + + # Allow only 2 sequences of ~128 tokens in worst case. + # Note 16 = 128/block_size + "forced_num_gpu_blocks": 2 * (16 + 1), + } + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "num_lookahead_slots": 0, +}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [{ + # We run one test with block_size < lookahead_slots, one test with + # block_size > lookahead_slots + "num_lookahead_slots": 10, + }]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size): + """Verify vLLM produces the same output with greedy sampling, when lookahead + scheduling is used vs. not. + + Lookahead scheduling is not expected to modify the output, as it simply + allocates empty slots ahead of the known token ids in a sliding fashion. + + This test constrains the total number of blocks to force preemption. It also + varies the block size so that the lookahead size is less than and greater + than the block size. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids without lookahead scheduling') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with lookahead scheduling') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py new file mode 100644 index 0000000000000..1e8e4ccdfb151 --- /dev/null +++ b/tests/core/block/test_block_manager_v2.py @@ -0,0 +1,103 @@ +import pytest + +from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.core.interfaces import AllocStatus +from vllm.sequence import Logprob, SequenceStatus +from vllm.utils import chunk_list + +from ..utils import create_seq_group + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) +@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, + num_gpu_blocks: int, watermark: float): + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + ) + num_watermark_blocks = int(watermark * num_gpu_blocks) + + num_output_blocks_per_seq = 1 + + # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but + # the current implementation assumes all seqs are new prompts / don't have + # different output lens. + num_output_blocks = num_output_blocks_per_seq + + for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks): + seq_group = create_seq_group( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + ) + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + can_allocate_result = block_manager.can_allocate(seq_group) + + num_required_blocks = num_prompt_blocks + num_output_blocks + + if num_gpu_blocks - num_required_blocks < num_watermark_blocks: + assert can_allocate_result == AllocStatus.NEVER + elif num_gpu_blocks >= num_required_blocks: + assert can_allocate_result == AllocStatus.OK + else: + assert can_allocate_result == AllocStatus.LATER + + +@pytest.mark.parametrize("block_size", [1, 8]) +@pytest.mark.parametrize("prompt_len", [1, 7, 8]) +@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) +@pytest.mark.parametrize("num_lookahead_slots", [0, 10]) +def test_append_slots(block_size, prompt_len, num_slots_to_append, + num_lookahead_slots): + """Verify append_slots consumes the correct number of blocks from the block + table. + """ + + num_gpu_blocks = 1024 + watermark = 0.1 + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + watermark=watermark, + ) + + seq_group = create_seq_group( + seq_prompt_len=prompt_len, + seq_output_lens=[0], + ) + + # Allocate seq + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + + # Seq seq to RUNNING + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + # Append tokens to the sequeqnce + for token_id in range(num_slots_to_append): + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Append slots for new tokens and lookahead slots. + free_blocks_before_append = block_manager.get_num_free_gpu_blocks() + block_manager.append_slots(seq, num_lookahead_slots) + num_consumed_blocks = (free_blocks_before_append - + block_manager.get_num_free_gpu_blocks()) + + # Expect consumed blocks to be new blocks required to support the new slots. + expected_consumed_blocks = len( + chunk_list( + list( + range(prompt_len + num_slots_to_append + num_lookahead_slots)), + block_size)) - len(chunk_list(list(range(prompt_len)), block_size)) + assert num_consumed_blocks == expected_consumed_blocks diff --git a/tests/core/block/test_block_space_manager.py b/tests/core/block/test_block_space_manager.py deleted file mode 100644 index eec8cbcb38803..0000000000000 --- a/tests/core/block/test_block_space_manager.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest - -from vllm.core.block_manager_v2 import BlockSpaceManagerV2 -from vllm.core.interfaces import AllocStatus - -from ..utils import create_seq_group - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) -@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, - num_gpu_blocks: int, watermark: float): - block_manager = BlockSpaceManagerV2( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - ) - num_watermark_blocks = int(watermark * num_gpu_blocks) - - num_output_blocks_per_seq = 1 - - # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but - # the current implementation assumes all seqs are new prompts / don't have - # different output lens. - num_output_blocks = num_output_blocks_per_seq - - for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks): - seq_group = create_seq_group( - seq_prompt_lens=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - ) - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - - can_allocate_result = block_manager.can_allocate(seq_group) - - num_required_blocks = num_prompt_blocks + num_output_blocks - - if num_gpu_blocks - num_required_blocks < num_watermark_blocks: - assert can_allocate_result == AllocStatus.NEVER - elif num_gpu_blocks >= num_required_blocks: - assert can_allocate_result == AllocStatus.OK - else: - assert can_allocate_result == AllocStatus.LATER diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index a7c5aa2b1df59..3481d6b4312c1 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -498,3 +498,78 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, # After free, expect all blocks to be freed. assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + +@pytest.mark.parametrize("block_size", [1, 8]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +@pytest.mark.parametrize("num_new_tokens", [1, 16, 129]) +@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, + num_new_tokens: int, + num_lookahead_slots: int, + allocator_type: str): + """Verify correct calculation of get_num_blocks_touched_by_append_slots. + + This is done by using copy-on-write, which requires any modified block to + be copied before write if the refcount > 1. We set the refcount>1 by forking + a sequence, then measure the free blocks before and after an append. If the + number of consumed blocks equals what `get_num_blocks_touched_by_append_ + slots` returns, then the calculation is correct. + """ + + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + token_ids_to_append = list(range(num_new_tokens)) + + block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + + block_table.allocate(token_ids=token_ids, device=Device.GPU) + + # Add lookahead before fork so both sequences have the same lookahead + # blocks. + block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots) + + # Fork sequence so that every block has refcount > 1. + _ = block_table.fork() + + # Determine how many blocks should be touched. + expected_num_touched_blocks = ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=token_ids_to_append, + num_lookahead_slots=num_lookahead_slots)) + + # Measure how many blocks are touched by measuring num_free_blocks before + # and after the append. + # + # We expect append_token_ids to CoW all mutated blocks that have refcount>1. + num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU) + block_table.append_token_ids(token_ids_to_append, num_lookahead_slots) + num_consumed_blocks = (num_free_blocks_before_append - + allocator.get_num_free_blocks(Device.GPU)) + + # TODO(cade) ensure equality when num_lookahead_slots > 0. + # The reason we have < is because lookahead blocks are not copied eagerly; + # they are copied on first write. This will cause issues for beam search + + # speculative decoding. This is acceptable for now as it is a large effort + # to combine the two. To fix this, we can ensure single sequence ownership + # of lookahead blocks by appending empty slots to each block, which will + # trigger the CoW. + # + # Until then, we can accept that the consumed tokens are <= the expected + # tokens when appending with lookahead. + if num_lookahead_slots > 0: + assert num_consumed_blocks <= expected_num_touched_blocks + else: + assert num_consumed_blocks == expected_num_touched_blocks diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 93226cba1909c..62984ef4caabb 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -103,9 +103,9 @@ def test_append_slot_single_seq(): block_manager.allocate(seq_group) # Nothing to append. Sequence has no new logical blocks. - assert block_manager.can_append_slot(seq_group) + assert block_manager.can_append_slots(seq_group) before_blocks = block_manager.get_num_free_gpu_blocks() - assert not block_manager.append_slot(prompt) + assert not block_manager.append_slots(prompt) after_blocks = block_manager.get_num_free_gpu_blocks() assert before_blocks == after_blocks @@ -114,9 +114,9 @@ def test_append_slot_single_seq(): token_id = i + 5 prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - assert block_manager.can_append_slot(seq_group) + assert block_manager.can_append_slots(seq_group) before_blocks = block_manager.get_num_free_gpu_blocks() - assert not block_manager.append_slot(prompt) + assert not block_manager.append_slots(prompt) after_blocks = block_manager.get_num_free_gpu_blocks() assert before_blocks - after_blocks == 1 @@ -150,13 +150,13 @@ def test_append_slot_cow(): child.append_token_id(token_id, {token_id: Logprob(0.0)}) block_manager.fork(prompt, child) - assert block_manager.can_append_slot(seq_group) + assert block_manager.can_append_slots(seq_group) before_blocks = block_manager.get_num_free_gpu_blocks() - maybe_src_dst_block = block_manager.append_slot(child) - assert maybe_src_dst_block is not None - src_block, dst_block = maybe_src_dst_block - assert src_block != dst_block + cows = block_manager.append_slots(child) + assert cows + for src_block, dst_blocks in cows.items(): + assert src_block not in dst_blocks after_blocks = block_manager.get_num_free_gpu_blocks() assert before_blocks - after_blocks == 1 @@ -184,7 +184,7 @@ def test_fork(): token_id = 4 # Append token to child. Block is shared so copy on write occurs. child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slot(child) + block_manager.append_slots(child) assert block_manager.get_block_table( prompt) != block_manager.get_block_table(child) @@ -325,7 +325,7 @@ def test_sliding_window_multi_seq(): token_id = 4 # Append token to child. Block is shared so copy on write occurs. child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slot(child) + block_manager.append_slots(child) # assert the number of blocks allocated is correct # we will use now one block more. Each seq will use 2 blocks, @@ -335,7 +335,7 @@ def test_sliding_window_multi_seq(): token_id = 5 parent.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slot(parent) + block_manager.append_slots(parent) # assert the number of blocks allocated is correct # no change, because both sequences are still just sharing one block diff --git a/tests/core/utils.py b/tests/core/utils.py index 2e462f2aec4d4..9482c7761c286 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -24,7 +24,7 @@ def create_dummy_prompt( def create_seq_group( - seq_prompt_lens=1024, + seq_prompt_len=1024, seq_output_lens=(128, ), request_id='0', seq_id_start=0, @@ -32,7 +32,7 @@ def create_seq_group( assert len(seq_output_lens) > 0 - prompt_token_ids = [0] * seq_prompt_lens + prompt_token_ids = [0] * seq_prompt_len seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py new file mode 100644 index 0000000000000..cd64622e2226f --- /dev/null +++ b/tests/quantization/test_autogptq_marlin_configs.py @@ -0,0 +1,68 @@ +"""Tests whether Marlin models can be loaded from the autogptq config. + +Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`. +""" + +from dataclasses import dataclass + +import pytest + +from vllm.config import ModelConfig + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +# Model Id // Expected Kernel +MODELS_QUANT_TYPE = [ + # compat: autogptq <=0.7.1 is_marlin_format: bool + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"), + # compat: autogptq >=0.8.0 use checkpoint_format: str + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq") +] + + +@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE) +def test_auto_gptq(model_quant_type: str, ) -> None: + model_path, quant_type = model_quant_type + + model_config_no_quant_arg = ModelConfig( + model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + quantization=None # case 1 + ) + + model_config_quant_arg = ModelConfig( + model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + quantization="gptq" # case 2 + ) + + assert model_config_no_quant_arg.quantization == quant_type, ( + f"Expected quant_type == {quant_type} for {model_path}, " + f"but found {model_config_no_quant_arg.quantization} " + "for no --quantization None case") + + assert model_config_quant_arg.quantization == quant_type, ( + f"Expected quant_type == {quant_type} for {model_path}, " + f"but found {model_config_quant_arg.quantization} " + "for --quantization gptq case") diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 92587b40dd45a..9bc9becb2a6f1 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -4,8 +4,8 @@ from transformers import AutoTokenizer from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer import (Detokenizer, + detokenize_incrementally) from vllm.transformers_utils.tokenizer_group import get_tokenizer_group TRUTH = [ diff --git a/vllm/__init__.py b/vllm/__init__.py index d53e591bcb062..52c36f55e9ebe 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -8,7 +8,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -__version__ = "0.4.0" +__version__ = "0.4.0.post1" __all__ = [ "LLM", diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py new file mode 100644 index 0000000000000..4f69ebef662cb --- /dev/null +++ b/vllm/attention/backends/torch_sdpa.py @@ -0,0 +1,253 @@ +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + + +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": + return TorchSDPAMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + prompt_lens: Optional[List[int]] + prompt_lens_tensor: Optional[torch.Tensor] + num_prompt_tokens: int + num_generation_tokens: int + + max_subquery_len: Optional[int] = None + max_prompt_len: Optional[int] = None + subquery_start_loc: Optional[torch.Tensor] = None + seq_start_loc: Optional[torch.Tensor] = None + use_cuda_graph: bool = False + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + +class TorchSDPABackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + if alibi_slopes is not None: + assert len(alibi_slopes) == num_heads + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: TorchSDPAMetadata, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype) + + if attn_metadata.is_prompt: + if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.prompt_lens) # type: ignore + elif self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.prompt_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = [None] * len(attn_metadata.prompt_lens) + attn_metadata.attn_bias = att_masks + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(attn_metadata.prompt_lens, + attn_metadata.attn_bias): + end = start + prompt_len + sub_out = scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=not self.need_mask, + scale=self.scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + else: + # prefix-enabled attention + raise RuntimeError( + "Torch SDPA backend doesn't support prefix decoding.") + + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + prompt_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for prompt_len in prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, prompt_len, prompt_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + prompt_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for prompt_len in prompt_lens: + tensor = torch.full( + (1, prompt_len, prompt_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index c2ec4376c9f3c..b5cd39bbe6252 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -5,7 +5,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.utils import is_hip +from vllm.utils import is_cpu, is_hip logger = init_logger(__name__) @@ -17,6 +17,10 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend + elif is_cpu(): + logger.info("Using Torch SDPA backend.") + from vllm.attention.backends.torch_sdpa import TorchSDPABackend + return TorchSDPABackend else: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -29,6 +33,8 @@ def _can_use_flash_attn(dtype: torch.dtype) -> bool: # AMD GPUs. logger.info("Cannot use FlashAttention backend for AMD GPUs.") return False + if is_cpu(): + return False if torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " diff --git a/vllm/config.py b/vllm/config.py index 62f1d70079648..eef3fc53c3a65 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,7 +10,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron +from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, + is_neuron) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -171,26 +172,28 @@ def _verify_quantization(self) -> None: self.quantization = self.quantization.lower() # Parse quantization method from the HF model config, if available. - hf_quant_config = getattr(self.hf_config, "quantization_config", None) - if hf_quant_config is not None: - hf_quant_method = str(hf_quant_config["quant_method"]).lower() - - # If the GPTQ model is serialized in marlin format, use marlin. - if (hf_quant_method == "gptq" - and "is_marlin_format" in hf_quant_config - and hf_quant_config["is_marlin_format"]): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" + or quant_cfg.get("is_marlin_format", False)) + + # Use marlin if the GPTQ model is serialized in marlin format. + if quant_method == "gptq" and is_format_marlin: logger.info("The model is serialized in Marlin format. " "Using Marlin kernel.") - hf_quant_method = "marlin" + quant_method = "marlin" if self.quantization == "gptq": - self.quantization = hf_quant_method + self.quantization = quant_method if self.quantization is None: - self.quantization = hf_quant_method - elif self.quantization != hf_quant_method: + self.quantization = quant_method + elif self.quantization != quant_method: raise ValueError( "Quantization method specified in the model config " - f"({hf_quant_method}) does not match the quantization " + f"({quant_method}) does not match the quantization " f"method specified in the `quantization` argument " f"({self.quantization}).") @@ -530,9 +533,13 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). + use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. + num_lookahead_slots: The number of slots to allocate per sequence per + step, beyond the known token ids. This is used in speculative + decoding to store KV activations of tokens which may or may not be + accepted. delay_factor: Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt. - use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. """ @@ -543,6 +550,7 @@ def __init__( max_num_seqs: int, max_model_len: int, use_v2_block_manager: bool = False, + num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, ) -> None: @@ -554,9 +562,11 @@ def __init__( self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.delay_factor = delay_factor self.use_v2_block_manager = use_v2_block_manager + self.num_lookahead_slots = num_lookahead_slots + self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill + self._verify_args() def _verify_args(self) -> None: @@ -568,12 +578,19 @@ def _verify_args(self) -> None: "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") + if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).") + if self.num_lookahead_slots < 0: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be greater than or " + "equal to 0.") + class DeviceConfig: @@ -582,6 +599,8 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if is_neuron(): self.device_type = "neuron" + elif is_cpu(): + self.device_type = "cpu" else: # We don't call torch.cuda.is_available() here to # avoid initializing CUDA before workers are forked diff --git a/vllm/core/block/__init__.py b/vllm/core/block/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 793c6698633af..ba061bbc4fbcb 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -85,7 +85,9 @@ def allocate(self, device=device) self._num_full_slots = len(token_ids) - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids(self, + token_ids: List[int], + num_lookahead_slots: int = 0) -> None: """Appends a sequence of token IDs to the existing blocks in the BlockTable. @@ -102,14 +104,13 @@ def append_token_ids(self, token_ids: List[int]) -> None: token_ids (List[int]): The sequence of token IDs to be appended. """ assert self._is_allocated + assert token_ids, "can't append empty token ids" - self.ensure_num_empty_slots(num_empty_slots=len(token_ids)) + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + + num_lookahead_slots) blocks = self._blocks[self._num_full_slots // self._block_size:] - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - token_blocks = [token_ids[:first_chunk_size]] + chunk_list( - token_ids[first_chunk_size:], self._block_size) + token_blocks = self._chunk_token_blocks_for_append(token_ids) for block, token_block in zip(blocks, token_blocks): block.append_token_ids(token_block) @@ -195,6 +196,25 @@ def physical_block_ids(self) -> List[int]: assert self._is_allocated return [block.block_id for block in self._blocks] + def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + """Get the number of "unseen" tokens in the sequence. + + Unseen tokens are tokens in the sequence corresponding to this block + table, but are not yet appended to this block table. + + Args: + sequence_token_ids (List[int]): The list of token ids in the + sequence. + + Returns: + List[int]: The postfix of sequence_token_ids that has not yet been + appended to the block table. + """ + + # Since the block table is append-only, the unseen token ids are the + # ones after the appended ones. + return sequence_token_ids[self.num_full_slots:] + def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], token_ids: List[int], device: Device) -> List[Block]: @@ -243,3 +263,29 @@ def num_full_slots(self) -> int: int: The total number of tokens currently stored in the BlockTable. """ return self._num_full_slots + + def get_num_blocks_touched_by_append_slots( + self, token_ids: List[int], num_lookahead_slots: int) -> int: + """Determine how many blocks will be "touched" by appending the token + ids. + + This is required for the scheduler to determine whether a sequence can + continue generation, or if it must be preempted. + """ + + all_token_ids = token_ids + [-1] * num_lookahead_slots + token_blocks = self._chunk_token_blocks_for_append(all_token_ids) + return len(token_blocks) + + def _chunk_token_blocks_for_append( + self, token_ids: List[int]) -> List[List[int]]: + """Split the token ids into block-sized chunks so they can be easily + appended to blocks. The first such "token block" may have less token ids + than the block size, since the last allocated block may be partially + full. + """ + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + token_blocks = [token_ids[:first_chunk_size]] + chunk_list( + token_ids[first_chunk_size:], self._block_size) + return token_blocks diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 160a86556f031..b2aaeb33c5299 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from itertools import count, takewhile from os.path import commonprefix -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set from vllm.block import BlockTable, PhysicalTokenBlock from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor @@ -292,7 +292,12 @@ def allocate(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() - def can_append_slot(self, seq_group: SequenceGroup) -> bool: + def can_append_slots(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> bool: + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerV1" + # Simple heuristic: If there is at least one free block # for each sequence, we can append. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() @@ -364,10 +369,11 @@ def _allocate_last_physical_block( assert new_block.ref_count == 1 return new_block - def append_slot( + def append_slots( self, seq: Sequence, - ) -> Optional[Tuple[int, int]]: + num_lookahead_slots: int = 0, + ) -> Dict[int, List[int]]: """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] @@ -386,7 +392,7 @@ def append_slot( # Allocate a new physical block. new_block = self._allocate_last_physical_block(seq) block_table.append(new_block) - return None + return {} # We want to append the token to the last physical block. last_block = block_table[-1] @@ -399,7 +405,7 @@ def append_slot( maybe_new_block = self._maybe_promote_last_block( seq, last_block) block_table[-1] = maybe_new_block - return None + return {} else: # The last block is shared with other sequences. # Copy on Write: Allocate a new block and copy the tokens. @@ -407,7 +413,7 @@ def append_slot( block_table[-1] = new_block self.gpu_allocator.free(last_block) - return last_block.block_number, new_block.block_number + return {last_block.block_number: [new_block.block_number]} def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. @@ -433,7 +439,11 @@ def _get_physical_blocks( blocks.update(self.block_tables[seq.seq_id]) return list(blocks) - def can_swap_in(self, seq_group: SequenceGroup) -> bool: + def can_swap_in(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> bool: + assert (num_lookahead_slots == 0 + ), "BlockSpaceManagerV1 does not support lookahead allocation" blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) num_free_blocks = self.gpu_allocator.get_num_free_blocks() @@ -443,7 +453,12 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool: num_required_blocks = len(blocks) + num_swapped_seqs return num_free_blocks - num_required_blocks >= self.watermark_blocks - def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_in(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> Dict[int, int]: + assert (num_lookahead_slots == 0 + ), "BlockSpaceManagerV1 does not support lookahead allocation" + # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 37c70073b663b..813e71ad883b2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,5 +1,5 @@ """A block manager that manages token blocks.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator @@ -21,6 +21,24 @@ class BlockSpaceManagerV2(BlockSpaceManager): sliding-window are not feature complete. This class implements the design described in https://github.com/vllm-project/vllm/pull/3492. + Lookahead slots + The block manager has the notion of a "lookahead slot". These are slots + in the KV cache that are allocated for a sequence. Unlike the other + allocated slots, the content of these slots is undefined -- the worker + may use the memory allocations in any way. + + In practice, a worker could use these lookahead slots to run multiple + forward passes for a single scheduler invocation. Each successive + forward pass would write KV activations to the corresponding lookahead + slot. This allows low inter-token latency use-cases, where the overhead + of continuous batching scheduling is amortized over >1 generated tokens. + + Speculative decoding uses lookahead slots to store KV activations of + proposal tokens. + + See https://github.com/vllm-project/vllm/pull/3250 for more information + on lookahead scheduling. + Args: block_size (int): The size of each memory block. num_gpu_blocks (int): The number of memory blocks allocated on GPU. @@ -116,35 +134,51 @@ def allocate(self, seq_group: SequenceGroup) -> None: for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() - def can_append_slot(self, seq_group: SequenceGroup) -> bool: - # Simple heuristic: If there is at least one free block - # for each sequence, we can append. + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + """Determine if there is enough space in the GPU KV cache to continue + generation of the specified sequence group. + + We use a worst-case heuristic: assume each touched block will require a + new allocation (either via CoW or new block). We can append slots if the + number of touched blocks is less than the number of free blocks. + + "Lookahead slots" are slots that are allocated in addition to the slots + for known tokens. The contents of the lookahead slots are not defined. + This is used by speculative decoding when speculating future tokens. + """ + + num_touched_blocks = 0 + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + block_table = self.block_tables[seq.seq_id] + + num_touched_blocks += ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids( + seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( Device.GPU) - num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) - return num_seqs <= num_free_gpu_blocks + return num_touched_blocks <= num_free_gpu_blocks - def append_slot( + def append_slots( self, seq: Sequence, - ) -> Optional[Tuple[int, int]]: + num_lookahead_slots: int, + ) -> Dict[int, List[int]]: block_table = self.block_tables[seq.seq_id] - # Get unseen token ids. - num_full_slots = block_table.num_full_slots - unseen_token_ids = seq.get_token_ids()[num_full_slots:] - assert unseen_token_ids - - block_table.append_token_ids(unseen_token_ids) - - # Return any copy-on-writes. - _ = self.block_allocator.clear_copy_on_writes() - - # TODO extend append_slot interface to append_slots - # @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250 + block_table.append_token_ids( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + ) - return None + # Return any new copy-on-writes. + new_cows = self.block_allocator.clear_copy_on_writes() + return new_cows def free(self, seq: Sequence) -> None: if seq.seq_id not in self.block_tables: @@ -191,10 +225,12 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.fork() - def can_swap_in(self, seq_group: SequenceGroup) -> bool: + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: return False - def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> Dict[int, int]: raise NotImplementedError def can_swap_out(self, seq_group: SequenceGroup) -> bool: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 48524de0df8ea..711536bcc97be 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,6 +1,6 @@ import enum from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Dict, List from vllm.sequence import Sequence, SequenceGroup @@ -44,14 +44,16 @@ def allocate(self, seq_group: SequenceGroup) -> None: pass @abstractmethod - def can_append_slot(self, seq_group: SequenceGroup) -> bool: + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: pass @abstractmethod - def append_slot( + def append_slots( self, seq: Sequence, - ) -> Optional[Tuple[int, int]]: + num_lookahead_slots: int, + ) -> Dict[int, List[int]]: pass @abstractmethod @@ -59,11 +61,13 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: pass @abstractmethod - def can_swap_in(self, seq_group: SequenceGroup) -> bool: + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: pass @abstractmethod - def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> Dict[int, int]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 04e8056aab544..9d098801233e2 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -52,6 +52,7 @@ def __init__( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], + num_lookahead_slots: int, ) -> None: """A list of sequence groups to be scheduled as a single batch. @@ -86,6 +87,7 @@ def __init__( # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) + self.num_lookahead_slots = num_lookahead_slots self.num_loras: int = len(self.lora_requests) if self.num_loras > 0: @@ -309,6 +311,8 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True), ) return scheduler_outputs @@ -323,7 +327,7 @@ def _schedule(self) -> SchedulerOutputs: preempted: List[SequenceGroup] = [] while self.running: seq_group = self.running.popleft() - while not self.block_manager.can_append_slot(seq_group): + while not self._can_append_slots(seq_group): if self.running: # Preempt the lowest-priority sequence groups. victim_seq_group = self.running.pop() @@ -337,7 +341,7 @@ def _schedule(self) -> SchedulerOutputs: break else: # Append new slots to the sequence group. - self._append_slot(seq_group, blocks_to_copy) + self._append_slots(seq_group, blocks_to_copy) running.append(seq_group) self.running = running @@ -366,7 +370,7 @@ def _schedule(self) -> SchedulerOutputs: continue # If the sequence group cannot be swapped in, stop. - if not self.block_manager.can_swap_in(seq_group): + if not self._can_swap_in(seq_group): break # The total number of sequences in the RUNNING state should not @@ -380,7 +384,7 @@ def _schedule(self) -> SchedulerOutputs: curr_loras.add(lora_int_id) self.swapped.popleft() self._swap_in(seq_group, blocks_to_swap_in) - self._append_slot(seq_group, blocks_to_copy) + self._append_slots(seq_group, blocks_to_copy) num_curr_seqs += num_new_seqs self.running.append(seq_group) @@ -405,9 +409,32 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False), ) return scheduler_outputs + def _can_append_slots(self, seq_group: SequenceGroup) -> bool: + """Determine whether or not we have enough space in the KV cache to + continue generation of the sequence group. + """ + # Appending slots only occurs in decoding. + is_prefill = False + + return self.block_manager.can_append_slots( + seq_group=seq_group, + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + ) + + def _can_swap_in(self, seq_group: SequenceGroup) -> bool: + # Swapping in is considered decode. + is_prefill = False + + return self.block_manager.can_swap_in( + seq_group=seq_group, + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + ) + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Schedule sequence groups. # This function call changes the internal states of the scheduler @@ -482,19 +509,30 @@ def _allocate(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slot( + def _append_slots( self, seq_group: SequenceGroup, blocks_to_copy: Dict[int, List[int]], ) -> None: + """Appends new slots to the sequences in the given sequence group. + + Args: + seq_group (SequenceGroup): The sequence group containing the + sequences to append slots to. + blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source + block indices to lists of destination block indices. This + dictionary is updated with the new source and destination block + indices for the appended slots. + """ + num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - ret = self.block_manager.append_slot(seq) - if ret is not None: - src_block, dst_block = ret - if src_block in blocks_to_copy: - blocks_to_copy[src_block].append(dst_block) - else: - blocks_to_copy[src_block] = [dst_block] + cows = self.block_manager.append_slots(seq, num_lookahead_slots) + + for src, dests in cows.items(): + if src not in blocks_to_copy: + blocks_to_copy[src] = [] + blocks_to_copy[src].extend(dests) def _preempt( self, @@ -588,3 +626,16 @@ def _passed_delay(self, now: float) -> bool: else: passed_delay = True return passed_delay + + def _get_num_lookahead_slots(self, is_prefill: bool) -> int: + """The number of slots to allocate per sequence per step, beyond known + token ids. Speculative decoding uses these slots to store KV activations + of tokens which may or may not be accepted. + + Speculative decoding does not yet support prefill, so we do not perform + lookahead allocation for prefill. + """ + if is_prefill: + return 0 + + return self.scheduler_config.num_lookahead_slots diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 83ef7ca182c3d..8d61f2f9ff193 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -53,8 +53,8 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False - forced_num_gpu_blocks: Optional[int] = None + num_lookahead_slots: int = 0 # Related to Vision-language models such as llava image_input_type: Optional[str] = None @@ -202,6 +202,14 @@ def add_cli_args( parser.add_argument('--use-v2-block-manager', action='store_true', help='Use BlockSpaceMangerV2') + parser.add_argument( + '--num-lookahead-slots', + type=int, + default=EngineArgs.num_lookahead_slots, + help='Experimental scheduling config necessary for ' + 'speculative decoding. This will be replaced by ' + 'speculative config in the future; it is present ' + 'to enable correctness tests until then.') parser.add_argument('--seed', type=int, @@ -324,7 +332,7 @@ def add_cli_args( parser.add_argument("--device", type=str, default=EngineArgs.device, - choices=["auto", "cuda", "neuron"], + choices=["auto", "cuda", "neuron", "cpu"], help='Device type for vLLM execution.') # Related to Vision-language models such as llava parser.add_argument( @@ -406,6 +414,7 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, self.use_v2_block_manager, + num_lookahead_slots=self.num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dec42c633b10b..cd7fc5fdfcee1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -13,7 +13,6 @@ from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.model_loader import get_architecture_class_name from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, @@ -115,6 +114,8 @@ def __init__( # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) usage_message.report_usage( get_architecture_class_name(model_config), usage_context, @@ -178,6 +179,9 @@ def from_engine_args( if device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor + elif device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + executor_class = CPUExecutor elif parallel_config.worker_use_ray: initialize_ray_cluster(parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e550943c88725..32282bfd8d12b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -127,7 +127,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @app.middleware("http") async def authentication(request: Request, call_next): - if not request.url.path.startswith("/v1"): + root_path = "" if args.root_path is None else args.root_path + if not request.url.path.startswith(f"{root_path}/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: return JSONResponse(content={"error": "Unauthorized"}, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py new file mode 100644 index 0000000000000..7b3cc784c98e5 --- /dev/null +++ b/vllm/executor/cpu_executor.py @@ -0,0 +1,154 @@ +import os +from typing import Dict, List, Optional + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import get_distributed_init_method, get_ip, get_open_port + +logger = init_logger(__name__) + + +class CPUExecutor(ExecutorBase): + + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: + assert device_config.device_type == "cpu" + assert lora_config is None, "cpu backend doesn't support LoRA" + model_config = _verify_and_get_model_config(model_config) + cache_config = _verify_and_get_cache_config(cache_config) + + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + # Instantiate the worker and load the model to CPU. + self._init_worker() + self._init_cache() + + def _init_worker(self): + from vllm.worker.cpu_worker import CPUWorker + + assert self.parallel_config.world_size == 1, ( + "CPUExecutor only supports single CPU socket currently.") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = CPUWorker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def _init_cache(self) -> None: + num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num( + block_size=self.cache_config.block_size, + cache_space=self.cache_config.cpu_kvcache_space_bytes, + cache_dtype=self.cache_config.cache_dtype, + ) + + logger.info(f"# CPU blocks: {num_cpu_blocks}") + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore + self.cache_config.num_cpu_blocks = 0 # type: ignore + + # Initialize the cache. + self.driver_worker.init_cache_engine(cache_config=self.cache_config) + + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + output = self.driver_worker.execute_model( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError("LoRA is not implemented for cpu backend.") + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError("LoRA is not implemented for cpu backend.") + + def list_loras(self) -> List[int]: + raise NotImplementedError("LoRA is not implemented for cpu backend.") + + def check_health(self) -> None: + # CPUExecutor will always be healthy as long as + # it's running. + return + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype == torch.float16: + logger.warning("float16 is not supported on CPU, casting to bfloat16.") + config.dtype = torch.bfloat16 + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on CPU, fallback to the eager " + "mode.") + config.enforce_eager = True + return config + + +def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: + _GB = 1 << 30 + if config.enable_prefix_caching: + logger.warning("Prefix caching is not supported on CPU, disable it.") + config.enable_prefix_caching = False + + kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0") + kv_cache_space = int(kv_cache_space_str) + + if kv_cache_space >= 0: + if kv_cache_space == 0: + config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore + logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " + "for CPU backend is not set, using 4 by default.") + else: + config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore + else: + raise RuntimeError( + "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + return config diff --git a/vllm/spec_decode/__init__.py b/vllm/spec_decode/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 419687e23b718..486c1938e1e10 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,10 +1,8 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.tokenizer import (convert_prompt_ids_to_tokens, - detokenize_incrementally) from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) @@ -148,10 +146,160 @@ def decode_sequence_inplace(self, seq: Sequence, ) sample_logprob.decoded_token = new_text - if seq.tokens is None: - seq.tokens = new_tokens - else: - seq.tokens.extend(new_tokens) + seq.tokens.extend(new_tokens) seq.prefix_offset = prefix_offset seq.read_offset = read_offset seq.output_text += new_decoded_token_text + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts = [] + current_sub_text = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # We do not need to convert the whole prompt to tokens. + # Offset a little more in case we have special tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], + skip_special_tokens=skip_special_tokens) + read_offset = len(new_tokens) + prefix_offset = max( + read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + return new_tokens, prefix_offset, read_offset + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int, + read_offset: int, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + + # If the new token id is out of bounds, return an empty string. + if new_token_id >= len(tokenizer): + new_tokens = [""] + else: + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) <= len(prefix_text) or new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + return new_tokens, "", prefix_offset, read_offset + + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 3bda3f419d8a2..e216a99af91f9 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -126,157 +126,3 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, get_lora_tokenizer_async = make_async(get_lora_tokenizer) - - -def _convert_tokens_to_string_with_added_encoders( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - output_tokens: List[str], - skip_special_tokens: bool, - spaces_between_special_tokens: bool, -) -> str: - # Adapted from - # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 - # NOTE(woosuk): The following code is slow because it runs a for loop over - # the output_tokens. In Python, running a for loop over a list can be slow - # even when the loop body is very simple. - sub_texts = [] - current_sub_text = [] - all_special_tokens = set(tokenizer.all_special_tokens) - for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: - continue - if token in tokenizer.get_added_vocab(): - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] - sub_texts.append(token) - else: - current_sub_text.append(token) - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - if spaces_between_special_tokens: - return " ".join(sub_texts) - else: - return "".join(sub_texts) - - -# 5 is an arbitrary value that should work for all -# tokenizers (bigger = more conservative). -INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 - - -def convert_prompt_ids_to_tokens( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - prompt_ids: List[int], - skip_special_tokens: bool = False, -) -> Tuple[List[str], int, int]: - """Converts the prompt ids to tokens and returns the tokens and offsets - for incremental detokenization. - - Note that not all tokens are converted to strings. Only the tokens that - are necessary for incremental detokenization are converted to strings. - """ - # Offset a little more in case we have special tokens. - prefix_offset = max( - len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0) - # We do not need to convert the whole prompt to tokens. - new_tokens = tokenizer.convert_ids_to_tokens( - prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens) - prefix_offset = max( - len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) - read_offset = len(new_tokens) - return new_tokens, prefix_offset, read_offset - - -# Based on -# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 -# under Apache 2.0 license -def detokenize_incrementally( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - all_input_ids: List[int], - prev_tokens: Optional[List[str]], - prefix_offset: int, - read_offset: int, - skip_special_tokens: bool = False, - spaces_between_special_tokens: bool = True, -) -> Tuple[List[str], str, int, int]: - """Detokenizes the input ids incrementally and returns the new tokens - and the new text. - - If `prev_tokens` is None, this function will convert the input ids to - tokens and return the tokens and the new text. Otherwise, it will return the - new tokens and the new text. - - This function will also return the new prefix offset and the new read - offset to be used in the next iteration. - - The offsets are necessary to defeat cleanup algorithms in the decode which - decide to add a space or not depending on the surrounding ids. - - Args: - tokenizer: The tokenizer to use. - all_input_ids: The input ids. The last id is the new token id. - prev_tokens: The previous tokens. If None, this function will convert - the input ids to tokens and return the tokens and the new text. - prefix_offset: The prefix offset. - read_offset: The read offset. - skip_special_tokens: Whether to skip special tokens. - spaces_between_special_tokens: Whether to add spaces between special - tokens. - """ - new_token_id = all_input_ids[-1] - # This is the first iteration for this sequence - is_first_iter = prev_tokens is None - if is_first_iter: - (prev_tokens, prefix_offset, - read_offset) = convert_prompt_ids_to_tokens( - tokenizer, - all_input_ids[:-1], - skip_special_tokens=skip_special_tokens) - - # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: - # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) - output_tokens = prev_tokens + new_tokens - - # If this is the first iteration, return all tokens. - if is_first_iter: - new_tokens = output_tokens - - # The prefix text is necessary only to defeat cleanup algorithms in - # the decode which decide to add a space or not depending on the - # surrounding ids. - if tokenizer.is_fast or not tokenizer.get_added_vocab(): - prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset]) - new_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:]) - else: - prefix_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:read_offset], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - new_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - - if len(new_text) > len(prefix_text) and not new_text.endswith("�"): - # utf-8 char at the end means it's a potential unfinished byte sequence - # from byte fallback tokenization. - # If it's in the middle, it's probably a real invalid id generated - # by the model - new_text = new_text[len(prefix_text):] - return new_tokens, new_text, read_offset, len(output_tokens) - else: - return new_tokens, "", prefix_offset, read_offset diff --git a/vllm/utils.py b/vllm/utils.py index 93fff4ffc9361..17b97f393ff21 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -117,6 +117,13 @@ def is_hip() -> bool: return torch.version.hip is not None +@lru_cache(maxsize=None) +def is_cpu() -> bool: + from importlib.metadata import version + is_cpu_flag = "cpu" in version("vllm") + return is_cpu_flag + + @lru_cache(maxsize=None) def is_neuron() -> bool: try: @@ -362,6 +369,9 @@ def is_pin_memory_available() -> bool: elif is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False + elif is_cpu(): + print_warning_once("Pin memory is not supported on CPU.") + return False return True diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py new file mode 100644 index 0000000000000..262ed9abd36b7 --- /dev/null +++ b/vllm/worker/cpu_worker.py @@ -0,0 +1,280 @@ +"""A CPU worker class.""" +from typing import Dict, List, Optional + +import torch +import torch.distributed + +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.worker.model_runner import ModelRunner + +logger = init_logger(__name__) + + +class CPUModelRunner(ModelRunner): + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + +class CPUCacheEngine: + """Manages the KV cache for CPU backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: + assert device_config.device_type == "cpu" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for CPU backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Get attention backend. + self.attn_backend = get_attn_backend(model_config.dtype) + + # Initialize the cache. + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[torch.Tensor]: + """Allocates KV cache on CPU.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_heads, self.head_size) + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + kv_cache.append( + torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) + return kv_cache + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: str, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = torch.tensor([], dtype=dtype).element_size() + return dtype_size * total + + +class CPUWorker: + """A worker class that executes (a partition of) the model on a CPU socket. + + Each worker is associated with a single CPU socket. The worker is + responsible for maintaining the KV cache and executing the model on the + CPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.model_runner = CPUModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) + # Uninitialized cache engine. Will be initialized by + # self.init_cache_engine(). + self.cache_config = None + self.cache_engine = None + self.cpu_cache = None + + def init_device(self) -> None: + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def get_cpu_cache_block_num( + self, + block_size: int, + cache_space: int, + cache_dtype: str, + ) -> int: + """ + Args: + block_size: The size of the cache block. + cache_space: The size of the CPU KV cache space in bytes. + """ + # For CPU device, the block number will be calculated based on the + # cpu_kvcache_space. + cache_block_size = CPUCacheEngine.get_cache_block_size( + block_size, cache_dtype, self.model_config, self.parallel_config) + num_cpu_blocks = int(cache_space // cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + return num_cpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig) -> None: + self.cache_config = cache_config + self.cache_engine = CPUCacheEngine(self.cache_config, + self.model_config, + self.parallel_config, + self.device_config) + self.cpu_cache = self.cache_engine.cpu_cache + self.model_runner.block_size = self.cache_engine.block_size + + assert self.cpu_cache is not None + + # Populate the cache to warmup the memory + for layer_cache in self.cpu_cache: + layer_cache.fill_(0) + + def cache_copy( + self, + blocks_to_copy: Dict[int, List[int]], + ) -> None: + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + assert len(blocks_to_swap_in) == 0 + assert len(blocks_to_swap_out) == 0 + data = { + "num_seq_groups": num_seq_groups, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_copy(blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return {} + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.cpu_cache) + return output + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch " + "world size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + backend = "gloo" + torch.distributed.init_process_group( + backend=backend, + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size)