From a8c17bfaf59385416e971c436a5f3564595f51fe Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Fri, 1 Mar 2024 16:30:42 +0800 Subject: [PATCH] [2.x API] Enable autoround export (#1641) Signed-off-by: Kaihui-intel --- .azure-pipelines/scripts/ut/env_setup.sh | 2 +- neural_compressor/adaptor/pytorch.py | 4 ++ .../adaptor/torch_utils/weight_only.py | 10 ++++ .../test_weight_only_adaptor_pytorch.py | 55 +++++++++++++++++++ .../test_weight_only_quantization.py | 28 ++++++++++ test/requirements.txt | 2 +- 6 files changed, 99 insertions(+), 2 deletions(-) diff --git a/.azure-pipelines/scripts/ut/env_setup.sh b/.azure-pipelines/scripts/ut/env_setup.sh index 1e0bf0bf8da..f3bfd6fe7db 100644 --- a/.azure-pipelines/scripts/ut/env_setup.sh +++ b/.azure-pipelines/scripts/ut/env_setup.sh @@ -99,7 +99,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then fi if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then - pip install git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf + pip install git+https://github.com/intel/auto-round.git@b65830f3f6cb32d92a5c8ba5f80ace12d517357b fi # test deps diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 314cd800043..efe705d299b 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4938,6 +4938,8 @@ def autoround_quantize(self, model, tune_cfg, dataloader): dynamic_max_gap = self.recipes["autoround_args"].get("dynamic_max_gap", -1) data_type = self.recipes["autoround_args"].get("data_type", "int") ##only support data_type scale_dtype = self.recipes["autoround_args"].get("scale_dtype", "fp16") + # autoround export + export_args = self.recipes["autoround_args"].get("export_args", {"format": None}) model, autoround_config = autoround_quantize( model=model, @@ -4970,6 +4972,8 @@ def autoround_quantize(self, model, tune_cfg, dataloader): dynamic_max_gap=dynamic_max_gap, data_type=data_type, scale_dtype=scale_dtype, + # export arguments + export_args=export_args, ) return model, autoround_config diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index e5099490ab3..d3816323432 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -703,6 +703,7 @@ def autoround_quantize( dynamic_max_gap: int = -1, data_type: str = "int", ##only support data_type scale_dtype="fp16", + export_args: dict = {"format": None, "inplace": True}, **kwargs, ): """Run autoround weight-only quantization. @@ -746,6 +747,8 @@ def autoround_quantize( not_use_best_mse (bool): Whether to use mean squared error (default is False). dynamic_max_gap (int): The dynamic maximum gap (default is -1). data_type (str): The data type to be used (default is "int"). + export_args (dict): The arguments for exporting compressed model, default is {"format": None, "inplace": True}. + Supported format: "itrex", "auto_gptq". **kwargs: Additional keyword arguments. Returns: @@ -787,4 +790,11 @@ def autoround_quantize( **kwargs, ) qdq_model, weight_config = rounder.quantize() + if export_args["format"] is not None: + output_dir = export_args.get("output_dir", None) + format = export_args["format"] + inplace = export_args.get("inplace", True) + use_triton = export_args.get("use_triton", False) + model = rounder.save_quantized(output_dir=output_dir, format=format, inplace=inplace, use_triton=use_triton) + return model, weight_config return qdq_model, weight_config diff --git a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor_pytorch.py b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor_pytorch.py index 7d75a19840d..ecfa34e56ff 100644 --- a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor_pytorch.py +++ b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor_pytorch.py @@ -801,6 +801,61 @@ def test_AutoRound_quant(self): self.assertTrue("scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()) self.assertTrue(torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]) + fp32_model = copy.deepcopy(self.gptj) + + conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*": { # re.match + "weight": { + "dtype": "int", + "bits": 4, + "group_size": 32, # -1 (per-channel) + "scheme": "sym", + "algorithm": "AUTOROUND", + }, + }, + }, + op_name_dict={ + ".*lm_head": { # re.match + "weight": {"dtype": "fp32"}, + }, + }, + recipes={ + "autoround_args": { + "n_samples": 20, + "amp": False, + "seq_len": 10, + "iters": 10, + "scale_dtype": "fp32", + "device": "cpu", + "export_args": {"format": "itrex", "inplace": False}, + }, + }, + ) + """All export arguments. + + "export_args": { + "format": "itrex", # "iterx", "auto_gptq", default is None + "output_dir": None, # saved path + "inplace": False, + "use_triton": False, + } + """ + input = torch.ones([1, 512], dtype=torch.long) + fp32_model = copy.deepcopy(self.gptj) + out1 = fp32_model(input) + export_model = quantization.fit( + fp32_model, + conf, + calib_dataloader=dataloader, + ) + out2 = export_model.model(input) + self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01)) + from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear + + self.assertTrue(isinstance(export_model.model.transformer.h[0].attn.k_proj, WeightOnlyLinear)) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_weight_only_quantization.py b/test/quantization/test_weight_only_quantization.py index d1e65a12a46..254742329e1 100644 --- a/test/quantization/test_weight_only_quantization.py +++ b/test/quantization/test_weight_only_quantization.py @@ -292,6 +292,34 @@ def test_autoround_int_quant(self): self.assertFalse(torch.all(out1[0] == out2[0])) self.assertTrue(torch.all(out2[0] == out3[0])) + def test_autoround_export(self): + model = copy.deepcopy(self.gptj) + device = "cpu" + model = model + out1 = model(self.lm_input) + export_model, weight_config1 = autoround_quantize( + model=model, + tokenizer=self.tokenizer, + n_samples=20, + device=device, + amp=False, + seqlen=10, + iters=10, + scale_dtype="fp32", + export_args={"format": "itrex", "inplace": True}, + ) + export_model = export_model + model = model + out2 = model(self.lm_input) + out3 = export_model(self.lm_input) + self.assertTrue(torch.all(torch.isclose(out1[0], out2[0], atol=1e-1))) + self.assertFalse(torch.all(out1[0] == out2[0])) + self.assertTrue(torch.all(out2[0] == out3[0])) + + from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear + + self.assertTrue(isinstance(export_model.transformer.h[0].attn.k_proj, WeightOnlyLinear)) + if __name__ == "__main__": unittest.main() diff --git a/test/requirements.txt b/test/requirements.txt index 3b71fd8af68..cc0205b2d72 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,7 +1,7 @@ --find-links https://download.pytorch.org/whl/torch_stable.html accelerate==0.21.0 dynast==1.6.0rc1 -git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf +git+https://github.com/intel/auto-round.git@b65830f3f6cb32d92a5c8ba5f80ace12d517357b horovod intel-extension-for-pytorch intel-tensorflow>=2.12.0