-
Notifications
You must be signed in to change notification settings - Fork 27.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HQQ quantization support (#29637)
* update HQQ transformers integration * push import_utils.py * add force_hooks check in modeling_utils.py * fix | with Optional * force bias as param * check bias is Tensor * force forward for multi-gpu * review fixes pass * remove torch grad() * if any key in linear_tags fix * add cpu/disk check * isinstance return * add multigpu test + refactor tests * clean hqq_utils imports in hqq.py * clean hqq_utils imports in quantizer_hqq.py * delete hqq_utils.py * Delete src/transformers/utils/hqq_utils.py * ruff init * remove torch.float16 from __init__ in test * refactor test * isinstance -> type in quantizer_hqq.py * cpu/disk device_map check in quantizer_hqq.py * remove type(module) nn.linear check in quantizer_hqq.py * add BaseQuantizeConfig import inside HqqConfig init * remove hqq import in hqq.py * remove accelerate import from test_hqq.py * quant config.py doc update * add hqqconfig to main_classes doc * make style * __init__ fix * ruff __init__ * skip_modules list * hqqconfig format fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * test_hqq.py remove mistral comment * remove self.using_multi_gpu is False * torch_dtype default val set and logger.info * hqq.py isinstance fix * remove torch=None * torch_device test_hqq * rename test_hqq * MODEL_ID in test_hqq * quantizer_hqq setattr fix * quantizer_hqq typo fix * imports quantizer_hqq.py * isinstance quantizer_hqq * hqq_layer.bias reformat quantizer_hqq * Step 2 as comment in quantizer_hqq * prepare_for_hqq_linear() comment * keep_in_fp32_modules fix * HqqHfQuantizer reformat * quantization.md hqqconfig * quantization.md model example reformat * quantization.md # space * quantization.md space }) * quantization.md space }) * quantization_config fix doc Co-authored-by: amyeroberts <[email protected]> * axis value check in quantization_config * format * dynamic config explanation * quant config method in quantization.md * remove shard-level progress * .cuda fix modeling_utils * test_hqq fixes * make fix-copies --------- Co-authored-by: amyeroberts <[email protected]>
- Loading branch information
1 parent
4c94093
commit 5995299
Showing
16 changed files
with
681 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"HQQ (Half-Quadratic Quantization) integration file" | ||
|
||
from ..utils import is_hqq_available, is_torch_available, logging | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
# Name all modules inside the model | ||
def autoname_modules(model): | ||
for name, module in model.named_modules(): | ||
module.name = name | ||
|
||
|
||
# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj | ||
def name_to_linear_tag(name): | ||
return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))]) | ||
|
||
|
||
# Get all linear tags available | ||
def get_linear_tags(model): | ||
if is_hqq_available(): | ||
from hqq.core.quantize import HQQLinear | ||
|
||
linear_tags = set() | ||
for name, module in model.named_modules(): | ||
if isinstance(module, (torch.nn.Linear, HQQLinear)): | ||
linear_tags.add(name_to_linear_tag(name)) | ||
return list(linear_tags) | ||
|
||
|
||
def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None): | ||
for name, module in model.named_children(): | ||
if current_key_name is None: | ||
current_key_name = [] | ||
current_key_name.append(name) | ||
|
||
if isinstance(module, torch.nn.Linear): | ||
# Get linear tag | ||
linear_tag = name_to_linear_tag(module.name) | ||
|
||
# We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param() | ||
if linear_tag in patch_params: | ||
if patch_params[linear_tag] is not None: | ||
model._modules[name].quant_config = patch_params[linear_tag] | ||
# Store the module class in case we need to transpose the weight later | ||
model._modules[name].source_cls = type(module) | ||
# Force requires grad to False to avoid unexpected errors | ||
model._modules[name].requires_grad_(False) | ||
|
||
has_been_replaced = True | ||
|
||
if len(list(module.children())) > 0: | ||
_, has_been_replaced = _prepare_for_hqq_linear( | ||
module, | ||
patch_params=patch_params, | ||
has_been_replaced=has_been_replaced, | ||
) | ||
# Remove the last key for recursion | ||
current_key_name.pop(-1) | ||
|
||
return model, has_been_replaced | ||
|
||
|
||
def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False): | ||
""" | ||
Prepares nn.Linear layers for HQQ quantization. | ||
Since each layer type can have separate quantization parameters, we need to do the following: | ||
1- tag each module with its neme via autoname_modules() | ||
2- Extract linear_tags (e.g. ['self_attn.q_proj', ...]) | ||
3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params | ||
""" | ||
|
||
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert | ||
|
||
# Add name to module | ||
autoname_modules(model) | ||
|
||
# Get linear tags. This allows us to use different quant params to different layer types | ||
linear_tags = get_linear_tags(model) | ||
|
||
# Convert quantization_config to layer-wise config | ||
skip_modules = quantization_config.skip_modules | ||
quant_config = quantization_config.to_dict() | ||
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert)) | ||
|
||
if any(key in linear_tags for key in quant_config.keys()): | ||
# If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None) | ||
patch_params = {key: None for key in linear_tags} | ||
patch_params.update(quant_config) | ||
else: | ||
# Same quant_config for all layers | ||
patch_params = {k: quant_config for k in linear_tags} | ||
|
||
model, has_been_replaced = _prepare_for_hqq_linear( | ||
model, patch_params=patch_params, has_been_replaced=has_been_replaced | ||
) | ||
|
||
# We store quantization config as linear_tag -> hqq quant config | ||
model.config.quantization_config = patch_params | ||
|
||
if not has_been_replaced: | ||
logger.warning("No linear modules were found in your model for quantization.") | ||
|
||
return model |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.