From edd397b01f8c45b976785fa67d425467d80835ae Mon Sep 17 00:00:00 2001 From: Aleksei Kashapov Date: Fri, 8 Mar 2024 12:20:50 +0100 Subject: [PATCH] Add check on None (#2558) ### Changes Add a check on None for OverflowFix parameter. ### Reason for changes Fix bug ### Related tickets N/A ### Tests Add test on a failed before scenario --- nncf/quantization/advanced_parameters.py | 2 +- tests/tensorflow/quantization/test_ptq_params.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/nncf/quantization/advanced_parameters.py b/nncf/quantization/advanced_parameters.py index b28edce49e0..dff463260be 100644 --- a/nncf/quantization/advanced_parameters.py +++ b/nncf/quantization/advanced_parameters.py @@ -396,7 +396,7 @@ def apply_advanced_parameters_to_config( :param params: Advanced quantization parameters :return: advanced quantization parameters as dict in the legacy format """ - config["overflow_fix"] = params.overflow_fix.value + config["overflow_fix"] = params.overflow_fix if params.overflow_fix is None else params.overflow_fix.value config["quantize_outputs"] = params.quantize_outputs if params.disable_bias_correction: diff --git a/tests/tensorflow/quantization/test_ptq_params.py b/tests/tensorflow/quantization/test_ptq_params.py index 9b859e886aa..6dc2f90e708 100644 --- a/tests/tensorflow/quantization/test_ptq_params.py +++ b/tests/tensorflow/quantization/test_ptq_params.py @@ -19,9 +19,11 @@ from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix from nncf.quantization.advanced_parameters import QuantizationParameters +from nncf.quantization.advanced_parameters import apply_advanced_parameters_to_config from nncf.quantization.range_estimator import RangeEstimatorParametersSet from nncf.scopes import IgnoredScope from nncf.tensorflow.quantization.quantize_model import _create_nncf_config +from nncf.tensorflow.quantization.quantize_model import _get_default_quantization_config @pytest.mark.parametrize( @@ -100,3 +102,10 @@ def test_create_nncf_config(params): # To validate NNCFConfig requared input_info config["input_info"] = {"sample_size": [1, 2, 224, 224]} NNCFConfig.validate(config) + + +@pytest.mark.parametrize("preset", (QuantizationPreset.MIXED, QuantizationPreset.PERFORMANCE)) +@pytest.mark.parametrize("advanced_quantization_params", (AdvancedQuantizationParameters(),)) +def test_apply_advanced_parameters_to_config(preset, advanced_quantization_params): + compression_config = _get_default_quantization_config(preset, 1) + assert apply_advanced_parameters_to_config(compression_config, advanced_quantization_params)