From e486f525155e3dc92e84c9c04abc3ac5ad77bde4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 1 Sep 2023 14:27:21 +0000 Subject: [PATCH 1/2] Fix dispatch logic --- csrc/activation_kernels.cu | 10 ++++------ csrc/cache_kernels.cu | 14 +++++--------- csrc/dispatch_utils.h | 14 ++++++++++++++ csrc/layernorm_kernels.cu | 5 ++--- csrc/pos_encoding_kernels.cu | 6 +++--- 5 files changed, 28 insertions(+), 21 deletions(-) create mode 100644 csrc/dispatch_utils.h diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index fc1f086f502d3..c6ae5db8f9c48 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,8 @@ #include #include +#include "dispatch_utils.h" + namespace vllm { template @@ -34,9 +36,7 @@ void silu_and_mul( dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "silu_and_mul_kernel", [&] { @@ -71,9 +71,7 @@ __global__ void activation_kernel( dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - AT_DISPATCH_FLOATING_TYPES_AND2( \ - at::ScalarType::Half, \ - at::ScalarType::BFloat16, \ + VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), \ "activation_kernel", \ [&] { \ diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 5e7b6be420ffa..ddad2b5a29b9e 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,8 @@ #include #include +#include "dispatch_utils.h" + #include #include #include @@ -125,9 +127,7 @@ void copy_blocks( dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -202,9 +202,7 @@ void reshape_and_cache( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), "reshape_and_cache_kernel", [&] { @@ -364,9 +362,7 @@ void gather_cached_kv( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), "gather_cached_kv_kernel_optimized", [&] { diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h new file mode 100644 index 0000000000000..ca30da8879382 --- /dev/null +++ b/csrc/dispatch_utils.h @@ -0,0 +1,14 @@ +/* + * Adapted from + * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h + */ +#include + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 73503c55840ca..f932b9e2d6150 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "dispatch_utils.h" #include "reduction_utils.cuh" namespace vllm { @@ -46,9 +47,7 @@ void rms_norm( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel", [&] { diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 98939fc719a01..ced26ecb36979 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,8 @@ #include #include +#include "dispatch_utils.h" + namespace vllm { template @@ -83,9 +85,7 @@ void rotary_embedding_neox( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, + VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), "rotary_embedding_neox", [&] { From 46d9982399348cf4c1c4f9d33011ec370d8865cf Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 1 Sep 2023 14:36:21 +0000 Subject: [PATCH 2/2] Minor --- csrc/dispatch_utils.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index ca30da8879382..7c0c49d392a98 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -4,11 +4,11 @@ */ #include -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))