From 59952994c422ae9fc17f8a3134c00396c948c42e Mon Sep 17 00:00:00 2001 From: mobicham <37179323+mobicham@users.noreply.github.com> Date: Thu, 2 May 2024 18:51:49 +0200 Subject: [PATCH] Add HQQ quantization support (#29637) * update HQQ transformers integration * push import_utils.py * add force_hooks check in modeling_utils.py * fix | with Optional * force bias as param * check bias is Tensor * force forward for multi-gpu * review fixes pass * remove torch grad() * if any key in linear_tags fix * add cpu/disk check * isinstance return * add multigpu test + refactor tests * clean hqq_utils imports in hqq.py * clean hqq_utils imports in quantizer_hqq.py * delete hqq_utils.py * Delete src/transformers/utils/hqq_utils.py * ruff init * remove torch.float16 from __init__ in test * refactor test * isinstance -> type in quantizer_hqq.py * cpu/disk device_map check in quantizer_hqq.py * remove type(module) nn.linear check in quantizer_hqq.py * add BaseQuantizeConfig import inside HqqConfig init * remove hqq import in hqq.py * remove accelerate import from test_hqq.py * quant config.py doc update * add hqqconfig to main_classes doc * make style * __init__ fix * ruff __init__ * skip_modules list * hqqconfig format fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * test_hqq.py remove mistral comment * remove self.using_multi_gpu is False * torch_dtype default val set and logger.info * hqq.py isinstance fix * remove torch=None * torch_device test_hqq * rename test_hqq * MODEL_ID in test_hqq * quantizer_hqq setattr fix * quantizer_hqq typo fix * imports quantizer_hqq.py * isinstance quantizer_hqq * hqq_layer.bias reformat quantizer_hqq * Step 2 as comment in quantizer_hqq * prepare_for_hqq_linear() comment * keep_in_fp32_modules fix * HqqHfQuantizer reformat * quantization.md hqqconfig * quantization.md model example reformat * quantization.md # space * quantization.md space }) * quantization.md space }) * quantization_config fix doc Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * axis value check in quantization_config * format * dynamic config explanation * quant config method in quantization.md * remove shard-level progress * .cuda fix modeling_utils * test_hqq fixes * make fix-copies --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../Dockerfile | 3 + docs/source/en/main_classes/quantization.md | 4 + docs/source/en/quantization.md | 50 +++++ docs/source/en/quicktour.md | 0 src/transformers/__init__.py | 2 + src/transformers/integrations/__init__.py | 2 + src/transformers/integrations/hqq.py | 121 +++++++++++ .../integrations/integration_utils.py | 0 src/transformers/modeling_utils.py | 11 + src/transformers/quantizers/__init__.py | 0 src/transformers/quantizers/auto.py | 4 + src/transformers/quantizers/quantizer_hqq.py | 200 ++++++++++++++++++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + src/transformers/utils/quantization_config.py | 112 +++++++++- tests/quantization/hqq/test_hqq.py | 167 +++++++++++++++ 16 files changed, 681 insertions(+), 1 deletion(-) mode change 100644 => 100755 docker/transformers-quantization-latest-gpu/Dockerfile mode change 100644 => 100755 docs/source/en/main_classes/quantization.md mode change 100644 => 100755 docs/source/en/quantization.md mode change 100644 => 100755 docs/source/en/quicktour.md mode change 100644 => 100755 src/transformers/__init__.py mode change 100644 => 100755 src/transformers/integrations/__init__.py create mode 100755 src/transformers/integrations/hqq.py mode change 100644 => 100755 src/transformers/integrations/integration_utils.py mode change 100644 => 100755 src/transformers/modeling_utils.py mode change 100644 => 100755 src/transformers/quantizers/__init__.py mode change 100644 => 100755 src/transformers/quantizers/auto.py create mode 100755 src/transformers/quantizers/quantizer_hqq.py mode change 100644 => 100755 src/transformers/utils/__init__.py mode change 100644 => 100755 src/transformers/utils/import_utils.py mode change 100644 => 100755 src/transformers/utils/quantization_config.py create mode 100755 tests/quantization/hqq/test_hqq.py diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile old mode 100644 new mode 100755 index 08bc3c45b952db..47fcd11fd766d7 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt # Add aqlm for quantization testing RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2 +# Add hqq for quantization testing +RUN python3 -m pip install --no-cache-dir hqq + # Add autoawq for quantization testing # >=v0.2.3 needed for compatibility with torch 2.2.1 RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md old mode 100644 new mode 100755 index 91de5fc8a33ce1..f1e2acdcfe4809 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -52,3 +52,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide. ## HfQuantizer [[autodoc]] quantizers.base.HfQuantizer + +## HqqConfig + +[[autodoc]] HqqConfig diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md old mode 100644 new mode 100755 index 8a3650a8439040..ae4f44f6b800b7 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -745,3 +745,53 @@ The speed and throughput of fused and unfused modules were also tested with the
generate throughput/batch size
+ +## HQQ +Half-Quadratic Quantization (HQQ) implements on-the-fly quantization via fast robust optimization. It doesn't require calibration data and can be used to quantize any model. +Please refer to the official package for more details. + +For installation, we recommend you use the following approach to get the latest version and build its corresponding CUDA kernels: +``` +pip install hqq +``` + +To quantize a model, you need to create an [`HqqConfig`]. There are two ways of doing it: +``` Python +from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig + +# Method 1: all linear layers will use the same quantization config +quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default +``` + +``` Python +# Method 2: each linear layer with the same tag will use a dedicated quantization config +q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False} +q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False} +quant_config = HqqConfig(dynamic_config={ + 'self_attn.q_proj':q4_config, + 'self_attn.k_proj':q4_config, + 'self_attn.v_proj':q4_config, + 'self_attn.o_proj':q4_config, + + 'mlp.gate_proj':q3_config, + 'mlp.up_proj' :q3_config, + 'mlp.down_proj':q3_config, +}) +``` + +The second approach is especially interesting for quantizing Mixture-of-Experts (MoEs) because the experts are less affected by lower quantization settings. + + +Then you simply quantize the model as follows +``` Python +model = transformers.AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + device_map="cuda", + quantization_config=quant_config +) +``` +### Optimized Runtime +HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training. +For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090. +For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend \ No newline at end of file diff --git a/docs/source/en/quicktour.md b/docs/source/en/quicktour.md old mode 100644 new mode 100755 diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py old mode 100644 new mode 100755 index 53a087468e66a0..12f1821df32f91 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1133,6 +1133,7 @@ "BitsAndBytesConfig", "EetqConfig", "GPTQConfig", + "HqqConfig", "QuantoConfig", ], } @@ -6099,6 +6100,7 @@ BitsAndBytesConfig, EetqConfig, GPTQConfig, + HqqConfig, QuantoConfig, ) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py old mode 100644 new mode 100755 index 72fdf3e1bbb997..69fb0e3259b1d5 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -43,6 +43,7 @@ "unset_hf_deepspeed_config", ], "eetq": ["replace_with_eetq_linear"], + "hqq": ["prepare_for_hqq_linear"], "integration_utils": [ "INTEGRATION_TO_CALLBACK", "AzureMLCallback", @@ -113,6 +114,7 @@ unset_hf_deepspeed_config, ) from .eetq import replace_with_eetq_linear + from .hqq import prepare_for_hqq_linear from .integration_utils import ( INTEGRATION_TO_CALLBACK, AzureMLCallback, diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py new file mode 100755 index 00000000000000..10a6d06a3f9f0b --- /dev/null +++ b/src/transformers/integrations/hqq.py @@ -0,0 +1,121 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"HQQ (Half-Quadratic Quantization) integration file" + +from ..utils import is_hqq_available, is_torch_available, logging + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +# Name all modules inside the model +def autoname_modules(model): + for name, module in model.named_modules(): + module.name = name + + +# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj +def name_to_linear_tag(name): + return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))]) + + +# Get all linear tags available +def get_linear_tags(model): + if is_hqq_available(): + from hqq.core.quantize import HQQLinear + + linear_tags = set() + for name, module in model.named_modules(): + if isinstance(module, (torch.nn.Linear, HQQLinear)): + linear_tags.add(name_to_linear_tag(name)) + return list(linear_tags) + + +def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None): + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, torch.nn.Linear): + # Get linear tag + linear_tag = name_to_linear_tag(module.name) + + # We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param() + if linear_tag in patch_params: + if patch_params[linear_tag] is not None: + model._modules[name].quant_config = patch_params[linear_tag] + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + + has_been_replaced = True + + if len(list(module.children())) > 0: + _, has_been_replaced = _prepare_for_hqq_linear( + module, + patch_params=patch_params, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + + return model, has_been_replaced + + +def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False): + """ + Prepares nn.Linear layers for HQQ quantization. + Since each layer type can have separate quantization parameters, we need to do the following: + 1- tag each module with its neme via autoname_modules() + 2- Extract linear_tags (e.g. ['self_attn.q_proj', ...]) + 3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params + """ + + modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert + + # Add name to module + autoname_modules(model) + + # Get linear tags. This allows us to use different quant params to different layer types + linear_tags = get_linear_tags(model) + + # Convert quantization_config to layer-wise config + skip_modules = quantization_config.skip_modules + quant_config = quantization_config.to_dict() + linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert)) + + if any(key in linear_tags for key in quant_config.keys()): + # If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None) + patch_params = {key: None for key in linear_tags} + patch_params.update(quant_config) + else: + # Same quant_config for all layers + patch_params = {k: quant_config for k in linear_tags} + + model, has_been_replaced = _prepare_for_hqq_linear( + model, patch_params=patch_params, has_been_replaced=has_been_replaced + ) + + # We store quantization config as linear_tag -> hqq quant config + model.config.quantization_config = patch_params + + if not has_been_replaced: + logger.warning("No linear modules were found in your model for quantization.") + + return model diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py old mode 100644 new mode 100755 diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py old mode 100644 new mode 100755 index 4b20b32aa694d2..59b6bf80752053 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2659,6 +2659,8 @@ def get_memory_footprint(self, return_buffers=True): @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: + raise ValueError("`.cuda` is not supported for HQQ-quantized models.") # Checks if the model has been loaded in 8-bit if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( @@ -2670,6 +2672,8 @@ def cuda(self, *args, **kwargs): @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: + raise ValueError("`.to` is not supported for HQQ-quantized models.") # Checks if the model has been loaded in 8-bit if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( @@ -3739,6 +3743,13 @@ def from_pretrained( } if "skip_keys" in inspect.signature(dispatch_model).parameters: device_map_kwargs["skip_keys"] = model._skip_keys_device_placement + # For HQQ method we force-set the hooks for single GPU envs + if ( + "force_hooks" in inspect.signature(dispatch_model).parameters + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ + ): + device_map_kwargs["force_hooks"] = True if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): dispatch_model(model, **device_map_kwargs) diff --git a/src/transformers/quantizers/__init__.py b/src/transformers/quantizers/__init__.py old mode 100644 new mode 100755 diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py old mode 100644 new mode 100755 index cc58cd7af69ffb..2c65afa77e282c --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -21,6 +21,7 @@ BitsAndBytesConfig, EetqConfig, GPTQConfig, + HqqConfig, QuantizationConfigMixin, QuantizationMethod, QuantoConfig, @@ -31,6 +32,7 @@ from .quantizer_bnb_8bit import Bnb8BitHfQuantizer from .quantizer_eetq import EetqHfQuantizer from .quantizer_gptq import GptqHfQuantizer +from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer @@ -42,6 +44,7 @@ "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, "eetq": EetqHfQuantizer, + "hqq": HqqHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -52,6 +55,7 @@ "gptq": GPTQConfig, "aqlm": AqlmConfig, "quanto": QuantoConfig, + "hqq": HqqConfig, } diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py new file mode 100755 index 00000000000000..dd58c2c1bc5a27 --- /dev/null +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -0,0 +1,200 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List + +from ..integrations import prepare_for_hqq_linear +from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging +from .base import HfQuantizer +from .quantizers_utils import get_module_from_name + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + + +if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +# Finds the parent of a node module named "name" +def find_parent(model, name): + module_tree = name.split(".")[:-1] + parent = model + for m in module_tree: + parent = parent._modules[m] + return parent + + +class HqqHfQuantizer(HfQuantizer): + """ + HQQ quantizer base HF class. + nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading(). + The actual quantization and offloading to the GPU is done in check_quantized_param(). + """ + + use_keep_in_fp32_modules = False + requires_parameters_quantization = True + requires_calibration = False + required_packages = ["hqq"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.torch_dtype = None + self.using_multi_gpu = False + + def validate_environment(self, *args, **kwargs): + if not (is_hqq_available()): + raise ImportError( + "HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`" + ) + + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): + raise ValueError( + "Converting weights from tf/flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + + if self.torch_dtype is None: + if "torch_dtype" in kwargs: + self.torch_dtype = kwargs["torch_dtype"] + else: + self.torch_dtype = torch.float32 + logger.info("Setting torch_dtype to torch.float32 as the default value since it was not specified.") + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + raise ValueError( + "You are attempting to use an HQQ model with a device_map that contains a CPU or disk device." + " This is not supported. Please remove the CPU or disk device from the device_map." + ) + else: + self.using_multi_gpu = len(set(device_map.values())) > 1 + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + module, tensor_name = get_module_from_name(model, param_name) + + return isinstance(module, torch.nn.Linear) + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: List[str], + ): + """ + Each nn.Linear layer is processsed here. + We first check if the corresponding module state_dict contains already HQQ quantized parameters. + If not, we create a temp linear layer with the module state_dict params and use it for quantization + """ + + if is_hqq_available(): + from hqq.core.quantize import HQQLinear + + module, tensor_name = get_module_from_name(model, param_name) + + layer_name = param_name.replace(".weight", "").replace(".bias", "") + parent_module = find_parent(model, layer_name) + node = layer_name.split(".")[-1] + + # Step 0: set module state_dict + module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key} + + # Step 1: populate module with weight/bias from module state dict + for key in module_state_dict: + setattr(module, key, torch.nn.Parameter(module_state_dict[key])) + + # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module + # directly doesn't work. + + if hasattr(module, "quant_config"): + hqq_layer = HQQLinear( + module, + module.quant_config, + compute_dtype=self.torch_dtype, + device=target_device, + del_orig=True, + ) + + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + + if self.using_multi_gpu: + hqq_layer = self._patch_layer_for_multigpu(hqq_layer) + + setattr(parent_module, node, hqq_layer) + + else: + module = module.to(dtype=self.torch_dtype, device=target_device) + setattr(parent_module, node, module) + + torch.cuda.empty_cache() + + # Remove accelerate hook and uses a simpler forward pass. Otherwise, this breaks with multi-gpu + def _patch_layer_for_multigpu(self, hqq_layer): + hqq_layer = remove_hook_from_module(hqq_layer) + + def forward_with_device(self, x): + out = torch.matmul(x.to(self.device), self.dequantize().t()) + if self.bias is not None: + out += self.bias + return out + + hqq_layer.forward = lambda x: forward_with_device(hqq_layer, x) + return hqq_layer + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + device_map, + keep_in_fp32_modules: List[str] = None, + **kwargs, + ): + keep_in_fp32_modules = keep_in_fp32_modules if keep_in_fp32_modules is not None else [] + + # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param(). + # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config) + model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + model.is_hqq_quantized = True + model.is_hqq_serializable = self.is_serializable + return model + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self) -> bool: + return False diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py old mode 100644 new mode 100755 index e4ff991ed75c74..2bfa5638df9225 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -129,6 +129,7 @@ is_ftfy_available, is_g2p_en_available, is_galore_torch_available, + is_hqq_available, is_in_notebook, is_ipex_available, is_jieba_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py old mode 100644 new mode 100755 index c65d4122b787d4..158896347a7a6b --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -170,6 +170,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _torchdistx_available = _is_package_available("torchdistx") _torchvision_available = _is_package_available("torchvision") _mlx_available = _is_package_available("mlx") +_hqq_available = _is_package_available("hqq") _torch_version = "N/A" @@ -292,6 +293,10 @@ def is_torch_available(): return _torch_available +def is_hqq_available(): + return _hqq_available + + def get_torch_version(): return _torch_version diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py old mode 100644 new mode 100755 index 8374ddef81d583..f9e503cf862f18 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -24,7 +24,7 @@ from packaging import version -from ..utils import is_auto_awq_available, is_torch_available, logging +from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging if is_torch_available(): @@ -41,6 +41,7 @@ class QuantizationMethod(str, Enum): AQLM = "aqlm" QUANTO = "quanto" EETQ = "eetq" + HQQ = "hqq" class AWQLinearVersion(str, Enum): @@ -180,6 +181,115 @@ def update(self, **kwargs): return unused_kwargs +@dataclass +class HqqConfig(QuantizationConfigMixin): + """ + This is wrapper around hqq's BaseQuantizeConfig. + + Args: + nbits (`int`, *optional*, defaults to 4): + Number of bits. Supported values are (8, 4, 3, 2, 1). + group_size (`int`, *optional*, defaults to 64): + Group-size value. Supported values are any value that is divisble by weight.shape[axis]). + quant_zero (`bool`, *optional*, defaults to `True`): + Quantize the zero-point if set to `True`. + quant_scale (`bool`, *optional*, defaults to `False`): + Quantize the scaling if set to `True`. + offload_meta (`bool`, *optional*, defaults to `False`): + Offload the meta-data to the CPU if set to `True`. + view_as_float (`bool`, *optional*, defaults to `False`): + View the quantized weight as float (used in distributed training) if set to `True`. + axis (`int`, *optional*, defaults to 0): + Axis along which grouping is performed. Supported values are 0 or 1. + dynamic_config (dict, *optional*): + Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config. + If set, each layer specified by its id will use its dedicated quantization configuration. + skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`): + List of `nn.Linear` layers to skip. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + nbits: int = 4, + group_size: int = 64, + quant_zero: bool = True, + quant_scale: bool = False, + offload_meta: bool = False, + view_as_float: bool = False, + axis: int = 0, + dynamic_config: Optional[dict] = None, + skip_modules: List[str] = ["lm_head"], + **kwargs, + ): + if is_hqq_available(): + from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig + + if axis not in [0, 1]: + raise ValueError("Invalid axis value. Only 0 and 1 are allowed.") + + if dynamic_config is not None: + self.quant_config = {} + for key in dynamic_config: + self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key]) + else: + self.quant_config = HQQBaseQuantizeConfig( + **{ + "nbits": nbits, + "group_size": group_size, + "quant_zero": quant_zero, + "quant_scale": quant_scale, + "offload_meta": offload_meta, + "view_as_float": view_as_float, + "axis": axis, + } + ) + + self.quant_method = QuantizationMethod.HQQ + self.skip_modules = skip_modules + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + pass + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return self.quant_config + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = HqqConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + @dataclass class BitsAndBytesConfig(QuantizationConfigMixin): """ diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py new file mode 100755 index 00000000000000..e4e01f86496388 --- /dev/null +++ b/tests/quantization/hqq/test_hqq.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig +from transformers.testing_utils import ( + require_accelerate, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) +from transformers.utils import is_hqq_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_hqq_available(): + from hqq.core.quantize import HQQBackend, HQQLinear + + +class HQQLLMRunner: + def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir): + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=compute_dtype, + device_map=device, + quantization_config=quant_config, + low_cpu_mem_usage=True, + cache_dir=cache_dir, + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir) + self.device = self.model.device + HQQLinear.set_backend(HQQBackend.PYTORCH) + + +def cleanup(): + torch.cuda.empty_cache() + gc.collect() + + +def check_hqqlayer(test_module, hqq_layer, batch_size=1, context_size=1024): + # Test HQQ layer + W_dequant = hqq_layer.dequantize() # Reconstructed weights + inputs = ( + torch.randn( + (batch_size, context_size, hqq_layer.meta["shape"][1]), + device=hqq_layer.device, + dtype=hqq_layer.compute_dtype, + ) + / 10.0 + ) + with torch.no_grad(): + outputs = hqq_layer(inputs) + test_module.assertEqual(outputs.shape[-1], W_dequant.shape[0]) + test_module.assertEqual(outputs.dtype, hqq_layer.compute_dtype) + del W_dequant, inputs, outputs + cleanup() + + +def check_forward(test_module, model, batch_size=1, context_size=1024): + # Test forward pass + with torch.no_grad(): + out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits + test_module.assertEqual(out.shape[0], batch_size) + test_module.assertEqual(out.shape[1], context_size) + cleanup() + + +MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + +@require_torch_gpu +class HqqConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = HqqConfig() + hqq_orig_config = quantization_config.to_dict() + + for key in hqq_orig_config: + self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key]) + + +@slow +@require_torch_gpu +@require_accelerate +class HQQTest(unittest.TestCase): + def tearDown(self): + cleanup() + + def test_fp16_quantized_model(self): + """ + Simple LLM model testing fp16 + """ + quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) + + hqq_runner = HQQLLMRunner( + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device + ) + + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model) + + def test_bfp16_quantized_model_with_offloading(self): + """ + Simple LLM model testing bfp16 with meta-data offloading + """ + q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False} + q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True} + quant_config = HqqConfig( + dynamic_config={ + "self_attn.q_proj": q4_config, + "self_attn.k_proj": q4_config, + "self_attn.v_proj": q4_config, + "self_attn.o_proj": q4_config, + "mlp.gate_proj": q3_config, + "mlp.up_proj": q3_config, + "mlp.down_proj": q3_config, + } + ) + + hqq_runner = HQQLLMRunner( + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device + ) + + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model) + + +@slow +@require_torch_gpu +@require_torch_multi_gpu +@require_accelerate +class HQQTestMultiGPU(unittest.TestCase): + def tearDown(self): + cleanup() + + def test_fp16_quantized_model_multipgpu(self): + """ + Simple LLM model testing fp16 with multi-gpu + """ + + quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) + + hqq_runner = HQQLLMRunner( + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto" + ) + + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model)