Skip to content

Commit

Permalink
[Model] Add BNB quantization support for Mllama (#9720)
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored Oct 29, 2024
1 parent ef7865b commit 09500f7
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 12 deletions.
35 changes: 31 additions & 4 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand All @@ -23,7 +24,7 @@ def __init__(
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[Any] = None,
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 0.0,
) -> None:

Expand All @@ -34,11 +35,15 @@ def __init__(
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold

def __repr__(self) -> str:
return "BitsAndBytesConfig"
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
f"load_in_4bit={self.load_in_4bit}, "
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")

@classmethod
def get_name(self) -> str:
Expand Down Expand Up @@ -102,15 +107,21 @@ def get_safe_value(config, keys, default_value=None):
llm_int8_threshold=llm_int8_threshold)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
return any(module_name in prefix for module_name in llm_int8_skip_modules)


class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
Expand Down Expand Up @@ -211,6 +222,11 @@ def _apply_8bit_weight(
from bitsandbytes import MatmulLtState, matmul

original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)

qweight = layer.qweight
Expand Down Expand Up @@ -265,6 +281,9 @@ def _apply_8bit_weight(

out = out.to(original_type)

if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))

if bias is not None:
out += bias

Expand All @@ -282,6 +301,11 @@ def _apply_4bit_weight(
from bitsandbytes import matmul_4bit

original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)

qweight = layer.qweight
Expand Down Expand Up @@ -310,6 +334,9 @@ def _apply_4bit_weight(

out = out.to(original_type)

if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))

if bias is not None:
out += bias

Expand Down
19 changes: 16 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,19 @@ def _get_quantized_weights_iterator(
return self._unquantized_generator(hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict

def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix)
for suffix in quantized_suffix)

def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax", "quant_map", "nested_absmax", "nested_quant_map",
"bitsandbytes"
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)

def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
Expand All @@ -912,7 +925,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):

if not weight_name.endswith((".weight", ".bias")):
if self._is_8bit_weight_name(weight_name):
continue

qweight_name = weight_name.replace(".weight", ".qweight")
Expand All @@ -932,7 +945,7 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith((".weight", ".bias")):
if not self._is_4bit_weight_name(weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
Expand All @@ -956,7 +969,7 @@ def _parse_quant_state(param_name: str,
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):

if not weight_name.endswith((".weight", ".bias")):
if self._is_4bit_weight_name(weight_name):
continue

if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
Expand Down
42 changes: 37 additions & 5 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,10 @@ def forward(self, hidden_state: torch.Tensor,
# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):

def __init__(self, config: config_mllama.MllamaVisionConfig):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()

model_parallel_size = get_tensor_model_parallel_world_size()
Expand All @@ -341,12 +344,16 @@ def __init__(self, config: config_mllama.MllamaVisionConfig):
self.head_dim,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

def forward(
Expand Down Expand Up @@ -393,7 +400,8 @@ def __init__(
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size

self.self_attn = MllamaVisionSdpaAttention(config)
self.self_attn = MllamaVisionSdpaAttention(
config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
Expand Down Expand Up @@ -1002,6 +1010,7 @@ def __init__(
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)

def forward(
Expand Down Expand Up @@ -1037,6 +1046,26 @@ def forward(
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(self,
config: config_mllama.MllamaConfig,
Expand All @@ -1061,10 +1090,13 @@ def __init__(self,
quant_config=quant_config,
prefix="language_model",
)
self.multi_modal_projector = nn.Linear(
self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
quant_config=quant_config,
gather_output=True,
prefix="multi_modal_projector",
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
Expand Down Expand Up @@ -1128,7 +1160,7 @@ def _parse_and_validate_image_input(self, **kwargs: object):
raise ValueError("No images provided.")
max_num_tiles = max(
max([len(x) for x in y[0]]) for y in pixel_values)
device = self.multi_modal_projector.weight.device
device = next(self.multi_modal_projector.parameters()).device
bsz = len(pixel_values)
out_num_tiles = []
out_images = torch.zeros(
Expand Down Expand Up @@ -1204,7 +1236,7 @@ def get_cross_attention_states(
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states, _ = self.multi_modal_projector(
cross_attention_states)

bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
Expand Down

0 comments on commit 09500f7

Please sign in to comment.