From 59952994c422ae9fc17f8a3134c00396c948c42e Mon Sep 17 00:00:00 2001
From: mobicham <37179323+mobicham@users.noreply.github.com>
Date: Thu, 2 May 2024 18:51:49 +0200
Subject: [PATCH] Add HQQ quantization support (#29637)
* update HQQ transformers integration
* push import_utils.py
* add force_hooks check in modeling_utils.py
* fix | with Optional
* force bias as param
* check bias is Tensor
* force forward for multi-gpu
* review fixes pass
* remove torch grad()
* if any key in linear_tags fix
* add cpu/disk check
* isinstance return
* add multigpu test + refactor tests
* clean hqq_utils imports in hqq.py
* clean hqq_utils imports in quantizer_hqq.py
* delete hqq_utils.py
* Delete src/transformers/utils/hqq_utils.py
* ruff init
* remove torch.float16 from __init__ in test
* refactor test
* isinstance -> type in quantizer_hqq.py
* cpu/disk device_map check in quantizer_hqq.py
* remove type(module) nn.linear check in quantizer_hqq.py
* add BaseQuantizeConfig import inside HqqConfig init
* remove hqq import in hqq.py
* remove accelerate import from test_hqq.py
* quant config.py doc update
* add hqqconfig to main_classes doc
* make style
* __init__ fix
* ruff __init__
* skip_modules list
* hqqconfig format fix
* hqqconfig doc fix
* hqqconfig doc fix
* hqqconfig doc fix
* hqqconfig doc fix
* hqqconfig doc fix
* hqqconfig doc fix
* hqqconfig doc fix
* hqqconfig doc fix
* hqqconfig doc fix
* test_hqq.py remove mistral comment
* remove self.using_multi_gpu is False
* torch_dtype default val set and logger.info
* hqq.py isinstance fix
* remove torch=None
* torch_device test_hqq
* rename test_hqq
* MODEL_ID in test_hqq
* quantizer_hqq setattr fix
* quantizer_hqq typo fix
* imports quantizer_hqq.py
* isinstance quantizer_hqq
* hqq_layer.bias reformat quantizer_hqq
* Step 2 as comment in quantizer_hqq
* prepare_for_hqq_linear() comment
* keep_in_fp32_modules fix
* HqqHfQuantizer reformat
* quantization.md hqqconfig
* quantization.md model example reformat
* quantization.md # space
* quantization.md space })
* quantization.md space })
* quantization_config fix doc
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* axis value check in quantization_config
* format
* dynamic config explanation
* quant config method in quantization.md
* remove shard-level progress
* .cuda fix modeling_utils
* test_hqq fixes
* make fix-copies
---------
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
---
.../Dockerfile | 3 +
docs/source/en/main_classes/quantization.md | 4 +
docs/source/en/quantization.md | 50 +++++
docs/source/en/quicktour.md | 0
src/transformers/__init__.py | 2 +
src/transformers/integrations/__init__.py | 2 +
src/transformers/integrations/hqq.py | 121 +++++++++++
.../integrations/integration_utils.py | 0
src/transformers/modeling_utils.py | 11 +
src/transformers/quantizers/__init__.py | 0
src/transformers/quantizers/auto.py | 4 +
src/transformers/quantizers/quantizer_hqq.py | 200 ++++++++++++++++++
src/transformers/utils/__init__.py | 1 +
src/transformers/utils/import_utils.py | 5 +
src/transformers/utils/quantization_config.py | 112 +++++++++-
tests/quantization/hqq/test_hqq.py | 167 +++++++++++++++
16 files changed, 681 insertions(+), 1 deletion(-)
mode change 100644 => 100755 docker/transformers-quantization-latest-gpu/Dockerfile
mode change 100644 => 100755 docs/source/en/main_classes/quantization.md
mode change 100644 => 100755 docs/source/en/quantization.md
mode change 100644 => 100755 docs/source/en/quicktour.md
mode change 100644 => 100755 src/transformers/__init__.py
mode change 100644 => 100755 src/transformers/integrations/__init__.py
create mode 100755 src/transformers/integrations/hqq.py
mode change 100644 => 100755 src/transformers/integrations/integration_utils.py
mode change 100644 => 100755 src/transformers/modeling_utils.py
mode change 100644 => 100755 src/transformers/quantizers/__init__.py
mode change 100644 => 100755 src/transformers/quantizers/auto.py
create mode 100755 src/transformers/quantizers/quantizer_hqq.py
mode change 100644 => 100755 src/transformers/utils/__init__.py
mode change 100644 => 100755 src/transformers/utils/import_utils.py
mode change 100644 => 100755 src/transformers/utils/quantization_config.py
create mode 100755 tests/quantization/hqq/test_hqq.py
diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile
old mode 100644
new mode 100755
index 08bc3c45b952db..47fcd11fd766d7
--- a/docker/transformers-quantization-latest-gpu/Dockerfile
+++ b/docker/transformers-quantization-latest-gpu/Dockerfile
@@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
# Add aqlm for quantization testing
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
+# Add hqq for quantization testing
+RUN python3 -m pip install --no-cache-dir hqq
+
# Add autoawq for quantization testing
# >=v0.2.3 needed for compatibility with torch 2.2.1
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl
diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md
old mode 100644
new mode 100755
index 91de5fc8a33ce1..f1e2acdcfe4809
--- a/docs/source/en/main_classes/quantization.md
+++ b/docs/source/en/main_classes/quantization.md
@@ -52,3 +52,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## HfQuantizer
[[autodoc]] quantizers.base.HfQuantizer
+
+## HqqConfig
+
+[[autodoc]] HqqConfig
diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md
old mode 100644
new mode 100755
index 8a3650a8439040..ae4f44f6b800b7
--- a/docs/source/en/quantization.md
+++ b/docs/source/en/quantization.md
@@ -745,3 +745,53 @@ The speed and throughput of fused and unfused modules were also tested with the
generate throughput/batch size
+
+## HQQ
+Half-Quadratic Quantization (HQQ) implements on-the-fly quantization via fast robust optimization. It doesn't require calibration data and can be used to quantize any model.
+Please refer to the official package for more details.
+
+For installation, we recommend you use the following approach to get the latest version and build its corresponding CUDA kernels:
+```
+pip install hqq
+```
+
+To quantize a model, you need to create an [`HqqConfig`]. There are two ways of doing it:
+``` Python
+from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
+
+# Method 1: all linear layers will use the same quantization config
+quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
+```
+
+``` Python
+# Method 2: each linear layer with the same tag will use a dedicated quantization config
+q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
+q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
+quant_config = HqqConfig(dynamic_config={
+ 'self_attn.q_proj':q4_config,
+ 'self_attn.k_proj':q4_config,
+ 'self_attn.v_proj':q4_config,
+ 'self_attn.o_proj':q4_config,
+
+ 'mlp.gate_proj':q3_config,
+ 'mlp.up_proj' :q3_config,
+ 'mlp.down_proj':q3_config,
+})
+```
+
+The second approach is especially interesting for quantizing Mixture-of-Experts (MoEs) because the experts are less affected by lower quantization settings.
+
+
+Then you simply quantize the model as follows
+``` Python
+model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.float16,
+ device_map="cuda",
+ quantization_config=quant_config
+)
+```
+### Optimized Runtime
+HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training.
+For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090.
+For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend
\ No newline at end of file
diff --git a/docs/source/en/quicktour.md b/docs/source/en/quicktour.md
old mode 100644
new mode 100755
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
old mode 100644
new mode 100755
index 53a087468e66a0..12f1821df32f91
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1133,6 +1133,7 @@
"BitsAndBytesConfig",
"EetqConfig",
"GPTQConfig",
+ "HqqConfig",
"QuantoConfig",
],
}
@@ -6099,6 +6100,7 @@
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
+ HqqConfig,
QuantoConfig,
)
diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py
old mode 100644
new mode 100755
index 72fdf3e1bbb997..69fb0e3259b1d5
--- a/src/transformers/integrations/__init__.py
+++ b/src/transformers/integrations/__init__.py
@@ -43,6 +43,7 @@
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
+ "hqq": ["prepare_for_hqq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
"AzureMLCallback",
@@ -113,6 +114,7 @@
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
+ from .hqq import prepare_for_hqq_linear
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
AzureMLCallback,
diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py
new file mode 100755
index 00000000000000..10a6d06a3f9f0b
--- /dev/null
+++ b/src/transformers/integrations/hqq.py
@@ -0,0 +1,121 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"HQQ (Half-Quadratic Quantization) integration file"
+
+from ..utils import is_hqq_available, is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+# Name all modules inside the model
+def autoname_modules(model):
+ for name, module in model.named_modules():
+ module.name = name
+
+
+# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj
+def name_to_linear_tag(name):
+ return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))])
+
+
+# Get all linear tags available
+def get_linear_tags(model):
+ if is_hqq_available():
+ from hqq.core.quantize import HQQLinear
+
+ linear_tags = set()
+ for name, module in model.named_modules():
+ if isinstance(module, (torch.nn.Linear, HQQLinear)):
+ linear_tags.add(name_to_linear_tag(name))
+ return list(linear_tags)
+
+
+def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None):
+ for name, module in model.named_children():
+ if current_key_name is None:
+ current_key_name = []
+ current_key_name.append(name)
+
+ if isinstance(module, torch.nn.Linear):
+ # Get linear tag
+ linear_tag = name_to_linear_tag(module.name)
+
+ # We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param()
+ if linear_tag in patch_params:
+ if patch_params[linear_tag] is not None:
+ model._modules[name].quant_config = patch_params[linear_tag]
+ # Store the module class in case we need to transpose the weight later
+ model._modules[name].source_cls = type(module)
+ # Force requires grad to False to avoid unexpected errors
+ model._modules[name].requires_grad_(False)
+
+ has_been_replaced = True
+
+ if len(list(module.children())) > 0:
+ _, has_been_replaced = _prepare_for_hqq_linear(
+ module,
+ patch_params=patch_params,
+ has_been_replaced=has_been_replaced,
+ )
+ # Remove the last key for recursion
+ current_key_name.pop(-1)
+
+ return model, has_been_replaced
+
+
+def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False):
+ """
+ Prepares nn.Linear layers for HQQ quantization.
+ Since each layer type can have separate quantization parameters, we need to do the following:
+ 1- tag each module with its neme via autoname_modules()
+ 2- Extract linear_tags (e.g. ['self_attn.q_proj', ...])
+ 3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params
+ """
+
+ modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
+
+ # Add name to module
+ autoname_modules(model)
+
+ # Get linear tags. This allows us to use different quant params to different layer types
+ linear_tags = get_linear_tags(model)
+
+ # Convert quantization_config to layer-wise config
+ skip_modules = quantization_config.skip_modules
+ quant_config = quantization_config.to_dict()
+ linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
+
+ if any(key in linear_tags for key in quant_config.keys()):
+ # If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None)
+ patch_params = {key: None for key in linear_tags}
+ patch_params.update(quant_config)
+ else:
+ # Same quant_config for all layers
+ patch_params = {k: quant_config for k in linear_tags}
+
+ model, has_been_replaced = _prepare_for_hqq_linear(
+ model, patch_params=patch_params, has_been_replaced=has_been_replaced
+ )
+
+ # We store quantization config as linear_tag -> hqq quant config
+ model.config.quantization_config = patch_params
+
+ if not has_been_replaced:
+ logger.warning("No linear modules were found in your model for quantization.")
+
+ return model
diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py
old mode 100644
new mode 100755
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
old mode 100644
new mode 100755
index 4b20b32aa694d2..59b6bf80752053
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -2659,6 +2659,8 @@ def get_memory_footprint(self, return_buffers=True):
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
+ if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
+ raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
@@ -2670,6 +2672,8 @@ def cuda(self, *args, **kwargs):
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
+ if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
+ raise ValueError("`.to` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
@@ -3739,6 +3743,13 @@ def from_pretrained(
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
+ # For HQQ method we force-set the hooks for single GPU envs
+ if (
+ "force_hooks" in inspect.signature(dispatch_model).parameters
+ and hf_quantizer is not None
+ and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
+ ):
+ device_map_kwargs["force_hooks"] = True
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)
diff --git a/src/transformers/quantizers/__init__.py b/src/transformers/quantizers/__init__.py
old mode 100644
new mode 100755
diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py
old mode 100644
new mode 100755
index cc58cd7af69ffb..2c65afa77e282c
--- a/src/transformers/quantizers/auto.py
+++ b/src/transformers/quantizers/auto.py
@@ -21,6 +21,7 @@
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
+ HqqConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@@ -31,6 +32,7 @@
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
+from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
@@ -42,6 +44,7 @@
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
+ "hqq": HqqHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -52,6 +55,7 @@
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
+ "hqq": HqqConfig,
}
diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py
new file mode 100755
index 00000000000000..dd58c2c1bc5a27
--- /dev/null
+++ b/src/transformers/quantizers/quantizer_hqq.py
@@ -0,0 +1,200 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, List
+
+from ..integrations import prepare_for_hqq_linear
+from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging
+from .base import HfQuantizer
+from .quantizers_utils import get_module_from_name
+
+
+if TYPE_CHECKING:
+ from ..modeling_utils import PreTrainedModel
+
+
+if is_accelerate_available():
+ from accelerate.hooks import remove_hook_from_module
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+# Finds the parent of a node module named "name"
+def find_parent(model, name):
+ module_tree = name.split(".")[:-1]
+ parent = model
+ for m in module_tree:
+ parent = parent._modules[m]
+ return parent
+
+
+class HqqHfQuantizer(HfQuantizer):
+ """
+ HQQ quantizer base HF class.
+ nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading().
+ The actual quantization and offloading to the GPU is done in check_quantized_param().
+ """
+
+ use_keep_in_fp32_modules = False
+ requires_parameters_quantization = True
+ requires_calibration = False
+ required_packages = ["hqq"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+ self.torch_dtype = None
+ self.using_multi_gpu = False
+
+ def validate_environment(self, *args, **kwargs):
+ if not (is_hqq_available()):
+ raise ImportError(
+ "HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`"
+ )
+
+ if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
+ raise ValueError(
+ "Converting weights from tf/flax weights is currently not supported, please make"
+ " sure the weights are in PyTorch format."
+ )
+
+ if not torch.cuda.is_available():
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
+
+ if self.torch_dtype is None:
+ if "torch_dtype" in kwargs:
+ self.torch_dtype = kwargs["torch_dtype"]
+ else:
+ self.torch_dtype = torch.float32
+ logger.info("Setting torch_dtype to torch.float32 as the default value since it was not specified.")
+
+ device_map = kwargs.get("device_map", None)
+ if isinstance(device_map, dict):
+ if "cpu" in device_map.values() or "disk" in device_map.values():
+ raise ValueError(
+ "You are attempting to use an HQQ model with a device_map that contains a CPU or disk device."
+ " This is not supported. Please remove the CPU or disk device from the device_map."
+ )
+ else:
+ self.using_multi_gpu = len(set(device_map.values())) > 1
+
+ def check_quantized_param(
+ self,
+ model: "PreTrainedModel",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ) -> bool:
+ module, tensor_name = get_module_from_name(model, param_name)
+
+ return isinstance(module, torch.nn.Linear)
+
+ def create_quantized_param(
+ self,
+ model: "PreTrainedModel",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Dict[str, Any],
+ unexpected_keys: List[str],
+ ):
+ """
+ Each nn.Linear layer is processsed here.
+ We first check if the corresponding module state_dict contains already HQQ quantized parameters.
+ If not, we create a temp linear layer with the module state_dict params and use it for quantization
+ """
+
+ if is_hqq_available():
+ from hqq.core.quantize import HQQLinear
+
+ module, tensor_name = get_module_from_name(model, param_name)
+
+ layer_name = param_name.replace(".weight", "").replace(".bias", "")
+ parent_module = find_parent(model, layer_name)
+ node = layer_name.split(".")[-1]
+
+ # Step 0: set module state_dict
+ module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
+
+ # Step 1: populate module with weight/bias from module state dict
+ for key in module_state_dict:
+ setattr(module, key, torch.nn.Parameter(module_state_dict[key]))
+
+ # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
+ # directly doesn't work.
+
+ if hasattr(module, "quant_config"):
+ hqq_layer = HQQLinear(
+ module,
+ module.quant_config,
+ compute_dtype=self.torch_dtype,
+ device=target_device,
+ del_orig=True,
+ )
+
+ if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
+ hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
+
+ if self.using_multi_gpu:
+ hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
+
+ setattr(parent_module, node, hqq_layer)
+
+ else:
+ module = module.to(dtype=self.torch_dtype, device=target_device)
+ setattr(parent_module, node, module)
+
+ torch.cuda.empty_cache()
+
+ # Remove accelerate hook and uses a simpler forward pass. Otherwise, this breaks with multi-gpu
+ def _patch_layer_for_multigpu(self, hqq_layer):
+ hqq_layer = remove_hook_from_module(hqq_layer)
+
+ def forward_with_device(self, x):
+ out = torch.matmul(x.to(self.device), self.dequantize().t())
+ if self.bias is not None:
+ out += self.bias
+ return out
+
+ hqq_layer.forward = lambda x: forward_with_device(hqq_layer, x)
+ return hqq_layer
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "PreTrainedModel",
+ device_map,
+ keep_in_fp32_modules: List[str] = None,
+ **kwargs,
+ ):
+ keep_in_fp32_modules = keep_in_fp32_modules if keep_in_fp32_modules is not None else []
+
+ # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param().
+ # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config)
+ model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)
+
+ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ model.is_hqq_quantized = True
+ model.is_hqq_serializable = self.is_serializable
+ return model
+
+ @property
+ def is_serializable(self):
+ return False
+
+ @property
+ def is_trainable(self) -> bool:
+ return False
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
old mode 100644
new mode 100755
index e4ff991ed75c74..2bfa5638df9225
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -129,6 +129,7 @@
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
+ is_hqq_available,
is_in_notebook,
is_ipex_available,
is_jieba_available,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
old mode 100644
new mode 100755
index c65d4122b787d4..158896347a7a6b
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -170,6 +170,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
+_hqq_available = _is_package_available("hqq")
_torch_version = "N/A"
@@ -292,6 +293,10 @@ def is_torch_available():
return _torch_available
+def is_hqq_available():
+ return _hqq_available
+
+
def get_torch_version():
return _torch_version
diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py
old mode 100644
new mode 100755
index 8374ddef81d583..f9e503cf862f18
--- a/src/transformers/utils/quantization_config.py
+++ b/src/transformers/utils/quantization_config.py
@@ -24,7 +24,7 @@
from packaging import version
-from ..utils import is_auto_awq_available, is_torch_available, logging
+from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging
if is_torch_available():
@@ -41,6 +41,7 @@ class QuantizationMethod(str, Enum):
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
+ HQQ = "hqq"
class AWQLinearVersion(str, Enum):
@@ -180,6 +181,115 @@ def update(self, **kwargs):
return unused_kwargs
+@dataclass
+class HqqConfig(QuantizationConfigMixin):
+ """
+ This is wrapper around hqq's BaseQuantizeConfig.
+
+ Args:
+ nbits (`int`, *optional*, defaults to 4):
+ Number of bits. Supported values are (8, 4, 3, 2, 1).
+ group_size (`int`, *optional*, defaults to 64):
+ Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
+ quant_zero (`bool`, *optional*, defaults to `True`):
+ Quantize the zero-point if set to `True`.
+ quant_scale (`bool`, *optional*, defaults to `False`):
+ Quantize the scaling if set to `True`.
+ offload_meta (`bool`, *optional*, defaults to `False`):
+ Offload the meta-data to the CPU if set to `True`.
+ view_as_float (`bool`, *optional*, defaults to `False`):
+ View the quantized weight as float (used in distributed training) if set to `True`.
+ axis (`int`, *optional*, defaults to 0):
+ Axis along which grouping is performed. Supported values are 0 or 1.
+ dynamic_config (dict, *optional*):
+ Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
+ If set, each layer specified by its id will use its dedicated quantization configuration.
+ skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`):
+ List of `nn.Linear` layers to skip.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional parameters from which to initialize the configuration object.
+ """
+
+ def __init__(
+ self,
+ nbits: int = 4,
+ group_size: int = 64,
+ quant_zero: bool = True,
+ quant_scale: bool = False,
+ offload_meta: bool = False,
+ view_as_float: bool = False,
+ axis: int = 0,
+ dynamic_config: Optional[dict] = None,
+ skip_modules: List[str] = ["lm_head"],
+ **kwargs,
+ ):
+ if is_hqq_available():
+ from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
+
+ if axis not in [0, 1]:
+ raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
+
+ if dynamic_config is not None:
+ self.quant_config = {}
+ for key in dynamic_config:
+ self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key])
+ else:
+ self.quant_config = HQQBaseQuantizeConfig(
+ **{
+ "nbits": nbits,
+ "group_size": group_size,
+ "quant_zero": quant_zero,
+ "quant_scale": quant_scale,
+ "offload_meta": offload_meta,
+ "view_as_float": view_as_float,
+ "axis": axis,
+ }
+ )
+
+ self.quant_method = QuantizationMethod.HQQ
+ self.skip_modules = skip_modules
+
+ self.post_init()
+
+ def post_init(self):
+ r"""
+ Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
+ """
+ pass
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ return self.quant_config
+
+ def __repr__(self):
+ config_dict = self.to_dict()
+ return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
+
+ def to_diff_dict(self) -> Dict[str, Any]:
+ """
+ Removes all attributes from config which correspond to the default config attributes for better readability and
+ serializes to a Python dictionary.
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ config_dict = self.to_dict()
+
+ # get the default config dict
+ default_config_dict = HqqConfig().to_dict()
+
+ serializable_config_dict = {}
+
+ # only serialize values that differ from the default config
+ for key, value in config_dict.items():
+ if value != default_config_dict[key]:
+ serializable_config_dict[key] = value
+
+ return serializable_config_dict
+
+
@dataclass
class BitsAndBytesConfig(QuantizationConfigMixin):
"""
diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py
new file mode 100755
index 00000000000000..e4e01f86496388
--- /dev/null
+++ b/tests/quantization/hqq/test_hqq.py
@@ -0,0 +1,167 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
+from transformers.testing_utils import (
+ require_accelerate,
+ require_torch_gpu,
+ require_torch_multi_gpu,
+ slow,
+ torch_device,
+)
+from transformers.utils import is_hqq_available, is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+if is_hqq_available():
+ from hqq.core.quantize import HQQBackend, HQQLinear
+
+
+class HQQLLMRunner:
+ def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir):
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=compute_dtype,
+ device_map=device,
+ quantization_config=quant_config,
+ low_cpu_mem_usage=True,
+ cache_dir=cache_dir,
+ )
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
+ self.device = self.model.device
+ HQQLinear.set_backend(HQQBackend.PYTORCH)
+
+
+def cleanup():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+def check_hqqlayer(test_module, hqq_layer, batch_size=1, context_size=1024):
+ # Test HQQ layer
+ W_dequant = hqq_layer.dequantize() # Reconstructed weights
+ inputs = (
+ torch.randn(
+ (batch_size, context_size, hqq_layer.meta["shape"][1]),
+ device=hqq_layer.device,
+ dtype=hqq_layer.compute_dtype,
+ )
+ / 10.0
+ )
+ with torch.no_grad():
+ outputs = hqq_layer(inputs)
+ test_module.assertEqual(outputs.shape[-1], W_dequant.shape[0])
+ test_module.assertEqual(outputs.dtype, hqq_layer.compute_dtype)
+ del W_dequant, inputs, outputs
+ cleanup()
+
+
+def check_forward(test_module, model, batch_size=1, context_size=1024):
+ # Test forward pass
+ with torch.no_grad():
+ out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits
+ test_module.assertEqual(out.shape[0], batch_size)
+ test_module.assertEqual(out.shape[1], context_size)
+ cleanup()
+
+
+MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
+
+
+@require_torch_gpu
+class HqqConfigTest(unittest.TestCase):
+ def test_to_dict(self):
+ """
+ Makes sure the config format is properly set
+ """
+ quantization_config = HqqConfig()
+ hqq_orig_config = quantization_config.to_dict()
+
+ for key in hqq_orig_config:
+ self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key])
+
+
+@slow
+@require_torch_gpu
+@require_accelerate
+class HQQTest(unittest.TestCase):
+ def tearDown(self):
+ cleanup()
+
+ def test_fp16_quantized_model(self):
+ """
+ Simple LLM model testing fp16
+ """
+ quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
+
+ hqq_runner = HQQLLMRunner(
+ model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
+ )
+
+ check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
+ check_forward(self, hqq_runner.model)
+
+ def test_bfp16_quantized_model_with_offloading(self):
+ """
+ Simple LLM model testing bfp16 with meta-data offloading
+ """
+ q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False}
+ q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True}
+ quant_config = HqqConfig(
+ dynamic_config={
+ "self_attn.q_proj": q4_config,
+ "self_attn.k_proj": q4_config,
+ "self_attn.v_proj": q4_config,
+ "self_attn.o_proj": q4_config,
+ "mlp.gate_proj": q3_config,
+ "mlp.up_proj": q3_config,
+ "mlp.down_proj": q3_config,
+ }
+ )
+
+ hqq_runner = HQQLLMRunner(
+ model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device
+ )
+
+ check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
+ check_forward(self, hqq_runner.model)
+
+
+@slow
+@require_torch_gpu
+@require_torch_multi_gpu
+@require_accelerate
+class HQQTestMultiGPU(unittest.TestCase):
+ def tearDown(self):
+ cleanup()
+
+ def test_fp16_quantized_model_multipgpu(self):
+ """
+ Simple LLM model testing fp16 with multi-gpu
+ """
+
+ quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
+
+ hqq_runner = HQQLLMRunner(
+ model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
+ )
+
+ check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
+ check_forward(self, hqq_runner.model)