Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Bitsandbytes: Add dequantize API for bitsandbytes quantized models #30806

Merged
merged 11 commits into from
May 15, 2024
23 changes: 22 additions & 1 deletion docs/source/en/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,27 @@ double_quant_config = BitsAndBytesConfig(
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
```

### Dequantizing `bitsandbytes` models

Once quantized, you can dequantize the model to the original precision. Note this might result in a small quality loss of the model. Make also sure to have enough GPU RAM to fit the dequantized model.
Below is how to perform dequantization on a 4-bit model using `bitsandbytes`.

```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer

model_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(model_id, BitsAndBytesConfig(load_in_4bit=True))
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.dequantize()

text = tokenizer("Hello my name is", return_tensors="pt").to(0)

out = model.generate(**text)
print(tokenizer.decode(out[0]))
```

## EETQ
The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization.

Expand Down Expand Up @@ -794,4 +815,4 @@ model = transformers.AutoModelForCausalLM.from_pretrained(
### Optimized Runtime
HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training.
For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090.
For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend
For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"replace_with_awq_linear",
],
"bitsandbytes": [
"dequantize_and_replace",
"get_keys_to_not_convert",
"replace_8bit_linear",
"replace_with_bnb_linear",
Expand Down Expand Up @@ -97,6 +98,7 @@
replace_with_awq_linear,
)
from .bitsandbytes import (
dequantize_and_replace,
get_keys_to_not_convert,
replace_8bit_linear,
replace_with_bnb_linear,
Expand Down
122 changes: 122 additions & 0 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.metadata
import inspect
import warnings
from copy import deepcopy
from inspect import signature
Expand All @@ -16,7 +17,9 @@
from ..pytorch_utils import Conv1D

if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import find_tied_parameters

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -322,3 +325,122 @@ def get_keys_to_not_convert(model):
filtered_module_names.append(name)

return filtered_module_names


# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
"""
Helper function to dequantize 4bit or 8bit bnb weights.

If the weight is not a bnb quantized weight, it will be returned as is.
"""
if not isinstance(weight, torch.nn.Parameter):
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")

cls_name = weight.__class__.__name__
if cls_name not in ("Params4bit", "Int8Params"):
return weight

if cls_name == "Params4bit":
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
logger.warning_once(
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
)
return output_tensor

if state.SCB is None:
state.SCB = weight.SCB

im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()


def _create_accelerate_new_hook(old_hook):
r"""
Creates a new hook based on the old hook. Use it only if you know what you are doing !
This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
with some changes
"""
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
old_hook_attr = old_hook.__dict__
filtered_old_hook_attr = {}
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
for k in old_hook_attr.keys():
if k in old_hook_init_signature.parameters:
filtered_old_hook_attr[k] = old_hook_attr[k]
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook


def dequantize_and_replace(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
"""
Converts a quantized model into its unquantized original version. The newly converted model will have
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
some performance drop compared to the original model before quantization - use it only for specific usecases
such as QLoRA adapters merging.

Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
quant_method = quantization_config.quantization_method()

target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit

for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, target_cls) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)

if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
bias = getattr(module, "bias", None)

device = module.weight.device
with init_empty_weights():
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)

if quant_method == "llm_int8":
state = module.state
else:
state = None

new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))

if bias is not None:
new_module.bias = bias

# Create a new hook and attach it in case we use accelerate
if hasattr(module, "_hf_hook"):
old_hook = module._hf_hook
new_hook = _create_accelerate_new_hook(old_hook)

remove_hook_from_module(module)
add_hook_to_module(new_module, new_hook)

new_module.to(device)
model._modules[name] = new_module
if len(list(module.children())) > 0:
_, has_been_replaced = dequantize_and_replace(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
Copy link
Collaborator

@amyeroberts amyeroberts May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One general comment, if instead you could have a private method _dequantize_and_replace, which handles the recursion, you don't need to return has_been_replaced here. When someone calls dequantize_and_replace, I don't think has_been_replaced is ever used and could be confusing e.g.:

# This is just dequantize_and_replace from before
def _dequantize_and_replace(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
):
    ...
    return model, has_been_replaced

def dequantize_and_replace(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
):
    model, has_been_replaced = _dequantize_and_replace(...)
    return model 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense ! Will do !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 8b904f7 !

12 changes: 12 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,18 @@ def post_init(self):
self.init_weights()
self._backward_compatibility_gradient_checkpointing()

def dequantize(self):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"""
Potentially dequantize the model in case it has been quantized by a quantization method that support
dequantization.
"""
hf_quantizer = getattr(self, "hf_quantizer", None)

if hf_quantizer is None:
raise ValueError("You need to first quantize your model in order to dequantize it")

return hf_quantizer.dequantize(self)

def _backward_compatibility_gradient_checkpointing(self):
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,23 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs):
"""
return self._process_model_after_weight_loading(model, **kwargs)

def dequantize(self, model):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"""
Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance.
Note not all quantization schemes support this.
"""
model = self._dequantize(model)

# Delete quantizer and quantization config
del model.hf_quantizer

return model

def _dequantize(self, model):
raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)

@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs):
...
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,11 @@ def is_serializable(self):
@property
def is_trainable(self) -> bool:
return True

def _dequantize(self, model):
from ..integrations import dequantize_and_replace

model, has_been_replaced = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model
8 changes: 8 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,11 @@ def is_serializable(self):
@property
def is_trainable(self) -> bool:
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0")

def _dequantize(self, model):
from ..integrations import dequantize_and_replace

model, has_been_replaced = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model
17 changes: 17 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,23 @@ def test_generate_quality_config(self):

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_generate_quality_dequantize(self):
r"""
Test that loading the model and unquantize it produce correct results
"""
bnb_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)

model_4bit.dequantize()

encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Expand Down
17 changes: 17 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ def test_generate_quality_config(self):

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_generate_quality_dequantize(self):
r"""
Test that loading the model and dequantizing it produce correct results
"""
bnb_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)

model_8bit.dequantize()

encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_raise_if_config_and_load_in_8bit(self):
r"""
Test that loading the model with the config and `load_in_8bit` raises an error
Expand Down
Loading