Skip to content

Commit

Permalink
[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored May 10, 2024
1 parent 379da6d commit c833101
Show file tree
Hide file tree
Showing 17 changed files with 843 additions and 558 deletions.
2 changes: 1 addition & 1 deletion .buildkite/check-wheel-size.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import zipfile

MAX_SIZE_MB = 100
MAX_SIZE_MB = 150


def print_top_10_largest_files(zip_file):
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
Expand Down
4 changes: 2 additions & 2 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags")

if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
list(APPEND GPU_FLAGS "-DENABLE_FP8")
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS
Expand All @@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)

list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8_E4M3"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")
Expand Down
286 changes: 110 additions & 176 deletions csrc/attention/attention_kernels.cu

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions csrc/attention/dtype_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
#include "attention_generic.cuh"

#include <stdint.h>
#ifdef ENABLE_FP8_E5M2
#ifdef ENABLE_FP8
#ifndef USE_ROCM
#include <cuda_fp8.h>
#endif
#endif // USE_ROCM
#endif // ENABLE_FP8

namespace vllm {
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache

enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};

// fp8 vector types for quantization of kv cache
template<>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
Expand All @@ -30,6 +37,5 @@ template<>
struct Vec<uint8_t, 8> {
using Type = uint2;
};
#endif // ENABLE_FP8_E5M2

} // namespace vllm
4 changes: 3 additions & 1 deletion csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,7 @@ void reshape_and_cache_flash(

// Just for unittest
void convert_fp8(
torch::Tensor& dst_cache,
torch::Tensor& src_cache,
torch::Tensor& dst_cache);
const float scale,
const std::string& kv_cache_dtype);
143 changes: 70 additions & 73 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

#include "cuda_compat.h"
#include "dispatch_utils.h"
#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "quantization/fp8/amd_detail/quant_utils.cuh"

#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif

#include <algorithm>
Expand Down Expand Up @@ -149,7 +150,7 @@ void copy_blocks(

namespace vllm {

template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
Expand Down Expand Up @@ -194,19 +195,12 @@ __global__ void reshape_and_cache_kernel(
+ block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_kv_cache) {
#if defined(ENABLE_FP8_E5M2)
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3)
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
assert(false);
#endif
} else {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
}
}
}
Expand Down Expand Up @@ -248,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel(
}
} // namespace vllm

#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x, \
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);

void reshape_and_cache(
Expand All @@ -285,25 +282,8 @@ void reshape_and_cache(
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, float, false);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
}
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}

DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE)
}

void reshape_and_cache_flash(
Expand Down Expand Up @@ -353,35 +333,34 @@ void reshape_and_cache_flash(

namespace vllm {

template<typename Tout, typename Tin>
template<typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const float kv_scale,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
#if defined(ENABLE_FP8_E5M2)
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#elif defined(ENABLE_FP8_E4M3)
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
}
}

} // namespace vllm

#define CALL_CONVERT_FP8(Tout, Tin) \
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
kv_scale, \
block_stride);

// Only for testing.
void convert_fp8(
torch::Tensor& dst_cache,
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
const float kv_scale,
const std::string& kv_cache_dtype)
{
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
Expand All @@ -399,17 +378,35 @@ void convert_fp8(
dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
if (kv_cache_dtype == "auto") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
}
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

#include "../../../attention/dtype_fp8.cuh"
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"

namespace vllm
{
namespace fp8_e4m3 {
#ifdef USE_ROCM

namespace fp8 {
#ifdef ENABLE_FP8

template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
Expand Down Expand Up @@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint3
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
#endif // ENABLE_FP8

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin &x) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x);
}
#endif
assert(false);
}

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
}
#endif
assert(false);
}

// The following macro is used to dispatch the conversion function based on the
// data type of the key and value cache. The FN is a macro that calls a function
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}

} // fp8
#endif // USE_ROCM
} // namespace vllm
File renamed without changes.
Loading

0 comments on commit c833101

Please sign in to comment.