Skip to content

Commit

Permalink
Enable CUDA graph for GPTQ & SqueezeLLM (#2318)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jan 3, 2024
1 parent 9140561 commit 6ef00b0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
18 changes: 12 additions & 6 deletions csrc/quantization/gptq/q_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ void gemm_half_q_half_cuda_part

fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);

kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b_q_weight,
Expand Down Expand Up @@ -434,7 +435,8 @@ void reconstruct_exllama
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);

reconstruct_exllama_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
(
b_q_weight,
b_q_perm,
Expand Down Expand Up @@ -567,7 +569,8 @@ void gemm_half_q_half_alt
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);

gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half2*) a,
b_q_weight,
Expand Down Expand Up @@ -639,7 +642,8 @@ void reconstruct_gptq
blockDim.y = 1;
gridDim.y = DIVIDE(height, 8);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
(
b_q_weight,
b_gptq_scales,
Expand Down Expand Up @@ -794,7 +798,8 @@ void shuffle_exllama_weight
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;

make_sequential_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
(
q_weight,
new_qweight,
Expand All @@ -813,7 +818,8 @@ void shuffle_exllama_weight
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
}

} // namespace gptq
Expand Down
4 changes: 3 additions & 1 deletion csrc/quantization/squeezellm/quant_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ void squeezellm_gemm(
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);

const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),
#else
Expand Down
6 changes: 0 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,6 @@ def _verify_cuda_graph(self) -> None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len)
if (self.quantization in ["gptq", "squeezellm"]
and not self.enforce_eager):
# Related issue: https://github.com/vllm-project/vllm/issues/2147
logger.warning(f"{self.quantization} does not support CUDA graph "
"yet. Disabling CUDA graph.")
self.enforce_eager = True

def verify_with_parallel_config(
self,
Expand Down

0 comments on commit 6ef00b0

Please sign in to comment.