Skip to content

Commit

Permalink
support bitsandbytes quantization with more models (vllm-project#9148)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqianfzh authored Oct 9, 2024
1 parent 9ba0bd6 commit 2f4117c
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 27 deletions.
13 changes: 7 additions & 6 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@
import torch

from tests.quantization.utils import is_quant_method_supported

from ..utils import fork_new_process_for_each_test
from tests.utils import fork_new_process_for_each_test

models_4bit_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
("facebook/opt-125m", "quantize opt model inflight"),
]

models_pre_qaunt_4bit_to_test = [
('lllyasviel/omost-llama-3-8b-4bits',
'read pre-quantized 4-bit NF4 model'),
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'),
]

models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
('meta-llama/Llama-Guard-3-8B-INT8',
'read pre-quantized llama 8-bit model'),
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
]


Expand Down Expand Up @@ -133,6 +133,7 @@ def validate_generated_texts(hf_runner,
hf_str = hf_log["generated_text"]
vllm_str = vllm_log["generated_text"]
prompt = hf_log["prompt"]

assert hf_str == vllm_str, (f"Model: {model_name}"
f"Mismatch between HF and vLLM outputs:\n"
f"Prompt: {prompt}\n"
Expand Down
26 changes: 25 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)

param_data = param.data
if output_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
Expand Down Expand Up @@ -821,6 +825,9 @@ def weight_loader(self,
("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)

packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
Expand All @@ -834,6 +841,23 @@ def weight_loader(self,
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
"v":
((self.total_num_heads + self.total_num_kv_heads) *
self.head_size,
self.total_num_kv_heads * self.head_size),
"total":
((self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_size, 0)
}

shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id)

loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_quant_method(self, layer: torch.nn.Module,
return None

def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
return []


class BitsAndBytesLinearMethod(LinearMethodBase):
Expand Down Expand Up @@ -236,7 +236,7 @@ def _apply_8bit_weight(
if generation == 0 or generation == 1:
matmul_states[i] = MatmulLtState()
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
matmul_states[i].SCB = quant_states[i]
matmul_states[i].SCB = quant_states[i].to(x.device)
matmul_states[i].threshold = (
self.quant_config.llm_int8_threshold)
matmul_states[i].has_fp16_weights = (
Expand Down
62 changes: 44 additions & 18 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,15 +736,26 @@ def save_model(
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""

# TODO: these module names are for Llama only,
# change so that it works with other models as well
possible_config_file_names = ["adapter_config.json"]

default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
'.fc1.',
'.fc2.',
'.dense.',
'.query_key_value.',
'.qkv_proj.',
'.dense_h_to_4h.',
'.dense_4h_to_h.',
'.out_proj.',
]

possible_config_file_names = ["adapter_config.json"]

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)

Expand All @@ -754,7 +765,7 @@ def __init__(self, load_config: LoadConfig):
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
self.target_modules = self.default_target_modules
self.target_modules = []
return

qlora_adapter = load_config.model_loader_extra_config[
Expand Down Expand Up @@ -901,10 +912,11 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):

if not weight_name.endswith(".weight"):
if not weight_name.endswith((".weight", ".bias")):
continue

qweight_name = weight_name.replace(".weight", ".qweight")

if qweight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor
Expand All @@ -920,7 +932,7 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
if weight_name.endswith((".weight", ".bias")):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
Expand All @@ -943,9 +955,10 @@ def _parse_quant_state(param_name: str,
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):

if not weight_name.endswith((".weight", ".bias")):
continue

if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
in temp_state_dict) or \
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
Expand All @@ -965,15 +978,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,

for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):

if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
weight_name = weight_name.replace(".weight", ".qweight")

# weight partitions of different modules occur at
# different dimensions
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if 'down_proj' in weight_name or 'o_proj' in weight_name:
if any(module in weight_name
for module in self.column_parallel_weights_modules):

total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
Expand Down Expand Up @@ -1022,6 +1034,20 @@ def _load_weights(self, model_config: ModelConfig,
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.")

if len(self.target_modules) == 0:
if hasattr(model, 'default_bitsandbytes_target_modules'):
self.target_modules = model.default_bitsandbytes_target_modules
else:
self.target_modules = self.default_target_modules

if hasattr(model, 'column_parallel_weights_modules'):
self.column_parallel_weights_modules = \
model.column_parallel_weights_modules
else:
self.column_parallel_weights_modules = []

self.model_type = type(model).__name__

logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...")

Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,17 @@ def forward(

class FalconForCausalLM(nn.Module, SupportsPP):

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
default_bitsandbytes_target_modules = [
".query_key_value.",
".dense.",
".dense_h_to_4h.",
".dense_4h_to_h.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]

def __init__(
self,
config: FalconConfig,
Expand Down
22 changes: 22 additions & 0 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"gate_up_proj",
"down_proj",
]

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
Expand Down
13 changes: 13 additions & 0 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
13 changes: 13 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,19 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings"
}
embedding_padding_modules = ["lm_head"]

# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
13 changes: 13 additions & 0 deletions vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,19 @@ def forward(

class OPTForCausalLM(nn.Module, SupportsPP):

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".out_proj.", ".fc2."]

def __init__(
self,
config: OPTConfig,
Expand Down
14 changes: 14 additions & 0 deletions vllm/model_executor/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,20 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"fc1",
"fc2",
]

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".fc2.", ".dense."]

embedding_modules = {}
embedding_padding_modules = []

Expand Down

0 comments on commit 2f4117c

Please sign in to comment.