Skip to content

Commit

Permalink
AWQ: Up to 2.66x higher throughput (vllm-project#2566)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored and NikolaBorisov committed Jan 31, 2024
1 parent 2d9f410 commit 1516c9a
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 1 deletion.
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ torch::Tensor awq_gemm(
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);

torch::Tensor awq_dequantize(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy);
#endif

void squeezellm_gemm(
Expand Down
1 change: 1 addition & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifndef USE_ROCM
// Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
Expand Down
108 changes: 108 additions & 0 deletions csrc/quantization/awq/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
#endif
}

__global__ void __launch_bounds__(64) dequantize_weights(
int* __restrict__ B,
half* __restrict__ scaling_factors,
int* __restrict__ zeros,
half* __restrict__ C,
int G
)
{
int j_factors1 = 4;
int row_stride2 = 4;
int split_k_iters = 1;
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];

half* B_shared_ptr2 = B_shared;

half B_shared_warp[32];
int OC = 512;

int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y;
int index1 = 8 * col + 8 * row * N;
half* C_ptr2 = C + index1;

int index2 = col + row * N;
int* B_ptr2 = B + index2;

int index3 = col + (int)(row / G) * N;
int* zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8;
half* scaling_factors_ptr2 = scaling_factors + index4;


uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
int j=0;

uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j);
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));

*(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;

for (int i=0; i<8; ++i) {
*(C_ptr2 + i) = B_shared[i];
}
}

} // namespace awq
} // namespace vllm

torch::Tensor awq_dequantize(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy)
{
int in_c = _kernel.size(0);
int qout_c = _kernel.size(1);
int out_c = qout_c * 8;
int G = in_c / _scaling_factors.size(0);

int x_thread = thx;
int y_thread = thy;

int x_blocks = 1;
int y_blocks = 1;
if (thx==0) {
x_thread = qout_c;
}
if (thy==0) {
y_thread = in_c;
}
if (thx==0 && thy==0) {
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
}

const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));

auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);

auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());

dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
kernel, scaling_factors, zeros, de_kernel, G);

return _de_kernel;
}

// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
Expand Down
11 changes: 10 additions & 1 deletion vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,16 @@ def apply_weights(self,
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)

# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)

0 comments on commit 1516c9a

Please sign in to comment.