diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 5b946e191e453..e560cb1fbfc06 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -70,7 +70,7 @@ def run_to_completion(profile: bool = False): parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', None], + choices=['awq', 'squeezellm', None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 13df1a5a0c874..fc578b4972863 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -201,7 +201,7 @@ def main(args: argparse.Namespace): parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', None], + choices=['awq', 'squeezellm', None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index 3afa7f6a231d7..dfe17a496c780 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -7,9 +7,13 @@ torch::Tensor awq_gemm( torch::Tensor _zeros, int split_k_iters); +void squeezellm_gemm( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "awq_gemm", - &awq_gemm, - "Quantized GEMM for AWQ"); + m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); } diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu new file mode 100644 index 0000000000000..1392b877397be --- /dev/null +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -0,0 +1,148 @@ +#include +#include +#include +#include +#include + +// half-tensor +#include +#include + +#define BLOCKWIDTH 128 +#define BLOCKHEIGHT4 16 + +namespace vllm { +namespace squeezellm { + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +// 4-bit matvec kernel (LUT-based) +__global__ void NUQ4MatMulKernel( + const half2* __restrict__ vec, + const int* __restrict__ mat, + half2* __restrict__ mul, + const __half* __restrict__ lookup_table, + int height, + int width, + int batch, + int vec_height +) { + + const int blockwidth2 = BLOCKWIDTH / 2; + + int row = BLOCKHEIGHT4 * blockIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ half2 blockvec[blockwidth2]; + + __shared__ __half deq2[16][BLOCKWIDTH]; + int off = threadIdx.x; + int column_offset = col * 16; + for (int val = 0; val < 16; val += 1) { + int lut_index = column_offset + val; + deq2[val][off] = lookup_table[lut_index]; + } + + __half res; + half2 res2; + half2 tmp2; + + int i; + int k; + + unsigned int tmp1; + unsigned int lut_index1, lut_index2; + + for (int b = 0; b < batch; ++b){ + i = width * row + col; + res = __int2half_rd(0); + k = 0; + + __syncthreads(); + if (threadIdx.x < blockwidth2) + blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x]; + __syncthreads(); + + while (k < blockwidth2) { + tmp1 = as_unsigned(mat[i]); + + res2 = {}; + tmp2 = {}; + + lut_index1 = tmp1 & 0xF; + lut_index2 = (tmp1 >> 4) & 0xF; + tmp2.x = deq2[lut_index1][off]; + tmp2.y = deq2[lut_index2][off]; + res2 = __hfma2(tmp2, blockvec[k + 0], res2); + + lut_index1 = (tmp1 >> 8) & 0xF; + lut_index2 = (tmp1 >> 12) & 0xF; + tmp2.x = deq2[lut_index1][off]; + tmp2.y = deq2[lut_index2][off]; + res2 = __hfma2(tmp2, blockvec[k + 1], res2); + + lut_index1 = (tmp1 >> 16) & 0xF; + lut_index2 = (tmp1 >> 20) & 0xF; + tmp2.x = deq2[lut_index1][off]; + tmp2.y = deq2[lut_index2][off]; + res2 = __hfma2(tmp2, blockvec[k + 2], res2); + + lut_index1 = (tmp1 >> 24) & 0xF; + lut_index2 = (tmp1 >> 28) & 0xF; + tmp2.x = deq2[lut_index1][off]; + tmp2.y = deq2[lut_index2][off]; + res2 = __hfma2(tmp2, blockvec[k + 3], res2); + + res = __hadd(__hadd(res2.x, res2.y), res); + + i += width; + k += 4; + } + + // col%2 -> only set one of the two values + half2 res3 = {}; + if (col % 2 == 0) { + res3.x = res; + } else { + res3.y = res; + } + + atomicAdd(&mul[b * width / 2 + col / 2], res3); + } +} + +} // namespace squeezellm +} // namespace vllm + +// 4-bit matvec kernel (LUT-based) +void squeezellm_gemm( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table +) { + int height = mat.size(0); + int width = mat.size(1); + + int batch = vec.size(0); + int vec_height = vec.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + vllm::squeezellm::NUQ4MatMulKernel<<>>( + (half2*) vec.data(), + mat.data_ptr(), + (half2*) mul.data(), + (__half*) lookup_table.data(), + height, width, batch, vec_height + ); +} + +#undef BLOCKWIDTH +#undef BLOCKHEIGHT4 diff --git a/setup.py b/setup.py index 6ffc03c25386d..4bcd53394b4e5 100644 --- a/setup.py +++ b/setup.py @@ -200,6 +200,7 @@ def get_torch_arch_list() -> Set[str]: sources=[ "csrc/quantization.cpp", "csrc/quantization/awq/gemm_kernels.cu", + "csrc/quantization/squeezellm/quant_cuda_kernel.cu", ], extra_compile_args={ "cxx": CXX_FLAGS, diff --git a/vllm/config.py b/vllm/config.py index d45bb8857ec35..6e19491083d44 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -103,7 +103,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq"] + supported_quantization = ["awq", "squeezellm"] if self.quantization is None: return quantization = self.quantization.lower() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 51a8161bdbc13..cc425a2c079e7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -168,7 +168,7 @@ def add_cli_args( parser.add_argument('--quantization', '-q', type=str, - choices=['awq', None], + choices=['awq', 'squeezellm', None], default=None, help='Method used to quantize the weights') return parser diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index eecfe8149ebf3..b09358261d5d1 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -1,10 +1,14 @@ from vllm.model_executor.layers.quantized_linear.awq import ( AWQColumnParallelLinear, AWQRowParallelLinear) +from vllm.model_executor.layers.quantized_linear.squeezellm import ( + SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear) from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, RowParallelLinear) _QUANTIZED_LINEAR_REGISTRY = { "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), + "squeezellm": + (SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear), } diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 0d7d0f9116a80..31e341318d400 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -11,9 +11,11 @@ class AWQColumnParallelLinear(ColumnParallelLinear): def create_weights(self, dtype: torch.dtype) -> None: - assert self.input_size % self.quant_config.weight_bits == 0 - assert (self.output_size_per_partition % - self.quant_config.pack_factor == 0) + assert self.input_size % self.quant_config.group_size == 0 + if self.output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The tensor parallel size is not aligned with the quantized " + "weight shape. Please use a different tensor parallel size.") self.qweight = Parameter( torch.empty( self.input_size, @@ -62,9 +64,11 @@ def apply_weights( class AWQRowParallelLinear(RowParallelLinear): def create_weights(self, dtype: torch.dtype) -> None: - assert (self.input_size_per_partition % - self.quant_config.weight_bits == 0) assert self.output_size % self.quant_config.pack_factor == 0 + if self.input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The tensor parallel size is not aligned with the quantized " + "weight shape. Please use a different tensor parallel size.") self.qweight = Parameter( torch.empty( self.input_size_per_partition, diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py new file mode 100644 index 0000000000000..3ccbc4e579dc6 --- /dev/null +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -0,0 +1,84 @@ +from typing import Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import quantization_ops +from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, + RowParallelLinear) + + +class SqueezeLLMColumnParallelLinear(ColumnParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert self.input_size % self.quant_config.pack_factor == 0 + self.qweight = Parameter( + torch.empty( + self.input_size // self.quant_config.pack_factor, + self.output_size_per_partition, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.lookup_table = Parameter( + torch.empty( + self.output_size_per_partition, + self.quant_config.weight_bits**2, + device="cuda", + dtype=dtype, + ), + requires_grad=False, + ) + + def apply_weights( + self, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + # NOTE: The output tensor should be zero-initialized. + out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, + self.lookup_table) + + if bias is not None: + out = out + bias + return out.reshape(out_shape) + + +class SqueezeLLMRowParallelLinear(RowParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + if self.input_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The tensor parallel size is not aligned with the quantized " + "weight shape. Please use a different tensor parallel size.") + self.qweight = Parameter( + torch.empty( + self.input_size_per_partition // self.quant_config.pack_factor, + self.output_size, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.lookup_table = Parameter( + torch.empty( + self.output_size, + self.quant_config.weight_bits**2, + device="cuda", + dtype=dtype, + ), + requires_grad=False, + ) + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) + # NOTE: The output tensor should be zero-initialized. + out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, + self.lookup_table) + return out.reshape(out_shape) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 15cdae46e9671..fb7569e6da7d5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -313,17 +313,21 @@ def load_weights(self, load_format: str = "auto", revision: Optional[str] = None): if self.quant_config is None: - weight_suffixes = ["weight"] + col_weight_suffixes = ["weight"] + row_weight_suffixes = ["weight"] else: - weight_suffixes = self.quant_config.get_tp_tensor_names() + col_weight_suffixes = ( + self.quant_config.get_col_parallel_tensor_names()) + row_weight_suffixes = ( + self.quant_config.get_row_parallel_tensor_names()) column_parallel_weights: List[str] = [] for layer in self._column_parallel_layers: - for suffix in weight_suffixes: + for suffix in col_weight_suffixes: column_parallel_weights.append(f"{layer}.{suffix}") row_parallel_weights: List[str] = [] for layer in self._row_parallel_layers: - for suffix in weight_suffixes: + for suffix in row_weight_suffixes: row_parallel_weights.append(f"{layer}.{suffix}") tp_size = get_tensor_model_parallel_world_size() @@ -350,10 +354,10 @@ def load_weights(self, if "rotary_emb.inv_freq" in name: continue - is_packed = False + packed_dim = None is_transposed = False if self.quant_config is not None: - is_packed = self.quant_config.is_packed(name) + packed_dim = self.quant_config.get_packed_dim(name) is_transposed = self.quant_config.is_transposed(name) if is_transposed: loaded_weight = convert_pyslice_to_tensor(loaded_weight) @@ -367,9 +371,11 @@ def load_weights(self, if is_transposed: param = param.T - if is_packed: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor + if packed_dim is not None: + shard_dim = 0 if not is_transposed else 1 + if packed_dim == shard_dim: + shard_size //= self.quant_config.pack_factor + offset //= self.quant_config.pack_factor if weight_name in ["k_proj", "v_proj"]: shard_id = tp_rank // num_kv_heads_replicas diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 94323dd923910..8660223831f14 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -301,17 +301,21 @@ def load_weights(self, load_format: str = "auto", revision: Optional[str] = None): if self.quant_config is None: - weight_suffixes = ["weight"] + col_weight_suffixes = ["weight"] + row_weight_suffixes = ["weight"] else: - weight_suffixes = self.quant_config.get_tp_tensor_names() + col_weight_suffixes = ( + self.quant_config.get_col_parallel_tensor_names()) + row_weight_suffixes = ( + self.quant_config.get_row_parallel_tensor_names()) column_parallel_weights: List[str] = [] for layer in self._column_parallel_layers: - for suffix in weight_suffixes: + for suffix in col_weight_suffixes: column_parallel_weights.append(f"{layer}.{suffix}") row_parallel_weights: List[str] = [] for layer in self._row_parallel_layers: - for suffix in weight_suffixes: + for suffix in row_weight_suffixes: row_parallel_weights.append(f"{layer}.{suffix}") tp_size = get_tensor_model_parallel_world_size() @@ -334,10 +338,10 @@ def load_weights(self, if "rotary_emb.inv_freq" in name: continue - is_packed = False + packed_dim = None is_transposed = False if self.quant_config is not None: - is_packed = self.quant_config.is_packed(name) + packed_dim = self.quant_config.get_packed_dim(name) is_transposed = self.quant_config.is_transposed(name) if is_transposed: loaded_weight = convert_pyslice_to_tensor(loaded_weight) @@ -351,9 +355,11 @@ def load_weights(self, if is_transposed: param = param.T - if is_packed: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor + if packed_dim is not None: + shard_dim = 0 if not is_transposed else 1 + if packed_dim == shard_dim: + shard_size //= self.quant_config.pack_factor + offset //= self.quant_config.pack_factor loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py index df67758f71108..345f6494bf836 100644 --- a/vllm/model_executor/quantization_utils/__init__.py +++ b/vllm/model_executor/quantization_utils/__init__.py @@ -2,9 +2,11 @@ from vllm.model_executor.quantization_utils.awq import AWQConfig from vllm.model_executor.quantization_utils.base import QuantizationConfig +from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig _QUANTIZATION_REGISTRY = { "awq": AWQConfig, + "squeezellm": SqueezeLLMConfig, } diff --git a/vllm/model_executor/quantization_utils/awq.py b/vllm/model_executor/quantization_utils/awq.py index 9c6160fbe7c4b..ebc89560a4477 100644 --- a/vllm/model_executor/quantization_utils/awq.py +++ b/vllm/model_executor/quantization_utils/awq.py @@ -60,13 +60,17 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) @classmethod - def get_packed_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros"] + def get_packed_tensors(cls) -> Dict[str, int]: + return {"qweight": 1, "qzeros": 1} @classmethod def get_transposed_tensor_names(cls) -> List[str]: return ["qweight", "qzeros", "scales"] @classmethod - def get_tp_tensor_names(cls) -> List[str]: + def get_col_parallel_tensor_names(cls) -> List[str]: + return ["qweight", "qzeros", "scales"] + + @classmethod + def get_row_parallel_tensor_names(cls) -> List[str]: return ["qweight", "qzeros", "scales"] diff --git a/vllm/model_executor/quantization_utils/base.py b/vllm/model_executor/quantization_utils/base.py index 3c91a1ac58176..a70a7a8631e41 100644 --- a/vllm/model_executor/quantization_utils/base.py +++ b/vllm/model_executor/quantization_utils/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch @@ -45,19 +45,25 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @classmethod - def get_packed_tensor_names(cls) -> List[str]: + def get_packed_tensors(cls) -> Dict[str, int]: + """Returns a dictionary of packed tensor names and their pack dims.""" raise NotImplementedError @classmethod - def is_packed(cls, tensor_name: str) -> bool: - """Returns True if a tensor is packed. + def get_packed_dim(cls, tensor_name: str) -> Optional[int]: + """Returns the pack dim of a tensor if it is packed. A tensor is considered packed if each element in the tensor is a packed representation of multiple elements in the original tensor. For example, an INT32 element in the tensor may represent 8 INT4 elements in the original tensor. + If the tensor is not packed, returns None. """ - return any(tag in tensor_name for tag in cls.get_packed_tensor_names()) + packed_tensors = cls.get_packed_tensors() + for packed_tensor_name, pack_dim in packed_tensors.items(): + if packed_tensor_name in tensor_name: + return pack_dim + return None @classmethod def get_transposed_tensor_names(cls) -> List[str]: @@ -71,5 +77,9 @@ def is_transposed(cls, tensor_name: str) -> bool: for tag in cls.get_transposed_tensor_names()) @classmethod - def get_tp_tensor_names(cls) -> List[str]: + def get_col_parallel_tensor_names(cls) -> List[str]: + raise NotImplementedError + + @classmethod + def get_row_parallel_tensor_names(cls) -> List[str]: raise NotImplementedError diff --git a/vllm/model_executor/quantization_utils/squeezellm.py b/vllm/model_executor/quantization_utils/squeezellm.py new file mode 100644 index 0000000000000..8a1db3e233217 --- /dev/null +++ b/vllm/model_executor/quantization_utils/squeezellm.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, List + +import torch + +from vllm.model_executor.quantization_utils.base import QuantizationConfig + + +class SqueezeLLMConfig(QuantizationConfig): + """Config class for SqueezeLLM. + + Reference: https://arxiv.org/pdf/2306.07629 + """ + + def __init__( + self, + weight_bits: int, + ) -> None: + self.weight_bits = weight_bits + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"SqueezeLLM, but got {self.weight_bits} bits.") + + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" + + @classmethod + def get_name(cls) -> str: + return "squeezellm" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + return cls(weight_bits) + + @classmethod + def get_packed_tensors(cls) -> Dict[str, int]: + return {"qweight": 0} + + @classmethod + def get_transposed_tensor_names(cls) -> List[str]: + return ["qweight"] + + @classmethod + def get_col_parallel_tensor_names(cls) -> List[str]: + return ["qweight", "lookup_table"] + + @classmethod + def get_row_parallel_tensor_names(cls) -> List[str]: + return ["qweight"]