diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index 87a931610fa..e9b51a1a99a 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -34,6 +34,7 @@ FP8Config, get_default_fp8_config, get_default_fp8_config_set, + get_woq_tuning_config, ) from neural_compressor.torch.quantization.autotune import ( diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 44c5423a34d..41b63ae0490 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -308,6 +308,7 @@ def awq_quantize_entry( use_full_range = op_config.use_full_range run_fn = kwargs.get("run_fn", None) + run_args = kwargs.get("run_args", None) example_inputs = kwargs.get("example_inputs", None) assert example_inputs is not None, "Please provide example_inputs for AWQ quantization." @@ -318,6 +319,7 @@ def awq_quantize_entry( bits=-1, # no quantize for op not in weight_config example_inputs=example_inputs, # must be required run_fn=run_fn, + run_args=run_args, use_auto_scale=use_auto_scale, use_mse_search=use_mse_search, folding=folding, diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 9c1505c06a7..05f2d629a88 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -17,7 +17,9 @@ # pylint:disable=import-error from collections import OrderedDict -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional +from typing import OrderedDict as OrderedDictType +from typing import Tuple, Union import torch @@ -57,6 +59,7 @@ "get_default_gptq_config", "HQQConfig", "get_default_hqq_config", + "get_woq_tuning_config", ] @@ -839,7 +842,7 @@ def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[st def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None - ) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]: + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if is_ipex_imported(): return super().to_config_mapping(config_list, model_info) config_mapping = OrderedDict({self.name: self}) @@ -1140,3 +1143,23 @@ def get_default_fp8_config_set() -> FP8Config: def get_all_registered_configs() -> Dict[str, BaseConfig]: registered_configs = config_registry.get_all_configs() return registered_configs.get(FRAMEWORK_NAME, {}) + + +# ============================================================================= +# Tuning Config +# ============================================================================= + + +######################## WOQ Tuning Config ############################### +def get_woq_tuning_config() -> list: + """Generate the config set for WOQ tuning. + + Returns: + the list of WOQ quant config. + """ + RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32) + GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32) + GPTQ_G32ASYM_DISABLE_LAST_LINEAR = GPTQConfig(use_sym=False).set_local("*.lm_head", GPTQConfig(dtype="fp32")) + GPTQ_G128ASYM = GPTQConfig(group_size=128, use_sym=False) + AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32) + return [RTN_G32ASYM, GPTQ_G32ASYM, GPTQ_G32ASYM_DISABLE_LAST_LINEAR, GPTQ_G128ASYM, AWQ_G32ASYM] diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py index 0c82a5af051..73001e9797c 100644 --- a/test/3x/torch/test_autotune.py +++ b/test/3x/torch/test_autotune.py @@ -308,6 +308,30 @@ def eval_acc_fn(model) -> float: best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn) self.assertIsNone(best_model) + def test_woq_tuning(self): + from neural_compressor.torch.quantization import autotune, get_woq_tuning_config + + baseline = [1] + acc_res_lst = baseline + [0.9, 0.95, 0.95, 0.99, 1.1] + + def eval_acc_fn(model): + res = acc_res_lst.pop(0) + return res + + custom_tune_config = TuningConfig(config_set=get_woq_tuning_config(), tolerable_loss=-1) + example_inputs = torch.ones([1, 32], dtype=torch.long) + model = get_gpt_j() + dataloader = GPTQLLMDataLoader() + best_model = autotune( + model=model, + tune_config=custom_tune_config, + eval_fn=eval_acc_fn, + run_fn=run_fn_for_gptq, + run_args=(dataloader, True), # run_args should be a tuple, + example_inputs=example_inputs, + ) + self.assertIsNone(best_model) + if __name__ == "__main__": unittest.main()