Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Caching with FP8 KV cache support #3234

Closed
wants to merge 11 commits into from
3 changes: 2 additions & 1 deletion csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)\
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
triton >= 2.2.0
outlines >= 0.0.27
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
8 changes: 8 additions & 0 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,17 @@ def test_copy_blocks(

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
# NOTE: torch.allclose has not supported
# torch.fp8_e5m2/torch.fp8_e4m3fn dtypes.
if kv_cache_dtype == "fp8_e5m2":
key_cache = key_cache.view(torch.half)
cloned_key_cache = cloned_key_cache.view(torch.half)
assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
if kv_cache_dtype == "fp8_e5m2":
value_cache = value_cache.view(torch.half)
cloned_value_cache = cloned_value_cache.view(torch.half)
assert torch.allclose(value_cache, cloned_value_cache)


Expand Down
Loading
Loading