Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
alpayariyak committed Apr 2, 2024
2 parents d7d8ae8 + a3c226e commit 5547403
Show file tree
Hide file tree
Showing 52 changed files with 3,608 additions and 333 deletions.
14 changes: 14 additions & 0 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This script build the CPU docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex

# Try building the docker image
docker build -t cpu-test -f Dockerfile.cpu .

# Setup cleanup
remove_docker_container() { docker rm -f cpu-test || true; }
trap remove_docker_container EXIT
remove_docker_container

# Run the image and launch offline inference
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py
3 changes: 3 additions & 0 deletions .buildkite/test-template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ steps:
queue: amd
command: bash .buildkite/run-amd-test.sh

- label: "CPU Test"
command: bash .buildkite/run-cpu-test.sh

- label: ":docker: build image"
commands:
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ $python_executable -m pip install -r requirements.txt
export MAX_JOBS=1
# Make sure punica is built for the release (for LoRA)
export VLLM_INSTALL_PUNICA_KERNELS=1

# Make sure release wheels are built for the following architectures
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
# Build
$python_executable setup.py bdist_wheel --dist-dir=dist
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21)

project(vllm_extensions LANGUAGES CXX)

option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")

message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")

include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

Expand Down Expand Up @@ -76,6 +79,19 @@ find_package(Torch REQUIRED)
find_library(torch_python_LIBRARY torch_python PATHS
"${TORCH_INSTALL_PREFIX}/lib")

#
# Forward the non-CUDA device extensions to external CMake scripts.
#
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
else()
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
endif()
return()
endif()

#
# Set up GPU language and check the torch version and warn if it isn't
# what is expected.
Expand Down
20 changes: 20 additions & 0 deletions Dockerfile.cpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.

FROM ubuntu:22.04

RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12

RUN pip install --upgrade pip \
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy

COPY ./ /workspace/vllm

WORKDIR /workspace/vllm

RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu

RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install

CMD ["/bin/bash"]
6 changes: 3 additions & 3 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ async def async_request_openai_chat_completions(
timestamp = time.perf_counter()
data = json.loads(chunk)

if "content" in data["choices"][0]["delta"]:
delta = data["choices"][0]["delta"]
if delta.get("content", None):
# First token
if ttft == 0:
ttft = time.perf_counter() - st
Expand All @@ -345,8 +346,7 @@ async def async_request_openai_chat_completions(
output.itl.append(timestamp -
most_recent_timestamp)

generated_text += data["choices"][0]["delta"][
"content"]
generated_text += delta["content"]

most_recent_timestamp = timestamp

Expand Down
90 changes: 90 additions & 0 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

#
# Define environment variables for special configurations
#
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512BF16 ON)
endif()

include_directories("${CMAKE_SOURCE_DIR}/csrc")

#
# Check the compile flags
#
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")

execute_process(COMMAND cat /proc/cpuinfo
RESULT_VARIABLE CPUINFO_RET
OUTPUT_VARIABLE CPUINFO)

if (NOT CPUINFO_RET EQUAL 0)
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
endif()

function (find_isa CPUINFO TARGET OUT)
string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
if(NOT ISA_FOUND EQUAL -1)
set(${OUT} ON PARENT_SCOPE)
else()
set(${OUT} OFF PARENT_SCOPE)
endif()
endfunction()

find_isa(${CPUINFO} "avx512f" AVX512_FOUND)

if (AVX512_FOUND)
list(APPEND CXX_COMPILE_FLAGS
"-mavx512f"
"-mavx512vl"
"-mavx512bw"
"-mavx512dq")

find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
else()
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
endif()
else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
endif()

message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")


#
# Define extension targets
#

#
# _C extension
#
set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pybind.cpp")

define_gpu_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
WITH_SOABI
)

add_custom_target(default)
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

148 changes: 148 additions & 0 deletions csrc/cpu/activation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#include "cpu_types.hpp"

namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
scalar_t *__restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();

TORCH_CHECK(d % VEC_ELEM_NUM == 0);

#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
int start = i * d;
if constexpr (is_gated) {
start *= 2;
}

const scalar_vec_t x(input + start + j);
const vec_op::FP32Vec8 f32_x(x);
vec_op::FP32Vec8 f32_ans = func(f32_x);

if constexpr (is_gated) {
const scalar_vec_t y(input + start + d + j);
const vec_op::FP32Vec8 f32_y(y);
f32_ans = f32_y * f32_ans;
}

const scalar_vec_t result(f32_ans);
result.save(output + i * d + j);
}
}
}

FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp());
}

FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 x3 = x * x * x;
const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
return w3 * x * (ones + t);
}

FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
return w3 * x * (ones + t);
}

FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er());
}

FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5);
const vec_op::FP32Vec8 w3(0.044715);
const vec_op::FP32Vec8 x_3 = x * x * x;
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh());
}
}; // namespace

void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
}

void gelu_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
}

void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
activation_kernel<scalar_t, gelu_tanh_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
});
}

void gelu_new(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);

VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_new_impl)
activation_kernel<scalar_t, gelu_new_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_new_impl)
});
}

void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);

VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_fast_impl)
activation_kernel<scalar_t, gelu_fast_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
});
}
Loading

0 comments on commit 5547403

Please sign in to comment.