Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Bitsandbytes: Add dequantize API for bitsandbytes quantized models #30806

Merged
merged 11 commits into from
May 15, 2024
23 changes: 21 additions & 2 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _create_accelerate_new_hook(old_hook):
return new_hook


def dequantize_and_replace(
def _dequantize_and_replace(
model,
modules_to_not_convert=None,
current_key_name=None,
Expand Down Expand Up @@ -434,7 +434,7 @@ def dequantize_and_replace(
new_module.to(device)
model._modules[name] = new_module
if len(list(module.children())) > 0:
_, has_been_replaced = dequantize_and_replace(
_, has_been_replaced = _dequantize_and_replace(
module,
modules_to_not_convert,
current_key_name,
Expand All @@ -444,3 +444,22 @@ def dequantize_and_replace(
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
Copy link
Collaborator

@amyeroberts amyeroberts May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One general comment, if instead you could have a private method _dequantize_and_replace, which handles the recursion, you don't need to return has_been_replaced here. When someone calls dequantize_and_replace, I don't think has_been_replaced is ever used and could be confusing e.g.:

# This is just dequantize_and_replace from before
def _dequantize_and_replace(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
):
    ...
    return model, has_been_replaced

def dequantize_and_replace(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
):
    model, has_been_replaced = _dequantize_and_replace(...)
    return model 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense ! Will do !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 8b904f7 !



def dequantize_and_replace(
model,
modules_to_not_convert=None,
quantization_config=None,
):
model, has_been_replaced = _dequantize_and_replace(
model,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)

if not has_been_replaced:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

"For some reason the model has not been properly dequantized. You might see unexpected behavior."
)

return model
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def is_trainable(self) -> bool:
def _dequantize(self, model):
from ..integrations import dequantize_and_replace

model, has_been_replaced = dequantize_and_replace(
model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def is_trainable(self) -> bool:
def _dequantize(self, model):
from ..integrations import dequantize_and_replace

model, has_been_replaced = dequantize_and_replace(
model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model
Loading