From 2284df5d97a24eb1328595072123529f18652b6d Mon Sep 17 00:00:00 2001 From: Aleksei Kashapov Date: Mon, 25 Nov 2024 21:27:36 +0100 Subject: [PATCH] [ONNX] Fix sporadic results in BC (#3081) ### Changes 1. This PR addresses an issue using `ONNXRuntime==1.19.2` where a tensor used as both an input and output in a model shares the same memory. This causes unexpected behavior: updating the input tensor inadvertently modifies the statistics data due to memory overlap. The issue was confirmed by calling `np.shares_memory(input_data['image'], outputs['image'])`, which returned `True`, indicating that the input and output tensors share memory. After applying the proposed changes, the same check now returns `False`, confirming that memory sharing is resolved. To fix this, the `ONNXEngine` logic has been updated to create a copy of any output tensor that is also used as a model input. This ensures that the input tensor and statistics data remain independent, avoiding unintended side effects. 2. Merge RawReducer and NoopReducer 3. Minor fixes (remove warnings + fix bug in BC) ### Reason for changes Regression ### Related tickets 156025 ### Tests PTQ run 549 --- .../common/tensor_statistics/collectors.py | 12 ++---------- nncf/onnx/engine.py | 10 ++++++++-- nncf/onnx/graph/model_transformer.py | 2 +- nncf/onnx/statistics/collectors.py | 5 ++--- nncf/openvino/statistics/collectors.py | 5 ++--- .../algorithms/bias_correction/algorithm.py | 2 +- .../algorithms/weight_compression/backend.py | 4 ++-- nncf/quantization/fake_quantize.py | 4 +++- nncf/torch/tensor_statistics/collectors.py | 7 +++---- .../experimental/test_reducers_and_aggregators.py | 7 +++---- tests/common/quantization/test_tune_range.py | 12 ++++++++++-- .../test_weights_compression_backends.py | 4 ++-- 12 files changed, 39 insertions(+), 35 deletions(-) diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 36b3b803566..89736bed238 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -412,25 +412,17 @@ def __init__(self, tensor_collectors: List[TensorCollector]) -> None: ################################################## -class NoopReducer(TensorReducerBase): +class RawReducer(TensorReducerBase): def __init__(self): super().__init__(inplace=False) def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: return None - def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]: + def _reduce_out_of_place(self, x: List[Tensor]) -> List[Tensor]: return x -class RawReducer(NoopReducer): - def __init__(self): - super().__init__() - - def __call__(self, x: List[Tensor]): - return self._reduce_out_of_place(x) - - class ShapeReducer(TensorReducerBase): def __init__(self, inplace: bool = False): super().__init__(inplace=inplace) diff --git a/nncf/onnx/engine.py b/nncf/onnx/engine.py index a268e617374..d86d617d976 100644 --- a/nncf/onnx/engine.py +++ b/nncf/onnx/engine.py @@ -38,7 +38,13 @@ def infer(self, input_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: :param input_data: inputs for the model :return output_data: models outputs """ - output_tensors = self.sess.run([], {k: v for k, v in input_data.items() if k in self.input_names}) + output_tensors = self.sess.run([], input_data) model_outputs = self.sess.get_outputs() - return {output.name: tensor for tensor, output in zip(output_tensors, model_outputs)} + outputs_safe = {} + for tensor, output in zip(output_tensors, model_outputs): + # Workaround for https://github.com/microsoft/onnxruntime/issues/21922 + # After fixing this copying should be removed + outputs_safe[output.name] = tensor.copy() if output.name in self.input_names else tensor + + return outputs_safe diff --git a/nncf/onnx/graph/model_transformer.py b/nncf/onnx/graph/model_transformer.py index a584251dab9..7db5c0a7921 100644 --- a/nncf/onnx/graph/model_transformer.py +++ b/nncf/onnx/graph/model_transformer.py @@ -49,7 +49,7 @@ class ONNXModelTransformer(ModelTransformer): def __init__(self, model: onnx.ModelProto): infered_model = onnx.shape_inference.infer_shapes(model) super().__init__(infered_model) - self.onnx_model_extractor = onnx.utils.Extractor(self._model) + self.onnx_model_extractor = onnx.utils.Extractor(infered_model) def _get_target_edge( self, diff --git a/nncf/onnx/statistics/collectors.py b/nncf/onnx/statistics/collectors.py index 60a2d40a049..e590e587e12 100644 --- a/nncf/onnx/statistics/collectors.py +++ b/nncf/onnx/statistics/collectors.py @@ -20,7 +20,6 @@ from nncf.experimental.common.tensor_statistics.collectors import MeanReducer from nncf.experimental.common.tensor_statistics.collectors import MinReducer from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator -from nncf.experimental.common.tensor_statistics.collectors import NoopReducer from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer from nncf.experimental.common.tensor_statistics.collectors import RawReducer from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator @@ -48,7 +47,7 @@ def get_mean_statistic_collector( reducer = BatchMeanReducer(inplace) else: reducer = MeanPerChReducer(channel_axis=channel_axis, inplace=inplace) - noop_reducer = NoopReducer() + raw_reducer = RawReducer() kwargs = { "num_samples": num_samples, @@ -60,7 +59,7 @@ def get_mean_statistic_collector( collector = TensorCollector(MeanTensorStatistic) collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean) - collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, noop_reducer, aggregate_shape) + collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape) return collector diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index a851e49f832..44a6ed606d0 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -24,7 +24,6 @@ from nncf.experimental.common.tensor_statistics.collectors import MeanVarianceReducer from nncf.experimental.common.tensor_statistics.collectors import MinReducer from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator -from nncf.experimental.common.tensor_statistics.collectors import NoopReducer from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer from nncf.experimental.common.tensor_statistics.collectors import RawReducer from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator @@ -128,7 +127,7 @@ def get_mean_statistic_collector( reducer = OVBatchMeanReducer(inplace) else: reducer = OVMeanPerChanelReducer(channel_axis=channel_axis, inplace=inplace) - noop_reducer = NoopReducer() + raw_reducer = RawReducer() kwargs = { "num_samples": num_samples, @@ -139,7 +138,7 @@ def get_mean_statistic_collector( collector = TensorCollector(MeanTensorStatistic) collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean) - collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, noop_reducer, aggregate_shape) + collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape) return collector diff --git a/nncf/quantization/algorithms/bias_correction/algorithm.py b/nncf/quantization/algorithms/bias_correction/algorithm.py index 938a1166f44..63db2ee0adf 100644 --- a/nncf/quantization/algorithms/bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/bias_correction/algorithm.py @@ -442,7 +442,7 @@ def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Te """ bias_shift_magnitude = fns.max( fns.abs( - (updated_bias_value - current_bias_value) / (current_bias_value + fns.finfo(current_bias_value).min) + (updated_bias_value - current_bias_value) / (current_bias_value + fns.finfo(current_bias_value).eps) ) ) return bias_shift_magnitude diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py index eeda65a5bd7..004bb08baef 100644 --- a/nncf/quantization/algorithms/weight_compression/backend.py +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -20,7 +20,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.experimental.common.tensor_statistics.collectors import HAWQAggregator -from nncf.experimental.common.tensor_statistics.collectors import NoopReducer +from nncf.experimental.common.tensor_statistics.collectors import RawReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.statistics import HessianTensorStatistic from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters @@ -252,7 +252,7 @@ def scale_insertion_command(source_node, next_nodes, source_node_output_port, sc class MixedPrecisionAlgoBackend(ABC): @staticmethod def hawq_statistic_collector(subset_size: Optional[int] = None) -> TensorCollector: - reducer = NoopReducer() + reducer = RawReducer() aggregator = HAWQAggregator(num_samples=subset_size) collector = TensorCollector(HessianTensorStatistic) collector.register_statistic_branch(HessianTensorStatistic.HESSIAN_INPUT_ACTIVATION_STATS, reducer, aggregator) diff --git a/nncf/quantization/fake_quantize.py b/nncf/quantization/fake_quantize.py index d5a3e96ae64..12649ae26f7 100644 --- a/nncf/quantization/fake_quantize.py +++ b/nncf/quantization/fake_quantize.py @@ -122,11 +122,13 @@ def tune_range( fval = -left_border * s qval = fns.round(fval) - ra = fns.where(qval < level_high, qval / (qval - level_high) * right_border, left_border) with warnings.catch_warnings(): # If `qval` is 0 `rb` will equal `right_border`, and we don't want to show an unnecessary division by 0 warning + # The same for (qval - level_high) warnings.simplefilter("ignore") + ra_then_result = qval / (qval - level_high) * right_border rb_then_result = (qval - level_high) / qval * left_border + ra = fns.where(qval < level_high, ra_then_result, left_border) rb = fns.where(qval > 0.0, rb_then_result, right_border) range_a = right_border - ra diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index abe7aa60acd..8706fc076b2 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -27,7 +27,6 @@ from nncf.experimental.common.tensor_statistics.collectors import MinAggregator from nncf.experimental.common.tensor_statistics.collectors import MinReducer from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator -from nncf.experimental.common.tensor_statistics.collectors import NoopReducer from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer from nncf.experimental.common.tensor_statistics.collectors import RawReducer @@ -246,7 +245,7 @@ def _get_collection_without_reduction( :return: Target statistic collector. """ tensor_collector = TensorCollector(statistic_cls) - reducer = NoopReducer() + reducer = RawReducer() aggregation_axes = list(set(list(aggregation_axes) + [dim + 1 for dim in reduction_axes])) aggregator = aggregator_cls( aggregation_axes=aggregation_axes, @@ -311,7 +310,7 @@ def get_mean_statistic_collector( reducer = BatchMeanReducer() else: reducer = MeanPerChReducer(channel_axis=channel_axis) - noop_reducer = NoopReducer() + raw_reducer = RawReducer() kwargs = { "num_samples": num_samples, @@ -322,7 +321,7 @@ def get_mean_statistic_collector( collector = TensorCollector(MeanTensorStatistic) collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean) - collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, noop_reducer, aggregate_shape) + collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape) return collector diff --git a/tests/common/experimental/test_reducers_and_aggregators.py b/tests/common/experimental/test_reducers_and_aggregators.py index 5544a426a54..9b490e5a0d2 100644 --- a/tests/common/experimental/test_reducers_and_aggregators.py +++ b/tests/common/experimental/test_reducers_and_aggregators.py @@ -31,7 +31,6 @@ from nncf.experimental.common.tensor_statistics.collectors import MedianNoOutliersAggregator from nncf.experimental.common.tensor_statistics.collectors import MinAggregator from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator -from nncf.experimental.common.tensor_statistics.collectors import NoopReducer from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator from nncf.experimental.common.tensor_statistics.collectors import RawReducer from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator @@ -173,19 +172,19 @@ def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = Non def cast_tensor(self, tensor, dtype: Dtype): pass - @pytest.mark.parametrize("reducer_cls", [NoopReducer, RawReducer]) + @pytest.mark.parametrize("reducer_cls", [RawReducer]) @pytest.mark.parametrize("input_data", [np.arange(24).reshape((1, 2, 3, 4)), np.array([])]) def test_other_reducers(self, reducer_cls, input_data): reducer = reducer_cls() tensor_data = self.get_nncf_tensor(input_data) reduced_input = reducer([tensor_data]) - if reducer_cls == NoopReducer and tensor_data.isempty(): + if tensor_data.isempty(): assert reduced_input is None else: assert len(reduced_input) == 1 assert fns.allclose(reduced_input[0], tensor_data) - @pytest.mark.parametrize("reducer_cls", [NoopReducer, RawReducer, ShapeReducer]) + @pytest.mark.parametrize("reducer_cls", [RawReducer, ShapeReducer]) def test_other_reducers_name_hash_equal(self, reducer_cls): reducers_instances = [reducer_cls() for _ in range(2)] assert hash(reducers_instances[0]) == hash(reducers_instances[1]) diff --git a/tests/common/quantization/test_tune_range.py b/tests/common/quantization/test_tune_range.py index 92fe5351298..6935f130992 100644 --- a/tests/common/quantization/test_tune_range.py +++ b/tests/common/quantization/test_tune_range.py @@ -11,13 +11,21 @@ import warnings import numpy as np +import pytest from nncf.quantization.fake_quantize import tune_range from nncf.tensor import Tensor -def test_tune_range_zero_division_warning(): +@pytest.mark.parametrize( + "params", + ( + (Tensor(np.array([0.0])), Tensor(np.array([1.0])), 8, False), + (Tensor(np.array([-1.0])), Tensor(np.array([0.0])), 8, False), + ), +) +def test_tune_range_zero_division_warning(params): with warnings.catch_warnings(record=True) as w: # Calling tune_range should not raise a warning - tune_range(Tensor(np.array([0.0])), Tensor(np.array([1.0])), 8, False) + tune_range(*params) assert len(w) == 0 diff --git a/tests/cross_fw/test_templates/test_weights_compression_backends.py b/tests/cross_fw/test_templates/test_weights_compression_backends.py index fb521fd0726..4d72e62239a 100644 --- a/tests/cross_fw/test_templates/test_weights_compression_backends.py +++ b/tests/cross_fw/test_templates/test_weights_compression_backends.py @@ -18,7 +18,7 @@ from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator from nncf.experimental.common.tensor_statistics.collectors import MeanVarianceReducer -from nncf.experimental.common.tensor_statistics.collectors import NoopReducer +from nncf.experimental.common.tensor_statistics.collectors import RawReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector @@ -58,7 +58,7 @@ def check_reducer(self, collector: TensorCollector, expected_reducer_type): @pytest.mark.parametrize( "algo_func, aggregator_type, reducer_type", [ - ("get_hawq_with_backend", HAWQAggregator, NoopReducer), + ("get_hawq_with_backend", HAWQAggregator, RawReducer), ("get_mean_variance_with_backend", MeanAggregator, MeanVarianceReducer), ("get_max_variance_with_backend", MeanAggregator, MaxVarianceReducer), ("get_mean_max_with_backend", MeanAggregator, MeanAbsMaxReducer),