From e6664b05a4da8421b2fc2748d0c40eb1a2cad0c4 Mon Sep 17 00:00:00 2001 From: xinhe Date: Wed, 28 Feb 2024 15:29:32 +0800 Subject: [PATCH] add fp8 autotune ut and fix bug in autotune (#1638) Signed-off-by: xin3he --- .../torch/quantization/__init__.py | 1 + .../torch/quantization/autotune.py | 7 +++- .../torch/quantization/config.py | 13 ++++++- .../torch/quantization/habana_fp8/test_fp8.py | 38 ++++++++++++++++++- 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index 7892df51e23..e29b0017ee7 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -31,6 +31,7 @@ get_default_hqq_config, FP8Config, get_default_fp8_config, + get_default_fp8_config_set, ) from neural_compressor.torch.quantization.autotune import ( diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index 3a900985967..48c3bfe18d9 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -72,10 +72,13 @@ def autotune( tuning_logger.trial_end(trial_index) if tuning_monitor.need_stop(): logger.info("Stopped tuning.") + del q_model # maybe gc.collect() is needed for memory release best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config() # !!! Make sure to use deepcopy only when inplace is set to `True`. - quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True) - best_quant_model = model # quantize model inplace + q_model = quantize( + deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True + ) + best_quant_model = q_model # quantize model inplace break tuning_logger.tuning_end() return best_quant_model diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index be5b221132c..567a24a6d16 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -941,14 +941,23 @@ def get_config_set_for_tuning(cls) -> Union[None, "FP8Config", List["FP8Config"] def get_default_fp8_config() -> FP8Config: - """Generate the default gptq config. + """Generate the default fp8 config. Returns: - the default gptq config. + the default fp8 config. """ return FP8Config() +def get_default_fp8_config_set() -> FP8Config: + """Generate the default fp8 config set. + + Returns: + the default fp8 config. + """ + return FP8Config.get_config_set_for_tuning() + + ##################### Algo Configs End ################################### diff --git a/test/3x/torch/quantization/habana_fp8/test_fp8.py b/test/3x/torch/quantization/habana_fp8/test_fp8.py index 41e3af35870..d45cf80ef1d 100644 --- a/test/3x/torch/quantization/habana_fp8/test_fp8.py +++ b/test/3x/torch/quantization/habana_fp8/test_fp8.py @@ -18,8 +18,14 @@ FP8Matmul, Matmul, ) - from neural_compressor.torch.quantization import quantize - from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config + from neural_compressor.torch.quantization import ( + FP8Config, + TuningConfig, + autotune, + get_default_fp8_config, + get_default_fp8_config_set, + quantize, + ) torch.set_grad_enabled(False) @@ -164,3 +170,31 @@ def calib_func(model): assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." assert isinstance(m.mm, FP8Matmul), "Unexpected result. Please double check." assert isinstance(m.bmm, FP8BatchMatmul), "Unexpected result. Please double check." + + def test_autotune(self): + m = copy.deepcopy(self.model) + inp = self.inp + fp32_out = m(inp) + + def calib_func(model): + model(inp) + + accu_list = [1.0, 0.9, 0.99] + + def eval_func(model): + nonlocal accu_list + return accu_list.pop() + + tune_config = TuningConfig( + config_set=get_default_fp8_config_set(), + tolerable_loss=0.01, + ) + best_model = autotune( + model=m, + tune_config=tune_config, + run_fn=calib_func, + eval_fns=eval_func, + ) + assert isinstance(best_model.fc1, FP8Linear), "Unexpected result. Please double check." + assert isinstance(best_model.mm, FP8Matmul), "Unexpected result. Please double check." + assert isinstance(best_model.bmm, FP8BatchMatmul), "Unexpected result. Please double check."