Skip to content

Commit

Permalink
Add check on None (#2558)
Browse files Browse the repository at this point in the history
### Changes

Add a check on None for OverflowFix parameter.

### Reason for changes

Fix bug

### Related tickets

N/A

### Tests

Add test on a failed before scenario
  • Loading branch information
kshpv authored Mar 8, 2024
1 parent 09960b9 commit edd397b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def apply_advanced_parameters_to_config(
:param params: Advanced quantization parameters
:return: advanced quantization parameters as dict in the legacy format
"""
config["overflow_fix"] = params.overflow_fix.value
config["overflow_fix"] = params.overflow_fix if params.overflow_fix is None else params.overflow_fix.value
config["quantize_outputs"] = params.quantize_outputs

if params.disable_bias_correction:
Expand Down
9 changes: 9 additions & 0 deletions tests/tensorflow/quantization/test_ptq_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import OverflowFix
from nncf.quantization.advanced_parameters import QuantizationParameters
from nncf.quantization.advanced_parameters import apply_advanced_parameters_to_config
from nncf.quantization.range_estimator import RangeEstimatorParametersSet
from nncf.scopes import IgnoredScope
from nncf.tensorflow.quantization.quantize_model import _create_nncf_config
from nncf.tensorflow.quantization.quantize_model import _get_default_quantization_config


@pytest.mark.parametrize(
Expand Down Expand Up @@ -100,3 +102,10 @@ def test_create_nncf_config(params):
# To validate NNCFConfig requared input_info
config["input_info"] = {"sample_size": [1, 2, 224, 224]}
NNCFConfig.validate(config)


@pytest.mark.parametrize("preset", (QuantizationPreset.MIXED, QuantizationPreset.PERFORMANCE))
@pytest.mark.parametrize("advanced_quantization_params", (AdvancedQuantizationParameters(),))
def test_apply_advanced_parameters_to_config(preset, advanced_quantization_params):
compression_config = _get_default_quantization_config(preset, 1)
assert apply_advanced_parameters_to_config(compression_config, advanced_quantization_params)

0 comments on commit edd397b

Please sign in to comment.