Skip to content

Commit

Permalink
Add save/load support for HQQ (#1913)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: chen, suyue <[email protected]>
  • Loading branch information
yiliu30 and chensuyue authored Jul 15, 2024
1 parent d320460 commit 34f0a9f
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 3 deletions.
60 changes: 59 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/hqq/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# NOTICE: the original `Quantizer` has been modified to `HQQTensorHandle`
# and `QTensor` to decouple the data structure and the quantization logic.

from typing import Any, Dict, Tuple
from typing import Any, Dict, Mapping, Tuple

import torch

Expand Down Expand Up @@ -278,3 +278,61 @@ def from_float(
# !!! Delete the float explicitly to save memory
del float_module
return new_mod

def state_dict(self, *args, **kwargs): # nn.Module override compatible
state_dict = self.q_weight.to_state_dict()
if self.bias is not None:
state_dict["bias"] = self.bias
if "destination" in kwargs and "prefix" in kwargs:
for key, value in state_dict.items():
kwargs["destination"][kwargs["prefix"] + key] = value
return state_dict

def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
all_expected_keys = ["val", "scale_quantized", "zero_quantized", "meta_info"]
if self.bias is not None:
all_expected_keys.append("bias")

for key in all_expected_keys:
if prefix + key not in state_dict:
missing_keys.append(key)
if missing_keys:
return # Can't load weights if either weight or meta is missing

cur_state_dict = {}
for key in all_expected_keys:
cur_state_dict[key] = state_dict.pop(prefix + key)

unexpected_keys += state_dict.keys()
self._assign_state_dict(cur_state_dict, strict)

def _assign_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
_scale_quantized = state_dict["scale_quantized"]
_zero_quantized = state_dict["zero_quantized"]
scale_state = state_dict["meta_info"]["scale"]
zero_state = state_dict["meta_info"]["zero"]
if _scale_quantized:
scale = HQQTensorHandle._create_q_tensor(scale_state["val"], scale_state["meta_info"])
else:
scale = state_dict["meta_info"]["scale"]
if _zero_quantized:
zero = HQQTensorHandle._create_q_tensor(zero_state["val"], zero_state["meta_info"])
else:
zero = state_dict["meta_info"]["zero"]
meta = state_dict["meta_info"]
meta["scale"] = scale
meta["zero"] = zero
self.q_weight = HQQTensorHandle._create_q_tensor(state_dict["val"], meta)
if self.bias is not None:
self.bias = state_dict["bias"]
self.quantized = True
return self
16 changes: 16 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,19 @@ def half(self):
if self.zero is not None:
self.zero = self.zero.half()
return self

def to_state_dict(self):
state = {}
state["val"] = self.val
state["meta_info"] = self.meta_info.to_dict()
state["scale_quantized"] = self.is_scale_quantized()
state["zero_quantized"] = self.is_zero_quantized()
if self.is_scale_quantized():
state["meta_info"]["scale"] = self.scale.to_state_dict()
else:
state["meta_info"]["scale"] = self.scale
if self.is_zero_quantized():
state["meta_info"]["zero"] = self.zero.to_state_dict()
else:
state["meta_info"]["zero"] = self.zero
return state
29 changes: 28 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path):

with open(qconfig_file_path, "r") as file:
self.quantization_config = json.load(file)

model = self._build_woq_model()
model.load_state_dict(qweights, assign=True)
model.eval()
Expand Down Expand Up @@ -157,8 +156,19 @@ def load_hf_format_woq_model(self):

return model

def _is_hqq_model(self):
for name, module in self.original_model.named_modules():
pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))"
for q_config_key, q_config_value in self.quantization_config.items():
if re.search(pattern, q_config_key):
if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "hqq":
return True

def _build_woq_model(self):
"""Build weight-only quantization model."""
if self._is_hqq_model():
return self._build_hqq_model()

from neural_compressor.torch.utils import set_module

from .modules import MulLinear
Expand Down Expand Up @@ -228,6 +238,23 @@ def _build_woq_model(self):
woq_model = self.original_model
return woq_model

def _build_hqq_model(self):
"""Replace quantized Linear with HQQLinear."""
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
from neural_compressor.torch.utils import set_module

for name, module in self.original_model.named_modules():
if isinstance(module, torch.nn.Linear):
loaded_state_dict_keys_set = set(self.loaded_state_dict_keys)
if name + ".val" not in loaded_state_dict_keys_set:
continue
new_module = HQQLinear(
in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None
)
set_module(self.original_model, name, new_module)
woq_model = self.original_model
return woq_model

def _get_model_class_and_config(self):
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,14 @@ def hqq_entry(
**kwargs,
) -> torch.nn.Module:
from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer
from neural_compressor.torch.algorithms.weight_only.save_load import save

logger.info("Quantize model with the HQQ algorithm.")

quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
model = quantizer.execute(model, mode=mode)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
dump_model_op_stats(mode, configs_mapping)

Expand Down
5 changes: 4 additions & 1 deletion neural_compressor/torch/quantization/load_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AWQConfig,
FP8Config,
GPTQConfig,
HQQConfig,
RTNConfig,
TEQConfig,
)
Expand Down Expand Up @@ -89,7 +90,9 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
# select load function
config_object = config_mapping[next(iter(config_mapping))]

if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ
if isinstance(
config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig, HQQConfig)
): # WOQ
from neural_compressor.torch.algorithms import weight_only

return weight_only.load(model_name_or_path, original_model, format=LoadFormat.DEFAULT)
Expand Down
78 changes: 78 additions & 0 deletions test/3x/torch/quantization/weight_only/test_hqq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import copy
import os
import time
from copy import deepcopy

import pytest
import torch
import transformers
from transformers import AutoModelForCausalLM

from neural_compressor.common import options
from neural_compressor.common.utils import logger
from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
Expand Down Expand Up @@ -93,6 +96,27 @@ def test_hqq_quant(self, force_use_cpu, force_not_half):
q_label_1.eq(q_label_2)
), "The results of calling `convert` + `prepare` and calling `quantize` should be equal."

def test_hqq_load_save(self, force_use_cpu, force_not_half):

hqq_global_option.use_half = False
fp32_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-OPTForCausalLM")
example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long, device="cpu")
# test_default_config
quant_config = get_default_hqq_config()

# prepare + convert API
model = prepare(deepcopy(fp32_model), quant_config)
qmodel = convert(model)
qmodel_out_ref = model(example_inputs)[0]
save_path = options.workspace + f"/_hqq_model_{time.time()}.pth"
qmodel.save(save_path)
from neural_compressor.torch.quantization import load

# loading compressed model
loaded_model = load(save_path, copy.deepcopy(fp32_model))
loaded_model_out = loaded_model(example_inputs)[0]
assert torch.allclose(qmodel_out_ref, loaded_model_out), "Unexpected result. Please double check."

def test_hqq_fallback(self, force_use_cpu, force_not_half):

class ToyModel(torch.nn.Module):
Expand Down Expand Up @@ -181,3 +205,57 @@ def test_hqq_module(
scale_quant_group_size=scale_quant_group_size,
device=torch.device(device_name),
)

@pytest.mark.parametrize(
"nbits, group_size, quant_zero, quant_scale, scale_quant_group_size",
[
(4, 64, True, False, 128),
(4, 64, False, False, 128),
(4, 64, True, True, 128),
(4, 64, False, True, 128),
(8, 64, True, False, 128),
],
)
def test_hqq_linear_save_and_load(
self,
nbits,
group_size,
quant_zero,
quant_scale,
scale_quant_group_size,
):
hqq_global_option.use_half = False
# Parse config
weight_qconfig = QTensorConfig(
nbits=nbits,
channel_wise=True,
group_size=group_size,
optimize=True,
round_zero=True if nbits == 4 else False,
)
zero_qconfig = None
if quant_zero:
zero_qconfig = QTensorConfig(nbits=8, channel_wise=False, group_size=None, optimize=False)
scale_qconfig = None
if quant_scale:
scale_qconfig = QTensorConfig(nbits=8, channel_wise=True, group_size=scale_quant_group_size, optimize=False)
hqq_quant_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig)
# Create HQQ Linear
bs = 4
in_features = 64
out_features = 128
float_linear = torch.nn.Linear(in_features=in_features, out_features=out_features)
float_linear.to(device)
float_linear_copy = deepcopy(float_linear)
input = torch.randn(bs, in_features, device=device)
hqq_linear = HQQLinear.from_float(float_linear_copy, quant_config=hqq_quant_config)
out_ref = hqq_linear(input)
state_dict = hqq_linear.state_dict()
hqq_module_path = options.workspace + f"/_hqq_linear_{time.time()}.pth"
torch.save(state_dict, hqq_module_path)
reload_state_dict = torch.load(hqq_module_path)
new_float = torch.nn.Linear(in_features=in_features, out_features=out_features)
new_hqq_linear = HQQLinear.from_float(new_float, quant_config=hqq_quant_config)
new_hqq_linear.load_state_dict(reload_state_dict)
out = new_hqq_linear(input)
assert torch.equal(out_ref, out), f"out_ref: {out_ref}, out: {out}"

0 comments on commit 34f0a9f

Please sign in to comment.