From 3b8c08ff3d1435594fb585baf99a261934ddcb50 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 24 Jul 2024 11:36:52 -0700 Subject: [PATCH] Add fp8 support to `reshape_and_cache_flash` (#6667) Signed-off-by: Alvant --- csrc/cache.h | 3 +- csrc/cache_kernels.cu | 75 ++++++++++++++++----------- csrc/torch_bindings.cpp | 3 +- tests/kernels/test_cache.py | 42 ++++++++++++--- vllm/_custom_ops.py | 5 +- vllm/attention/backends/flash_attn.py | 2 + vllm/attention/backends/flashinfer.py | 2 + vllm/utils.py | 9 +++- 8 files changed, 98 insertions(+), 43 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 52177e8901a89..11c4c5001daaa 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + const double k_scale, const double v_scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index caef7f5e18630..1be806bbfa43c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel( } } -template +template __global__ void reshape_and_cache_flash_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, + cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, // head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, + cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, // head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, - const int num_heads, const int head_size, const int block_size) { + const int num_heads, const int head_size, const int block_size, + const float k_scale, const float v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_stride + - block_offset * num_heads * head_size + - head_idx * head_size + head_offset; - k_cache[tgt_value_idx] = key[src_key_idx]; - v_cache[tgt_value_idx] = value[src_value_idx]; + const int64_t tgt_key_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_value_idx] = tgt_key; + value_cache[tgt_key_value_idx] = tgt_value; + } else { + key_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_value, v_scale); + } } } } // namespace vllm @@ -278,40 +288,45 @@ void reshape_and_cache( CALL_RESHAPE_AND_CACHE) } +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, num_heads, head_size, block_size, k_scale, v_scale); + void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) { - // FIXME: only support auto datatype, does not support fp8 - if (kv_cache_dtype != "auto") { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); - int block_size = k_cache.size(1); + int block_size = key_cache.size(1); int key_stride = key.stride(0); int value_stride = value.stride(0); - int block_stride = k_cache.stride(0); - TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + int block_stride = key_cache.stride(0); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), "reshape_and_cache_flash", [&] { - vllm::reshape_and_cache_flash_kernel - <<>>( - key.data_ptr(), value.data_ptr(), - k_cache.data_ptr(), v_cache.data_ptr(), - slot_mapping.data_ptr(), block_stride, key_stride, - value_stride, num_heads, head_size, block_size); - }); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_FLASH); } namespace vllm { diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 0df9bdb75018f..3027b63ba2b33 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -248,7 +248,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache," " Tensor! value_cache," " Tensor slot_mapping," - " str kv_cache_dtype) -> ()"); + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 70ae3d0c6e0c3..f9a609464abfc 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -215,8 +215,6 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, ) -> None: - if kv_cache_dtype == "fp8": - pytest.skip() random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -248,15 +246,33 @@ def test_reshape_and_cache_flash( dtype, device=device, ) - key_cache, value_cache = key_caches[0], value_caches[0] + key_cache, value_cache = key_caches[0].contiguous( + ), value_caches[0].contiguous() + del key_caches + del value_caches # Clone the KV caches. - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() + if kv_cache_dtype == "fp8": + cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + ops.convert_fp8(cloned_key_cache, key_cache) + cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + ops.convert_fp8(cloned_value_cache, value_cache) + else: + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Using default kv_scale + k_scale = v_scale = 1.0 # Call the reshape_and_cache kernel. ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + slot_mapping, kv_cache_dtype, k_scale, v_scale) + + if kv_cache_dtype == "fp8": + result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + ops.convert_fp8(result_key_cache, key_cache) + result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + ops.convert_fp8(result_value_cache, value_cache) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") @@ -269,8 +285,18 @@ def test_reshape_and_cache_flash( cloned_key_cache[block_idx, block_offset, :, :] = key[i] cloned_value_cache[block_idx, block_offset, :, :] = value[i] - assert torch.allclose(key_cache, cloned_key_cache) - assert torch.allclose(value_cache, cloned_value_cache) + if kv_cache_dtype == "fp8": + assert torch.allclose(result_key_cache, + cloned_key_cache, + atol=0.001, + rtol=0.1) + assert torch.allclose(result_value_cache, + cloned_value_cache, + atol=0.001, + rtol=0.1) + else: + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) @pytest.mark.parametrize("direction", COPYING_DIRECTION) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e5151c070f2f7..0186594656cc1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -426,10 +426,13 @@ def reshape_and_cache_flash( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, + k_scale: float, + v_scale: float, ) -> None: torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype) + kv_cache_dtype, k_scale, + v_scale) def copy_blocks(key_caches: List[torch.Tensor], diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index b16a204c8f44e..949bd973cf3c4 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -478,6 +478,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, + k_scale, + v_scale, ) num_prefill_tokens = attn_metadata.num_prefill_tokens diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 9dac12d3b906d..2a4900489df35 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -489,6 +489,8 @@ def forward( kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, + k_scale, + v_scale, ) query = query.contiguous( diff --git a/vllm/utils.py b/vllm/utils.py index 83605631b5bd6..876c3bf90b02c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -491,7 +491,6 @@ def create_kv_caches_with_random_flash( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - assert cache_dtype != "fp8" torch.random.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) @@ -507,7 +506,13 @@ def create_kv_caches_with_random_flash( key_value_cache = torch.empty(size=key_value_cache_shape, dtype=torch_dtype, device=device) - key_value_cache.uniform_(-scale, scale) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") key_caches.append(key_value_cache[:, 0]) value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches