From 53ecfd878c204ddb73d3c72b887bea8661c6e5ae Mon Sep 17 00:00:00 2001 From: zina-cs <109593976+zina-cs@users.noreply.github.com> Date: Thu, 26 Sep 2024 18:48:20 +0400 Subject: [PATCH] I've built on awayzjj's test_dump method and added checks if the dataset has len implemented or not. @l-bat , awaiting your feedback, thank you. --- .../test_quantization_pipeline.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/openvino/native/quantization/test_quantization_pipeline.py b/tests/openvino/native/quantization/test_quantization_pipeline.py index 8196dbdfc75..302808a0339 100644 --- a/tests/openvino/native/quantization/test_quantization_pipeline.py +++ b/tests/openvino/native/quantization/test_quantization_pipeline.py @@ -197,3 +197,34 @@ def test_ignored_scope_dump(ignored_options, expected_dump, tmp_path): assert dumped_model.get_rt_info(rt_path) == value else: assert dumped_model.has_rt_info(rt_path) is False + +@pytest.mark.parametrize("subset_size, expected_actual_subset_size", [[1, 1], [2, 1]]) +def test_dump(subset_size, expected_actual_subset_size, tmp_path): + model = WeightsModel().ov_model + dataset = get_dataset_for_test(model) # dataset.get_length() == 1 + quantize_parameters = { + "preset": QuantizationPreset.PERFORMANCE, + "target_device": TargetDevice.CPU, + "subset_size": subset_size, + "fast_bias_correction": True, + } + + quantized_model = quantize_impl(model, dataset, **quantize_parameters) + ov.save_model(quantized_model, tmp_path / "ov_model.xml") + core = ov.Core() + dumped_model = core.read_model(tmp_path / "ov_model.xml") + + assert dumped_model.get_rt_info(["nncf", "quantization", "actual_subset_size"]) == str(expected_actual_subset_size) + + # Check if dataset has __len__ implemented + if hasattr(dataset, '__len__'): + with warnings.catch_warnings(record=True) as w: + collect_statistics(dataset, subset_size) + if len(dataset) < subset_size: + assert len(w) > 0 + assert "smaller than subset_size" in str(w[-1].message) + else: + # Handle the case when __len__ is not implemented + with warnings.catch_warnings(record=True) as w: + collect_statistics(dataset, subset_size) + assert len(w) == 0 # No warning expected