Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid compiling kernels for double data type #933

Merged
merged 2 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "dispatch_utils.h"

namespace vllm {

template<typename T>
Expand Down Expand Up @@ -34,9 +36,7 @@ void silu_and_mul(
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
Expand Down Expand Up @@ -71,9 +71,7 @@ __global__ void activation_kernel(
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::Half, \
at::ScalarType::BFloat16, \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
Expand Down
14 changes: 5 additions & 9 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "dispatch_utils.h"

#include <algorithm>
#include <cassert>
#include <map>
Expand Down Expand Up @@ -125,9 +127,7 @@ void copy_blocks(
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
Expand Down Expand Up @@ -202,9 +202,7 @@ void reshape_and_cache(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
Expand Down Expand Up @@ -364,9 +362,7 @@ void gather_cached_kv(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {
Expand Down
14 changes: 14 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#include <torch/extension.h>

#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
5 changes: 2 additions & 3 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "dispatch_utils.h"
#include "reduction_utils.cuh"

namespace vllm {
Expand Down Expand Up @@ -46,9 +47,7 @@ void rms_norm(
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"rms_norm_kernel",
[&] {
Expand Down
6 changes: 3 additions & 3 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "dispatch_utils.h"

namespace vllm {

template<typename scalar_t>
Expand Down Expand Up @@ -83,9 +85,7 @@ void rotary_embedding_neox(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding_neox",
[&] {
Expand Down
Loading