Skip to content

Commit

Permalink
[2.x API] Enable autoround export (#1641)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Mar 1, 2024
1 parent 8e833bd commit a8c17bf
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/ut/env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
fi

if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
pip install git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf
pip install git+https://github.com/intel/auto-round.git@b65830f3f6cb32d92a5c8ba5f80ace12d517357b
fi

# test deps
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4938,6 +4938,8 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
dynamic_max_gap = self.recipes["autoround_args"].get("dynamic_max_gap", -1)
data_type = self.recipes["autoround_args"].get("data_type", "int") ##only support data_type
scale_dtype = self.recipes["autoround_args"].get("scale_dtype", "fp16")
# autoround export
export_args = self.recipes["autoround_args"].get("export_args", {"format": None})

model, autoround_config = autoround_quantize(
model=model,
Expand Down Expand Up @@ -4970,6 +4972,8 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
dynamic_max_gap=dynamic_max_gap,
data_type=data_type,
scale_dtype=scale_dtype,
# export arguments
export_args=export_args,
)
return model, autoround_config

Expand Down
10 changes: 10 additions & 0 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def autoround_quantize(
dynamic_max_gap: int = -1,
data_type: str = "int", ##only support data_type
scale_dtype="fp16",
export_args: dict = {"format": None, "inplace": True},
**kwargs,
):
"""Run autoround weight-only quantization.
Expand Down Expand Up @@ -746,6 +747,8 @@ def autoround_quantize(
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
data_type (str): The data type to be used (default is "int").
export_args (dict): The arguments for exporting compressed model, default is {"format": None, "inplace": True}.
Supported format: "itrex", "auto_gptq".
**kwargs: Additional keyword arguments.
Returns:
Expand Down Expand Up @@ -787,4 +790,11 @@ def autoround_quantize(
**kwargs,
)
qdq_model, weight_config = rounder.quantize()
if export_args["format"] is not None:
output_dir = export_args.get("output_dir", None)
format = export_args["format"]
inplace = export_args.get("inplace", True)
use_triton = export_args.get("use_triton", False)
model = rounder.save_quantized(output_dir=output_dir, format=format, inplace=inplace, use_triton=use_triton)
return model, weight_config
return qdq_model, weight_config
55 changes: 55 additions & 0 deletions test/adaptor/pytorch_adaptor/test_weight_only_adaptor_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,61 @@ def test_AutoRound_quant(self):
self.assertTrue("scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys())
self.assertTrue(torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"])

fp32_model = copy.deepcopy(self.gptj)

conf = PostTrainingQuantConfig(
approach="weight_only",
op_type_dict={
".*": { # re.match
"weight": {
"dtype": "int",
"bits": 4,
"group_size": 32, # -1 (per-channel)
"scheme": "sym",
"algorithm": "AUTOROUND",
},
},
},
op_name_dict={
".*lm_head": { # re.match
"weight": {"dtype": "fp32"},
},
},
recipes={
"autoround_args": {
"n_samples": 20,
"amp": False,
"seq_len": 10,
"iters": 10,
"scale_dtype": "fp32",
"device": "cpu",
"export_args": {"format": "itrex", "inplace": False},
},
},
)
"""All export arguments.
"export_args": {
"format": "itrex", # "iterx", "auto_gptq", default is None
"output_dir": None, # saved path
"inplace": False,
"use_triton": False,
}
"""
input = torch.ones([1, 512], dtype=torch.long)
fp32_model = copy.deepcopy(self.gptj)
out1 = fp32_model(input)
export_model = quantization.fit(
fp32_model,
conf,
calib_dataloader=dataloader,
)
out2 = export_model.model(input)
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01))
from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear

self.assertTrue(isinstance(export_model.model.transformer.h[0].attn.k_proj, WeightOnlyLinear))


if __name__ == "__main__":
unittest.main()
28 changes: 28 additions & 0 deletions test/quantization/test_weight_only_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,34 @@ def test_autoround_int_quant(self):
self.assertFalse(torch.all(out1[0] == out2[0]))
self.assertTrue(torch.all(out2[0] == out3[0]))

def test_autoround_export(self):
model = copy.deepcopy(self.gptj)
device = "cpu"
model = model
out1 = model(self.lm_input)
export_model, weight_config1 = autoround_quantize(
model=model,
tokenizer=self.tokenizer,
n_samples=20,
device=device,
amp=False,
seqlen=10,
iters=10,
scale_dtype="fp32",
export_args={"format": "itrex", "inplace": True},
)
export_model = export_model
model = model
out2 = model(self.lm_input)
out3 = export_model(self.lm_input)
self.assertTrue(torch.all(torch.isclose(out1[0], out2[0], atol=1e-1)))
self.assertFalse(torch.all(out1[0] == out2[0]))
self.assertTrue(torch.all(out2[0] == out3[0]))

from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear

self.assertTrue(isinstance(export_model.transformer.h[0].attn.k_proj, WeightOnlyLinear))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--find-links https://download.pytorch.org/whl/torch_stable.html
accelerate==0.21.0
dynast==1.6.0rc1
git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf
git+https://github.com/intel/auto-round.git@b65830f3f6cb32d92a5c8ba5f80ace12d517357b
horovod
intel-extension-for-pytorch
intel-tensorflow>=2.12.0
Expand Down

0 comments on commit a8c17bf

Please sign in to comment.