Skip to content

Commit

Permalink
Support SqueezeLLM (vllm-project#1326)
Browse files Browse the repository at this point in the history
Co-authored-by: squeeze-ai-lab <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
  • Loading branch information
3 people authored and skrider committed Oct 27, 2023
1 parent 7f7b152 commit 9dd265d
Show file tree
Hide file tree
Showing 16 changed files with 378 additions and 40 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run_to_completion(profile: bool = False):
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
Expand Down
12 changes: 8 additions & 4 deletions csrc/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ torch::Tensor awq_gemm(
torch::Tensor _zeros,
int split_k_iters);

void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
}
148 changes: 148 additions & 0 deletions csrc/quantization/squeezellm/quant_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>

#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16

namespace vllm {
namespace squeezellm {

__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}

// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
const half2* __restrict__ vec,
const int* __restrict__ mat,
half2* __restrict__ mul,
const __half* __restrict__ lookup_table,
int height,
int width,
int batch,
int vec_height
) {

const int blockwidth2 = BLOCKWIDTH / 2;

int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;

__shared__ half2 blockvec[blockwidth2];

__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
int column_offset = col * 16;
for (int val = 0; val < 16; val += 1) {
int lut_index = column_offset + val;
deq2[val][off] = lookup_table[lut_index];
}

__half res;
half2 res2;
half2 tmp2;

int i;
int k;

unsigned int tmp1;
unsigned int lut_index1, lut_index2;

for (int b = 0; b < batch; ++b){
i = width * row + col;
res = __int2half_rd(0);
k = 0;

__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
__syncthreads();

while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);

res2 = {};
tmp2 = {};

lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 0], res2);

lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 1], res2);

lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 2], res2);

lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
res2 = __hfma2(tmp2, blockvec[k + 3], res2);

res = __hadd(__hadd(res2.x, res2.y), res);

i += width;
k += 4;
}

// col%2 -> only set one of the two values
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}

atomicAdd(&mul[b * width / 2 + col / 2], res3);
}
}

} // namespace squeezellm
} // namespace vllm

// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
int height = mat.size(0);
int width = mat.size(1);

int batch = vec.size(0);
int vec_height = vec.size(1);

dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);

vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
(half2*) vec.data<at::Half>(),
mat.data_ptr<int>(),
(half2*) mul.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
height, width, batch, vec_height
);
}

#undef BLOCKWIDTH
#undef BLOCKHEIGHT4
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def get_torch_arch_list() -> Set[str]:
sources=[
"csrc/quantization.cpp",
"csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _verify_tokenizer_mode(self) -> None:
self.tokenizer_mode = tokenizer_mode

def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
supported_quantization = ["awq", "squeezellm"]
if self.quantization is None:
return
quantization = self.quantization.lower()
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def add_cli_args(
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
choices=['awq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights')
return parser
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/quantized_linear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.layers.quantized_linear.squeezellm import (
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)

_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
"squeezellm":
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
}


Expand Down
14 changes: 9 additions & 5 deletions vllm/model_executor/layers/quantized_linear/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
class AWQColumnParallelLinear(ColumnParallelLinear):

def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.weight_bits == 0
assert (self.output_size_per_partition %
self.quant_config.pack_factor == 0)
assert self.input_size % self.quant_config.group_size == 0
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size,
Expand Down Expand Up @@ -62,9 +64,11 @@ def apply_weights(
class AWQRowParallelLinear(RowParallelLinear):

def create_weights(self, dtype: torch.dtype) -> None:
assert (self.input_size_per_partition %
self.quant_config.weight_bits == 0)
assert self.output_size % self.quant_config.pack_factor == 0
if self.input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
Expand Down
84 changes: 84 additions & 0 deletions vllm/model_executor/layers/quantized_linear/squeezellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Optional

import torch
from torch.nn.parameter import Parameter

from vllm import quantization_ops
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)


class SqueezeLLMColumnParallelLinear(ColumnParallelLinear):

def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size // self.quant_config.pack_factor,
self.output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size_per_partition,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)

def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)

if bias is not None:
out = out + bias
return out.reshape(out_shape)


class SqueezeLLMRowParallelLinear(RowParallelLinear):

def create_weights(self, dtype: torch.dtype) -> None:
if self.input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The tensor parallel size is not aligned with the quantized "
"weight shape. Please use a different tensor parallel size.")
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.pack_factor,
self.output_size,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.lookup_table = Parameter(
torch.empty(
self.output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)

def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
self.lookup_table)
return out.reshape(out_shape)
24 changes: 15 additions & 9 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,21 @@ def load_weights(self,
load_format: str = "auto",
revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
col_weight_suffixes = ["weight"]
row_weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
col_weight_suffixes = (
self.quant_config.get_col_parallel_tensor_names())
row_weight_suffixes = (
self.quant_config.get_row_parallel_tensor_names())

column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
for suffix in col_weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
for suffix in row_weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")

tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -350,10 +354,10 @@ def load_weights(self,
if "rotary_emb.inv_freq" in name:
continue

is_packed = False
packed_dim = None
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
packed_dim = self.quant_config.get_packed_dim(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
Expand All @@ -367,9 +371,11 @@ def load_weights(self,
if is_transposed:
param = param.T

if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
if packed_dim is not None:
shard_dim = 0 if not is_transposed else 1
if packed_dim == shard_dim:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor

if weight_name in ["k_proj", "v_proj"]:
shard_id = tp_rank // num_kv_heads_replicas
Expand Down
Loading

0 comments on commit 9dd265d

Please sign in to comment.