From 2f4117c38e101ee63b65521c93b22efe3526f77e Mon Sep 17 00:00:00 2001 From: chenqianfzh <51831990+chenqianfzh@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:52:19 -0700 Subject: [PATCH] support bitsandbytes quantization with more models (#9148) --- tests/quantization/test_bitsandbytes.py | 13 ++-- vllm/model_executor/layers/linear.py | 26 +++++++- .../layers/quantization/bitsandbytes.py | 4 +- vllm/model_executor/model_loader/loader.py | 62 +++++++++++++------ vllm/model_executor/models/falcon.py | 11 ++++ vllm/model_executor/models/gemma.py | 22 +++++++ vllm/model_executor/models/gemma2.py | 13 ++++ vllm/model_executor/models/llama.py | 13 ++++ vllm/model_executor/models/opt.py | 13 ++++ vllm/model_executor/models/phi.py | 14 +++++ 10 files changed, 164 insertions(+), 27 deletions(-) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index ac2ebc622ba6f..f2acf0d70afef 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -9,22 +9,22 @@ import torch from tests.quantization.utils import is_quant_method_supported - -from ..utils import fork_new_process_for_each_test +from tests.utils import fork_new_process_for_each_test models_4bit_to_test = [ - ('huggyllama/llama-7b', 'quantize model inflight'), + ("facebook/opt-125m", "quantize opt model inflight"), ] models_pre_qaunt_4bit_to_test = [ - ('lllyasviel/omost-llama-3-8b-4bits', - 'read pre-quantized 4-bit NF4 model'), ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', 'read pre-quantized 4-bit FP4 model'), + ('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'), ] models_pre_quant_8bit_to_test = [ - ('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'), + ('meta-llama/Llama-Guard-3-8B-INT8', + 'read pre-quantized llama 8-bit model'), + ("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"), ] @@ -133,6 +133,7 @@ def validate_generated_texts(hf_runner, hf_str = hf_log["generated_text"] vllm_str = vllm_log["generated_text"] prompt = hf_log["prompt"] + assert hf_str == vllm_str, (f"Model: {model_name}" f"Mismatch between HF and vLLM outputs:\n" f"Prompt: {prompt}\n" diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c162ab81c5530..a3d1dc2c76d21 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -336,8 +336,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if is_gguf_weight and isinstance(param, UninitializedParameter): param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + param_data = param.data - if output_dim is not None: + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if output_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[output_dim] start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, @@ -821,6 +825,9 @@ def weight_loader(self, ("v", (self.total_num_heads + self.total_num_kv_heads) * self.head_size, self.total_num_kv_heads * self.head_size), ] + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: # Special case for Quantized Weights. @@ -834,6 +841,23 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.total_num_heads * self.head_size), + "k": (self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + "v": + ((self.total_num_heads + self.total_num_kv_heads) * + self.head_size, + self.total_num_kv_heads * self.head_size), + "total": + ((self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_size, 0) + } + + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, shard_id) + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 38495d5a5a863..faa8d92e83de3 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -108,7 +108,7 @@ def get_quant_method(self, layer: torch.nn.Module, return None def get_scaled_act_names(self) -> List[str]: - return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + return [] class BitsAndBytesLinearMethod(LinearMethodBase): @@ -236,7 +236,7 @@ def _apply_8bit_weight( if generation == 0 or generation == 1: matmul_states[i] = MatmulLtState() matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] - matmul_states[i].SCB = quant_states[i] + matmul_states[i].SCB = quant_states[i].to(x.device) matmul_states[i].threshold = ( self.quant_config.llm_int8_threshold) matmul_states[i].has_fp16_weights = ( diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8d4163ec88490..813f58339da37 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -736,15 +736,26 @@ def save_model( class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" - # TODO: these module names are for Llama only, - # change so that it works with other models as well + possible_config_file_names = ["adapter_config.json"] + default_target_modules = [ - "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", - "o_proj" + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + '.fc1.', + '.fc2.', + '.dense.', + '.query_key_value.', + '.qkv_proj.', + '.dense_h_to_4h.', + '.dense_4h_to_h.', + '.out_proj.', ] - possible_config_file_names = ["adapter_config.json"] - def __init__(self, load_config: LoadConfig): super().__init__(load_config) @@ -754,7 +765,7 @@ def __init__(self, load_config: LoadConfig): if (not load_config.model_loader_extra_config or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config): - self.target_modules = self.default_target_modules + self.target_modules = [] return qlora_adapter = load_config.model_loader_extra_config[ @@ -901,10 +912,11 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): - if not weight_name.endswith(".weight"): + if not weight_name.endswith((".weight", ".bias")): continue qweight_name = weight_name.replace(".weight", ".qweight") + if qweight_name in quant_state_dict: set_weight_attrs(weight_tensor, {"load_in_8bit": True}) yield qweight_name, weight_tensor @@ -920,7 +932,7 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, use_safetensors) temp_state_dict = {} for weight_name, weight_tensor in weight_iterator: - if weight_name.endswith(".weight"): + if weight_name.endswith((".weight", ".bias")): continue # bitsandbytes library requires # weight.quant_state.bitsandbytes__* in CPU @@ -943,9 +955,10 @@ def _parse_quant_state(param_name: str, # pre quantized weights would have a quant_state for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): - # Filter out all weights whose suffix is not ".weight" - if not weight_name.endswith(".weight"): + + if not weight_name.endswith((".weight", ".bias")): continue + if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ in temp_state_dict) or \ (f"{weight_name}.quant_state.bitsandbytes__fp4" \ @@ -965,15 +978,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): - if any(target_module in weight_name - for target_module in self.target_modules): + + if any(target_module in weight_name for target_module in + self.target_modules) and weight_name.endswith(".weight"): weight_name = weight_name.replace(".weight", ".qweight") - # weight partitions of different modules occur at - # different dimensions - # TODO: these module names are for Llama only, - # change so that it works with other models as well - if 'down_proj' in weight_name or 'o_proj' in weight_name: + if any(module in weight_name + for module in self.column_parallel_weights_modules): + total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) @@ -1022,6 +1034,20 @@ def _load_weights(self, model_config: ModelConfig, f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet.") + if len(self.target_modules) == 0: + if hasattr(model, 'default_bitsandbytes_target_modules'): + self.target_modules = model.default_bitsandbytes_target_modules + else: + self.target_modules = self.default_target_modules + + if hasattr(model, 'column_parallel_weights_modules'): + self.column_parallel_weights_modules = \ + model.column_parallel_weights_modules + else: + self.column_parallel_weights_modules = [] + + self.model_type = type(model).__name__ + logger.info("Loading weights with BitsAndBytes quantization. " " May take a while ...") diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index a20dd93cee18c..467a33505ee12 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -391,6 +391,17 @@ def forward( class FalconForCausalLM(nn.Module, SupportsPP): + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = {} + default_bitsandbytes_target_modules = [ + ".query_key_value.", + ".dense.", + ".dense_h_to_4h.", + ".dense_4h_to_h.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."] + def __init__( self, config: FalconConfig, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index ca419891f69db..91e556db70a0b 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -332,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "gate_up_proj", "down_proj", ] + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + # Gemma does not apply LoRA to the embedding layer. embedding_modules = {} embedding_padding_modules = [] diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index bd3c1114c929f..f1899d92b02b6 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -375,6 +375,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # Gemma does not apply LoRA to the embedding layer. embedding_modules = {} embedding_padding_modules = [] + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8eacf73dd6322..4b4e024578789 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -449,6 +449,19 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "lm_head": "output_embeddings" } embedding_padding_modules = ["lm_head"] + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 727dd65acc749..3bcdb0d87fd52 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -315,6 +315,19 @@ def forward( class OPTForCausalLM(nn.Module, SupportsPP): + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + } + default_bitsandbytes_target_modules = [ + ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2." + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".out_proj.", ".fc2."] + def __init__( self, config: OPTConfig, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index c90fe2e0ab9ea..0918f21a40e27 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -260,6 +260,20 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "fc1", "fc2", ] + + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + } + default_bitsandbytes_target_modules = [ + ".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense." + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".fc2.", ".dense."] + embedding_modules = {} embedding_padding_modules = []