Skip to content

Commit

Permalink
Add HQQ quantization support (#29637)
Browse files Browse the repository at this point in the history
* update HQQ transformers integration

* push import_utils.py

* add force_hooks check in modeling_utils.py

* fix | with Optional

* force bias as param

* check bias is Tensor

* force forward for multi-gpu

* review fixes pass

* remove torch grad()

* if any key in linear_tags fix

* add cpu/disk check

* isinstance return

* add multigpu test + refactor tests

* clean hqq_utils imports in hqq.py

* clean hqq_utils imports in quantizer_hqq.py

* delete hqq_utils.py

* Delete src/transformers/utils/hqq_utils.py

* ruff init

* remove torch.float16 from __init__ in test

* refactor test

* isinstance -> type in quantizer_hqq.py

* cpu/disk device_map check in quantizer_hqq.py

* remove type(module) nn.linear check in quantizer_hqq.py

* add BaseQuantizeConfig import inside HqqConfig init

* remove hqq import in hqq.py

* remove accelerate import from test_hqq.py

* quant config.py doc update

* add hqqconfig to main_classes doc

* make style

* __init__ fix

* ruff __init__

* skip_modules list

* hqqconfig format fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* test_hqq.py remove mistral comment

* remove self.using_multi_gpu is False

* torch_dtype default val set and logger.info

* hqq.py isinstance fix

* remove torch=None

* torch_device test_hqq

* rename test_hqq

* MODEL_ID in test_hqq

* quantizer_hqq setattr fix

* quantizer_hqq typo fix

* imports quantizer_hqq.py

* isinstance quantizer_hqq

* hqq_layer.bias reformat quantizer_hqq

* Step 2 as comment in quantizer_hqq

* prepare_for_hqq_linear() comment

* keep_in_fp32_modules fix

* HqqHfQuantizer reformat

* quantization.md hqqconfig

* quantization.md model example reformat

* quantization.md # space

* quantization.md space   })

* quantization.md space   })

* quantization_config fix doc

Co-authored-by: amyeroberts <[email protected]>

* axis value check in quantization_config

* format

* dynamic config explanation

* quant config method in quantization.md

* remove shard-level progress

* .cuda fix modeling_utils

* test_hqq fixes

* make fix-copies

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
mobicham and amyeroberts authored May 2, 2024
1 parent 4c94093 commit 5995299
Show file tree
Hide file tree
Showing 16 changed files with 681 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docker/transformers-quantization-latest-gpu/Dockerfile
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
# Add aqlm for quantization testing
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2

# Add hqq for quantization testing
RUN python3 -m pip install --no-cache-dir hqq

# Add autoawq for quantization testing
# >=v0.2.3 needed for compatibility with torch 2.2.1
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## HfQuantizer

[[autodoc]] quantizers.base.HfQuantizer

## HqqConfig

[[autodoc]] HqqConfig
50 changes: 50 additions & 0 deletions docs/source/en/quantization.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,53 @@ The speed and throughput of fused and unfused modules were also tested with the
<figcaption class="mt-2 text-center text-sm text-gray-500">generate throughput/batch size</figcaption>
</div>
</div>

## HQQ
Half-Quadratic Quantization (HQQ) implements on-the-fly quantization via fast robust optimization. It doesn't require calibration data and can be used to quantize any model.
Please refer to the <a href="https://github.com/mobiusml/hqq/">official package</a> for more details.

For installation, we recommend you use the following approach to get the latest version and build its corresponding CUDA kernels:
```
pip install hqq
```

To quantize a model, you need to create an [`HqqConfig`]. There are two ways of doing it:
``` Python
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

# Method 1: all linear layers will use the same quantization config
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
```

``` Python
# Method 2: each linear layer with the same tag will use a dedicated quantization config
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
'self_attn.v_proj':q4_config,
'self_attn.o_proj':q4_config,

'mlp.gate_proj':q3_config,
'mlp.up_proj' :q3_config,
'mlp.down_proj':q3_config,
})
```

The second approach is especially interesting for quantizing Mixture-of-Experts (MoEs) because the experts are less affected by lower quantization settings.


Then you simply quantize the model as follows
``` Python
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
quantization_config=quant_config
)
```
### Optimized Runtime
HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training.
For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090.
For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend
Empty file modified docs/source/en/quicktour.md
100644 → 100755
Empty file.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,7 @@
"BitsAndBytesConfig",
"EetqConfig",
"GPTQConfig",
"HqqConfig",
"QuantoConfig",
],
}
Expand Down Expand Up @@ -6099,6 +6100,7 @@
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
HqqConfig,
QuantoConfig,
)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"hqq": ["prepare_for_hqq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
"AzureMLCallback",
Expand Down Expand Up @@ -113,6 +114,7 @@
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .hqq import prepare_for_hqq_linear
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
AzureMLCallback,
Expand Down
121 changes: 121 additions & 0 deletions src/transformers/integrations/hqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"HQQ (Half-Quadratic Quantization) integration file"

from ..utils import is_hqq_available, is_torch_available, logging


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


# Name all modules inside the model
def autoname_modules(model):
for name, module in model.named_modules():
module.name = name


# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj
def name_to_linear_tag(name):
return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))])


# Get all linear tags available
def get_linear_tags(model):
if is_hqq_available():
from hqq.core.quantize import HQQLinear

linear_tags = set()
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, HQQLinear)):
linear_tags.add(name_to_linear_tag(name))
return list(linear_tags)


def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None):
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, torch.nn.Linear):
# Get linear tag
linear_tag = name_to_linear_tag(module.name)

# We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param()
if linear_tag in patch_params:
if patch_params[linear_tag] is not None:
model._modules[name].quant_config = patch_params[linear_tag]
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)

has_been_replaced = True

if len(list(module.children())) > 0:
_, has_been_replaced = _prepare_for_hqq_linear(
module,
patch_params=patch_params,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)

return model, has_been_replaced


def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False):
"""
Prepares nn.Linear layers for HQQ quantization.
Since each layer type can have separate quantization parameters, we need to do the following:
1- tag each module with its neme via autoname_modules()
2- Extract linear_tags (e.g. ['self_attn.q_proj', ...])
3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params
"""

modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert

# Add name to module
autoname_modules(model)

# Get linear tags. This allows us to use different quant params to different layer types
linear_tags = get_linear_tags(model)

# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))

if any(key in linear_tags for key in quant_config.keys()):
# If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None)
patch_params = {key: None for key in linear_tags}
patch_params.update(quant_config)
else:
# Same quant_config for all layers
patch_params = {k: quant_config for k in linear_tags}

model, has_been_replaced = _prepare_for_hqq_linear(
model, patch_params=patch_params, has_been_replaced=has_been_replaced
)

# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params

if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")

return model
Empty file modified src/transformers/integrations/integration_utils.py
100644 → 100755
Empty file.
11 changes: 11 additions & 0 deletions src/transformers/modeling_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2659,6 +2659,8 @@ def get_memory_footprint(self, return_buffers=True):

@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
Expand All @@ -2670,6 +2672,8 @@ def cuda(self, *args, **kwargs):

@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
Expand Down Expand Up @@ -3739,6 +3743,13 @@ def from_pretrained(
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
# For HQQ method we force-set the hooks for single GPU envs
if (
"force_hooks" in inspect.signature(dispatch_model).parameters
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
):
device_map_kwargs["force_hooks"] = True
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)

Expand Down
Empty file modified src/transformers/quantizers/__init__.py
100644 → 100755
Empty file.
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
HqqConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
Expand All @@ -31,6 +32,7 @@
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer


Expand All @@ -42,6 +44,7 @@
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -52,6 +55,7 @@
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"hqq": HqqConfig,
}


Expand Down
Loading

0 comments on commit 5995299

Please sign in to comment.