diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 1f74aa0a7df..c1797aadccb 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -125,6 +125,7 @@ def __init__( tensor_processor: NNCFCollectorTensorProcessor, aggregation_axes: Optional[AggregationAxes] = None, num_samples: Optional[int] = None, + window_size: Optional[int] = None, ): """ :param tensor_processor: Backend-specific tensor processor. @@ -134,14 +135,17 @@ def __init__( :param num_samples: Maximum number of samples to collect. Aggregator skips tensor registration if tensor registration was called num_samples times before. Aggregator never skips registration if num_samples is None. + :param window_size: Number of samples from the end of the list of collected samples to aggregate. + Aggregates all available collected statistics in case parameter is None. """ self._tensor_processor = tensor_processor self._aggregation_axes = (0,) if aggregation_axes is None else aggregation_axes - self._keepdims = False + self._keepdims = True self._num_samples = num_samples self._collected_samples = 0 - self._container = [] + self._window_size = window_size + self._container = deque(maxlen=window_size) @property def num_samples(self) -> int: @@ -594,20 +598,7 @@ def _aggregate_impl(self): return self._container.shape -class TensorAggregatorBase(AggregatorBase, ABC): - def __init__( - self, - tensor_processor: NNCFCollectorTensorProcessor, - aggregation_axes: Optional[AggregationAxes] = None, - num_samples: Optional[int] = None, - window_size=None, - ): - super().__init__(tensor_processor, aggregation_axes=aggregation_axes, num_samples=num_samples) - self._window_size = window_size - self._container = deque(maxlen=window_size) - - -class OnlineAggregatorBase(TensorAggregatorBase, ABC): +class OnlineAggregatorBase(AggregatorBase, ABC): """ Base class for aggregators which are using aggregation function fn with following property: fn([x1, x2, x3]) == fn([fn([x1, x2]), x3]) where x1, x2, x3 are samples to aggregate. @@ -622,19 +613,17 @@ def _register_reduced_input_impl(self, x: NNCFTensor) -> None: else: reduced = x if 0 in self._aggregation_axes: - if self._container: - reduced = self._aggregation_fn( - self._tensor_processor.stack([reduced, self._container]), axis=0, keepdims=False - ) - self._container = reduced + stacked_tensors = self._tensor_processor.stack([reduced, *self._container], axis=0) + aggregated = self._aggregation_fn(stacked_tensors, axis=0, keepdims=self._keepdims) + aggregated = self._tensor_processor.squeeze(aggregated, 0) + self._container = [aggregated] else: self._container.append(reduced) def _aggregate_impl(self) -> NNCFTensor: if 0 in self._aggregation_axes: if self._keepdims: - return self._tensor_processor.stack([self._container]).tensor - return self._container.tensor + return self._container[0].tensor return self._tensor_processor.stack(self._container).tensor @abstractmethod @@ -652,7 +641,7 @@ def _aggregation_fn(self, stacked_value: NNCFTensor, axis: AggregationAxes, keep return self._tensor_processor.reduce_max(stacked_value, axis=axis, keepdims=keepdims) -class OfflineAggregatorBase(TensorAggregatorBase, ABC): +class OfflineAggregatorBase(AggregatorBase, ABC): """ Base class for aggregators which are using aggregation function fn which does not fulfill property fn([x1, x2, x3]) == fn([fn([x1, x2]), x3]) @@ -665,7 +654,8 @@ def _register_reduced_input_impl(self, x: TensorType) -> None: def _aggregate_impl(self) -> NNCFTensor: stacked_val = self._tensor_processor.stack(self._container) - return self._aggregation_fn(stacked_val, axis=self._aggregation_axes, keepdims=self._keepdims).tensor + aggregated = self._aggregation_fn(stacked_val, axis=self._aggregation_axes, keepdims=self._keepdims) + return self._tensor_processor.squeeze(aggregated, 0).tensor @abstractmethod def _aggregation_fn(self, stacked_value: NNCFTensor, axis: AggregationAxes, keepdims: bool) -> NNCFTensor: @@ -699,13 +689,19 @@ def __init__( def _aggregate_impl(self) -> NNCFTensor: stacked_samples = self._tensor_processor.stack(self._container) low_values, high_values = self._tensor_processor.quantile( - stacked_samples, quantile=(self._quantile, 1 - self._quantile), axis=self._aggregation_axes + stacked_samples, + quantile=(self._quantile, 1 - self._quantile), + axis=self._aggregation_axes, ) tp = self._tensor_processor outliers_mask = tp.logical_or(tp.less(stacked_samples, low_values), tp.less(high_values, stacked_samples)) - return self._aggregation_fn( - stacked_samples=stacked_samples, mask=outliers_mask, axis=self._aggregation_axes, keepdims=self._keepdims - ).tensor + aggregated = self._aggregation_fn( + stacked_samples=stacked_samples, + mask=outliers_mask, + axis=self._aggregation_axes, + keepdims=self._keepdims, + ) + return self._tensor_processor.squeeze(aggregated, 0).tensor @abstractmethod def _aggregation_fn( @@ -734,7 +730,7 @@ def _aggregation_fn( return self._tensor_processor.masked_median(stacked_samples, axis=axis, mask=mask, keepdims=keepdims) -class MedianAbsoluteDeviationAggregator(TensorAggregatorBase): +class MedianAbsoluteDeviationAggregator(AggregatorBase): def __init__( self, tensor_processor: NNCFCollectorTensorProcessor, @@ -780,7 +776,7 @@ def _aggregate_impl(self) -> Dict[str, NNCFTensor]: } -class PercentileAggregator(TensorAggregatorBase): +class PercentileAggregator(AggregatorBase): def __init__( self, tensor_processor: NNCFCollectorTensorProcessor, @@ -850,9 +846,7 @@ def _moveaxes_flatten_cat( transposed_t = tensor_processor.transpose(tensor, transpose_dims) reshaped_tensors.append(tensor_processor.reshape(transposed_t, reshape_shape)) - shape_after_aggregation = (1,) + tuple( - 1 if idx in aggregation_axes else dim for idx, dim in enumerate(tensor_shape) - ) + shape_after_aggregation = tuple(1 if idx in aggregation_axes else dim for idx, dim in enumerate(tensor_shape)) return tensor_processor.cat(reshaped_tensors, axis=0), shape_after_aggregation diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index ea02b952959..34bc7693218 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -20,6 +20,7 @@ from nncf.common.tensor_statistics.collectors import NNCFTensor from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer +from nncf.experimental.common.tensor_statistics.collectors import AggregatorBase from nncf.experimental.common.tensor_statistics.collectors import BatchMeanReducer from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator from nncf.experimental.common.tensor_statistics.collectors import MaxReducer @@ -33,7 +34,6 @@ from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator -from nncf.experimental.common.tensor_statistics.collectors import TensorAggregatorBase from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.advanced_parameters import StatisticsType from nncf.torch.tensor import PTNNCFTensor @@ -443,8 +443,8 @@ def get_percentile_tensor_collector( def _get_collection_without_reduction( - aggregator_cls: TensorAggregatorBase, - statistic_cls: TensorAggregatorBase, + aggregator_cls: AggregatorBase, + statistic_cls: AggregatorBase, reduction_axes: Tuple[int, ...], aggregation_axes: Tuple[int, ...], num_samples: int, diff --git a/tests/common/experimental/test_reducers_and_aggregators.py b/tests/common/experimental/test_reducers_and_aggregators.py index 9a0e5a8d4a9..b6e76c1e029 100644 --- a/tests/common/experimental/test_reducers_and_aggregators.py +++ b/tests/common/experimental/test_reducers_and_aggregators.py @@ -59,49 +59,47 @@ class OfflineAggregatorTestCase: OFFLINE_AGGREGATORS_TEST_CASES = [ - OfflineAggregatorTestCase( - aggregation_axes=None, - min_ref=np.array([[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]]), - max_ref=np.array([[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]]), - ), OfflineAggregatorTestCase( aggregation_axes=(0,), min_ref=np.array([[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]]), max_ref=np.array([[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]]), ), OfflineAggregatorTestCase( - aggregation_axes=(0, 2), - min_ref=np.array([[-50000, -28, -32]]), - max_ref=np.array([[50000, 28, 32]]), + aggregation_axes=( + 0, + 2, + ), + min_ref=np.array([[[-50000, -28, -32]]]), + max_ref=np.array([[[50000, 28, 32]]]), ), OfflineAggregatorTestCase( aggregation_axes=(2,), min_ref=np.array( [ - [[-50000, 5, 10]], - [[-40000, 4, 8]], - [[-30000, 3, 6]], - [[-20000, 2, 4]], - [[-10000, 1, 2]], - [[0, 0, 0]], - [[-6, -7, -8]], - [[-12, -14, -16]], - [[-18, -21, -24]], - [[-24, -28, -32]], + [[[-50000, 5, 10]]], + [[[-40000, 4, 8]]], + [[[-30000, 3, 6]]], + [[[-20000, 2, 4]]], + [[[-10000, 1, 2]]], + [[[0, 0, 0]]], + [[[-6, -7, -8]]], + [[[-12, -14, -16]]], + [[[-18, -21, -24]]], + [[[-24, -28, -32]]], ] ), max_ref=np.array( [ - [[50000, -5, -10]], - [[40000, -4, -8]], - [[30000, -3, -6]], - [[20000, -2, -4]], - [[10000, -1, -2]], - [[0, 0, 0]], - [[6, 7, 8]], - [[12, 14, 16]], - [[18, 21, 24]], - [[24, 28, 32]], + [[[50000, -5, -10]]], + [[[40000, -4, -8]]], + [[[30000, -3, -6]]], + [[[20000, -2, -4]]], + [[[10000, -1, -2]]], + [[[0, 0, 0]]], + [[[6, 7, 8]]], + [[[12, 14, 16]]], + [[[18, 21, 24]]], + [[[24, 28, 32]]], ] ), ), @@ -254,8 +252,8 @@ def test_min_max_aggregators( assert self.all_close(max_aggregator.aggregate(), max_ref) NO_OUTLIERS_TEST_PARAMS = [ - (MeanAggregator, True, 1, 1404.5138888888905), - (MedianAggregator, True, 1, 24.0), + (MeanAggregator, True, 1, [1404.5138888888905]), + (MedianAggregator, True, 1, [24.0]), ( MeanAggregator, False, @@ -263,16 +261,16 @@ def test_min_max_aggregators( [2503.125, -2493.75, 5009.375, -4987.5, 7515.625, -7481.25, 10021.875, -9975.0, 12528.125], ), (MedianAggregator, False, 1, [4.5, 5.0, 13.5, 10.0, 22.5, 15.0, 31.5, 20.0, 40.5]), - (MeanAggregator, True, 2, [2512.5, -1651.04166667, 3352.08333333]), - (MedianAggregator, True, 2, [13.0, 12.5, 21.0]), + (MeanAggregator, True, 2, [[2512.5, -1651.04166667, 3352.08333333]]), + (MedianAggregator, True, 2, [[13.0, 12.5, 21.0]]), (MeanAggregator, False, 2, DEFALUT_3D_MEAN_VALUE), (MedianAggregator, False, 2, DEFALUT_3D_MEDIAN_VALUE), - (MeanAggregator, True, 3, DEFALUT_3D_MEAN_VALUE), - (MedianAggregator, True, 3, DEFALUT_3D_MEDIAN_VALUE), + (MeanAggregator, True, 3, [DEFALUT_3D_MEAN_VALUE]), + (MedianAggregator, True, 3, [DEFALUT_3D_MEDIAN_VALUE]), (MeanAggregator, False, 3, [DEFALUT_3D_MEAN_VALUE]), (MedianAggregator, False, 3, [DEFALUT_3D_MEDIAN_VALUE]), - (default_test_mean_no_outlier, True, 1, 20.0893), - (default_test_median_no_outlier, True, 1, 30.0), + (default_test_mean_no_outlier, True, 1, [20.0893]), + (default_test_median_no_outlier, True, 1, [30.0]), ( default_test_mean_no_outlier, False, @@ -280,12 +278,12 @@ def test_min_max_aggregators( [4.16666667, 8.33333333, 12.5, 16.66666667, 20.83333333, 25.0, 29.16666667, 33.33333333, 37.5], ), (default_test_median_no_outlier, False, 1, [5.0, 4.0, 15.0, 8.0, 25.0, 12.0, 35.0, 16.0, 45.0]), - (default_test_mean_no_outlier, True, 2, [16.66666667, 20.83333333, 25.0]), - (default_test_median_no_outlier, True, 2, [14.0, 10.0, 24.0]), + (default_test_mean_no_outlier, True, 2, [[16.66666667, 20.83333333, 25.0]]), + (default_test_median_no_outlier, True, 2, [[14.0, 10.0, 24.0]]), (default_test_mean_no_outlier, False, 2, NO_OUTLIERS_DEFAULT_3D_MEAN_VALUE), (default_test_median_no_outlier, False, 2, NO_OUTLIERS_DEFAULT_3D_MEDIAN_VALUE), - (default_test_mean_no_outlier, True, 3, NO_OUTLIERS_DEFAULT_3D_MEAN_VALUE), - (default_test_median_no_outlier, True, 3, NO_OUTLIERS_DEFAULT_3D_MEDIAN_VALUE), + (default_test_mean_no_outlier, True, 3, [NO_OUTLIERS_DEFAULT_3D_MEAN_VALUE]), + (default_test_median_no_outlier, True, 3, [NO_OUTLIERS_DEFAULT_3D_MEDIAN_VALUE]), (default_test_mean_no_outlier, False, 3, [NO_OUTLIERS_DEFAULT_3D_MEAN_VALUE]), (default_test_median_no_outlier, False, 3, [NO_OUTLIERS_DEFAULT_3D_MEDIAN_VALUE]), ] @@ -348,8 +346,8 @@ def aggregator_cls_fixture(self, request): "mad_values": np.array([2.5, 5.0, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 22.5]), }, (0, 1): { - "median_values": np.array(18.0), - "mad_values": np.array(12.0), + "median_values": np.array([18.0]), + "mad_values": np.array([12.0]), }, }, PercentileAggregator: { @@ -366,10 +364,10 @@ def aggregator_cls_fixture(self, request): 95: np.array([7.6, 15.2, 22.8, 30.4, 38.0, 45.6, 53.2, 60.8, 68.4]), }, (0, 1): { - 5: np.array(0.0), - 10: np.array(0.0), - 90: np.array(48.0), - 95: np.array(56.0), + 5: np.array([0.0]), + 10: np.array([0.0]), + 90: np.array([48.0]), + 95: np.array([56.0]), }, }, } @@ -389,14 +387,14 @@ def test_mad_percentile_aggregators(self, MAD_precentile_aggregator_cls, tensor_ REF_MAD_PERCENTILE_REF_VALUES_DYNAMIC_TENSORS = { MedianAbsoluteDeviationAggregator: { - "median_values": np.array([28.5, 35.5, 43.5]), - "mad_values": np.array([24.0, 24.0, 24.0]), + "median_values": np.array([[28.5, 35.5, 43.5]]).reshape(1, 3, 1), + "mad_values": np.array([[[24.0, 24.0, 24.0]]]).reshape(1, 3, 1), }, PercentileAggregator: { - 5: np.array([0.95, 5.95, 9.95]), - 10: np.array([1.9, 7.9, 15.5]), - 90: np.array([75.1, 83.1, 91.1]), - 95: np.array([77.05, 85.05, 93.05]), + 5: np.array([[[0.95, 5.95, 9.95]]]).reshape(1, 3, 1), + 10: np.array([[[1.9, 7.9, 15.5]]]).reshape(1, 3, 1), + 90: np.array([[[75.1, 83.1, 91.1]]]).reshape(1, 3, 1), + 95: np.array([[[77.05, 85.05, 93.05]]]).reshape(1, 3, 1), }, } diff --git a/tests/post_training/test_templates/test_channel_alignment.py b/tests/post_training/test_templates/test_channel_alignment.py index 950852ef3c2..27032965e05 100644 --- a/tests/post_training/test_templates/test_channel_alignment.py +++ b/tests/post_training/test_templates/test_channel_alignment.py @@ -492,5 +492,5 @@ def test_statistic_collectors(self, inplace_ref, q_ref): for aggr in statistic_collector.aggregators.values(): assert isinstance(aggr, MedianAggregator) assert aggr.num_samples == num_samples_ref - assert not aggr._keepdims + assert aggr._keepdims assert aggr._aggregation_axes == (0,)