Skip to content

Commit

Permalink
StatisticsAggregator returns original tensor shape (#2213)
Browse files Browse the repository at this point in the history
### Changes

Make StatisticsAggreagtor keep the original tensor share after
aggregation.

### Reason for changes

To add support of correct handling statistics in case batch_size > 1.

### Related tickets

121650

### Tests

All tests are updated accordingly
  • Loading branch information
kshpv authored Nov 6, 2023
1 parent 2944eb4 commit 94d1f9c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 90 deletions.
62 changes: 28 additions & 34 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
tensor_processor: NNCFCollectorTensorProcessor,
aggregation_axes: Optional[AggregationAxes] = None,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
):
"""
:param tensor_processor: Backend-specific tensor processor.
Expand All @@ -134,14 +135,17 @@ def __init__(
:param num_samples: Maximum number of samples to collect. Aggregator
skips tensor registration if tensor registration was called num_samples times before.
Aggregator never skips registration if num_samples is None.
:param window_size: Number of samples from the end of the list of collected samples to aggregate.
Aggregates all available collected statistics in case parameter is None.
"""

self._tensor_processor = tensor_processor
self._aggregation_axes = (0,) if aggregation_axes is None else aggregation_axes
self._keepdims = False
self._keepdims = True
self._num_samples = num_samples
self._collected_samples = 0
self._container = []
self._window_size = window_size
self._container = deque(maxlen=window_size)

@property
def num_samples(self) -> int:
Expand Down Expand Up @@ -594,20 +598,7 @@ def _aggregate_impl(self):
return self._container.shape


class TensorAggregatorBase(AggregatorBase, ABC):
def __init__(
self,
tensor_processor: NNCFCollectorTensorProcessor,
aggregation_axes: Optional[AggregationAxes] = None,
num_samples: Optional[int] = None,
window_size=None,
):
super().__init__(tensor_processor, aggregation_axes=aggregation_axes, num_samples=num_samples)
self._window_size = window_size
self._container = deque(maxlen=window_size)


class OnlineAggregatorBase(TensorAggregatorBase, ABC):
class OnlineAggregatorBase(AggregatorBase, ABC):
"""
Base class for aggregators which are using aggregation function fn with following property:
fn([x1, x2, x3]) == fn([fn([x1, x2]), x3]) where x1, x2, x3 are samples to aggregate.
Expand All @@ -622,19 +613,17 @@ def _register_reduced_input_impl(self, x: NNCFTensor) -> None:
else:
reduced = x
if 0 in self._aggregation_axes:
if self._container:
reduced = self._aggregation_fn(
self._tensor_processor.stack([reduced, self._container]), axis=0, keepdims=False
)
self._container = reduced
stacked_tensors = self._tensor_processor.stack([reduced, *self._container], axis=0)
aggregated = self._aggregation_fn(stacked_tensors, axis=0, keepdims=self._keepdims)
aggregated = self._tensor_processor.squeeze(aggregated, 0)
self._container = [aggregated]
else:
self._container.append(reduced)

def _aggregate_impl(self) -> NNCFTensor:
if 0 in self._aggregation_axes:
if self._keepdims:
return self._tensor_processor.stack([self._container]).tensor
return self._container.tensor
return self._container[0].tensor
return self._tensor_processor.stack(self._container).tensor

@abstractmethod
Expand All @@ -652,7 +641,7 @@ def _aggregation_fn(self, stacked_value: NNCFTensor, axis: AggregationAxes, keep
return self._tensor_processor.reduce_max(stacked_value, axis=axis, keepdims=keepdims)


class OfflineAggregatorBase(TensorAggregatorBase, ABC):
class OfflineAggregatorBase(AggregatorBase, ABC):
"""
Base class for aggregators which are using aggregation function fn which
does not fulfill property fn([x1, x2, x3]) == fn([fn([x1, x2]), x3])
Expand All @@ -665,7 +654,8 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:

def _aggregate_impl(self) -> NNCFTensor:
stacked_val = self._tensor_processor.stack(self._container)
return self._aggregation_fn(stacked_val, axis=self._aggregation_axes, keepdims=self._keepdims).tensor
aggregated = self._aggregation_fn(stacked_val, axis=self._aggregation_axes, keepdims=self._keepdims)
return self._tensor_processor.squeeze(aggregated, 0).tensor

@abstractmethod
def _aggregation_fn(self, stacked_value: NNCFTensor, axis: AggregationAxes, keepdims: bool) -> NNCFTensor:
Expand Down Expand Up @@ -699,13 +689,19 @@ def __init__(
def _aggregate_impl(self) -> NNCFTensor:
stacked_samples = self._tensor_processor.stack(self._container)
low_values, high_values = self._tensor_processor.quantile(
stacked_samples, quantile=(self._quantile, 1 - self._quantile), axis=self._aggregation_axes
stacked_samples,
quantile=(self._quantile, 1 - self._quantile),
axis=self._aggregation_axes,
)
tp = self._tensor_processor
outliers_mask = tp.logical_or(tp.less(stacked_samples, low_values), tp.less(high_values, stacked_samples))
return self._aggregation_fn(
stacked_samples=stacked_samples, mask=outliers_mask, axis=self._aggregation_axes, keepdims=self._keepdims
).tensor
aggregated = self._aggregation_fn(
stacked_samples=stacked_samples,
mask=outliers_mask,
axis=self._aggregation_axes,
keepdims=self._keepdims,
)
return self._tensor_processor.squeeze(aggregated, 0).tensor

@abstractmethod
def _aggregation_fn(
Expand Down Expand Up @@ -734,7 +730,7 @@ def _aggregation_fn(
return self._tensor_processor.masked_median(stacked_samples, axis=axis, mask=mask, keepdims=keepdims)


class MedianAbsoluteDeviationAggregator(TensorAggregatorBase):
class MedianAbsoluteDeviationAggregator(AggregatorBase):
def __init__(
self,
tensor_processor: NNCFCollectorTensorProcessor,
Expand Down Expand Up @@ -780,7 +776,7 @@ def _aggregate_impl(self) -> Dict[str, NNCFTensor]:
}


class PercentileAggregator(TensorAggregatorBase):
class PercentileAggregator(AggregatorBase):
def __init__(
self,
tensor_processor: NNCFCollectorTensorProcessor,
Expand Down Expand Up @@ -850,9 +846,7 @@ def _moveaxes_flatten_cat(
transposed_t = tensor_processor.transpose(tensor, transpose_dims)
reshaped_tensors.append(tensor_processor.reshape(transposed_t, reshape_shape))

shape_after_aggregation = (1,) + tuple(
1 if idx in aggregation_axes else dim for idx, dim in enumerate(tensor_shape)
)
shape_after_aggregation = tuple(1 if idx in aggregation_axes else dim for idx, dim in enumerate(tensor_shape))
return tensor_processor.cat(reshaped_tensors, axis=0), shape_after_aggregation


Expand Down
6 changes: 3 additions & 3 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.common.tensor_statistics.collectors import NNCFTensor
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import AggregatorBase
from nncf.experimental.common.tensor_statistics.collectors import BatchMeanReducer
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import MaxReducer
Expand All @@ -33,7 +34,6 @@
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorAggregatorBase
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.quantization.advanced_parameters import StatisticsType
from nncf.torch.tensor import PTNNCFTensor
Expand Down Expand Up @@ -443,8 +443,8 @@ def get_percentile_tensor_collector(


def _get_collection_without_reduction(
aggregator_cls: TensorAggregatorBase,
statistic_cls: TensorAggregatorBase,
aggregator_cls: AggregatorBase,
statistic_cls: AggregatorBase,
reduction_axes: Tuple[int, ...],
aggregation_axes: Tuple[int, ...],
num_samples: int,
Expand Down
102 changes: 50 additions & 52 deletions tests/common/experimental/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,49 +59,47 @@ class OfflineAggregatorTestCase:


OFFLINE_AGGREGATORS_TEST_CASES = [
OfflineAggregatorTestCase(
aggregation_axes=None,
min_ref=np.array([[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]]),
max_ref=np.array([[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]]),
),
OfflineAggregatorTestCase(
aggregation_axes=(0,),
min_ref=np.array([[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]]),
max_ref=np.array([[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]]),
),
OfflineAggregatorTestCase(
aggregation_axes=(0, 2),
min_ref=np.array([[-50000, -28, -32]]),
max_ref=np.array([[50000, 28, 32]]),
aggregation_axes=(
0,
2,
),
min_ref=np.array([[[-50000, -28, -32]]]),
max_ref=np.array([[[50000, 28, 32]]]),
),
OfflineAggregatorTestCase(
aggregation_axes=(2,),
min_ref=np.array(
[
[[-50000, 5, 10]],
[[-40000, 4, 8]],
[[-30000, 3, 6]],
[[-20000, 2, 4]],
[[-10000, 1, 2]],
[[0, 0, 0]],
[[-6, -7, -8]],
[[-12, -14, -16]],
[[-18, -21, -24]],
[[-24, -28, -32]],
[[[-50000, 5, 10]]],
[[[-40000, 4, 8]]],
[[[-30000, 3, 6]]],
[[[-20000, 2, 4]]],
[[[-10000, 1, 2]]],
[[[0, 0, 0]]],
[[[-6, -7, -8]]],
[[[-12, -14, -16]]],
[[[-18, -21, -24]]],
[[[-24, -28, -32]]],
]
),
max_ref=np.array(
[
[[50000, -5, -10]],
[[40000, -4, -8]],
[[30000, -3, -6]],
[[20000, -2, -4]],
[[10000, -1, -2]],
[[0, 0, 0]],
[[6, 7, 8]],
[[12, 14, 16]],
[[18, 21, 24]],
[[24, 28, 32]],
[[[50000, -5, -10]]],
[[[40000, -4, -8]]],
[[[30000, -3, -6]]],
[[[20000, -2, -4]]],
[[[10000, -1, -2]]],
[[[0, 0, 0]]],
[[[6, 7, 8]]],
[[[12, 14, 16]]],
[[[18, 21, 24]]],
[[[24, 28, 32]]],
]
),
),
Expand Down Expand Up @@ -254,38 +252,38 @@ def test_min_max_aggregators(
assert self.all_close(max_aggregator.aggregate(), max_ref)

NO_OUTLIERS_TEST_PARAMS = [
(MeanAggregator, True, 1, 1404.5138888888905),
(MedianAggregator, True, 1, 24.0),
(MeanAggregator, True, 1, [1404.5138888888905]),
(MedianAggregator, True, 1, [24.0]),
(
MeanAggregator,
False,
1,
[2503.125, -2493.75, 5009.375, -4987.5, 7515.625, -7481.25, 10021.875, -9975.0, 12528.125],
),
(MedianAggregator, False, 1, [4.5, 5.0, 13.5, 10.0, 22.5, 15.0, 31.5, 20.0, 40.5]),
(MeanAggregator, True, 2, [2512.5, -1651.04166667, 3352.08333333]),
(MedianAggregator, True, 2, [13.0, 12.5, 21.0]),
(MeanAggregator, True, 2, [[2512.5, -1651.04166667, 3352.08333333]]),
(MedianAggregator, True, 2, [[13.0, 12.5, 21.0]]),
(MeanAggregator, False, 2, DEFALUT_3D_MEAN_VALUE),
(MedianAggregator, False, 2, DEFALUT_3D_MEDIAN_VALUE),
(MeanAggregator, True, 3, DEFALUT_3D_MEAN_VALUE),
(MedianAggregator, True, 3, DEFALUT_3D_MEDIAN_VALUE),
(MeanAggregator, True, 3, [DEFALUT_3D_MEAN_VALUE]),
(MedianAggregator, True, 3, [DEFALUT_3D_MEDIAN_VALUE]),
(MeanAggregator, False, 3, [DEFALUT_3D_MEAN_VALUE]),
(MedianAggregator, False, 3, [DEFALUT_3D_MEDIAN_VALUE]),
(default_test_mean_no_outlier, True, 1, 20.0893),
(default_test_median_no_outlier, True, 1, 30.0),
(default_test_mean_no_outlier, True, 1, [20.0893]),
(default_test_median_no_outlier, True, 1, [30.0]),
(
default_test_mean_no_outlier,
False,
1,
[4.16666667, 8.33333333, 12.5, 16.66666667, 20.83333333, 25.0, 29.16666667, 33.33333333, 37.5],
),
(default_test_median_no_outlier, False, 1, [5.0, 4.0, 15.0, 8.0, 25.0, 12.0, 35.0, 16.0, 45.0]),
(default_test_mean_no_outlier, True, 2, [16.66666667, 20.83333333, 25.0]),
(default_test_median_no_outlier, True, 2, [14.0, 10.0, 24.0]),
(default_test_mean_no_outlier, True, 2, [[16.66666667, 20.83333333, 25.0]]),
(default_test_median_no_outlier, True, 2, [[14.0, 10.0, 24.0]]),
(default_test_mean_no_outlier, False, 2, NO_OUTLIERS_DEFAULT_3D_MEAN_VALUE),
(default_test_median_no_outlier, False, 2, NO_OUTLIERS_DEFAULT_3D_MEDIAN_VALUE),
(default_test_mean_no_outlier, True, 3, NO_OUTLIERS_DEFAULT_3D_MEAN_VALUE),
(default_test_median_no_outlier, True, 3, NO_OUTLIERS_DEFAULT_3D_MEDIAN_VALUE),
(default_test_mean_no_outlier, True, 3, [NO_OUTLIERS_DEFAULT_3D_MEAN_VALUE]),
(default_test_median_no_outlier, True, 3, [NO_OUTLIERS_DEFAULT_3D_MEDIAN_VALUE]),
(default_test_mean_no_outlier, False, 3, [NO_OUTLIERS_DEFAULT_3D_MEAN_VALUE]),
(default_test_median_no_outlier, False, 3, [NO_OUTLIERS_DEFAULT_3D_MEDIAN_VALUE]),
]
Expand Down Expand Up @@ -348,8 +346,8 @@ def aggregator_cls_fixture(self, request):
"mad_values": np.array([2.5, 5.0, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 22.5]),
},
(0, 1): {
"median_values": np.array(18.0),
"mad_values": np.array(12.0),
"median_values": np.array([18.0]),
"mad_values": np.array([12.0]),
},
},
PercentileAggregator: {
Expand All @@ -366,10 +364,10 @@ def aggregator_cls_fixture(self, request):
95: np.array([7.6, 15.2, 22.8, 30.4, 38.0, 45.6, 53.2, 60.8, 68.4]),
},
(0, 1): {
5: np.array(0.0),
10: np.array(0.0),
90: np.array(48.0),
95: np.array(56.0),
5: np.array([0.0]),
10: np.array([0.0]),
90: np.array([48.0]),
95: np.array([56.0]),
},
},
}
Expand All @@ -389,14 +387,14 @@ def test_mad_percentile_aggregators(self, MAD_precentile_aggregator_cls, tensor_

REF_MAD_PERCENTILE_REF_VALUES_DYNAMIC_TENSORS = {
MedianAbsoluteDeviationAggregator: {
"median_values": np.array([28.5, 35.5, 43.5]),
"mad_values": np.array([24.0, 24.0, 24.0]),
"median_values": np.array([[28.5, 35.5, 43.5]]).reshape(1, 3, 1),
"mad_values": np.array([[[24.0, 24.0, 24.0]]]).reshape(1, 3, 1),
},
PercentileAggregator: {
5: np.array([0.95, 5.95, 9.95]),
10: np.array([1.9, 7.9, 15.5]),
90: np.array([75.1, 83.1, 91.1]),
95: np.array([77.05, 85.05, 93.05]),
5: np.array([[[0.95, 5.95, 9.95]]]).reshape(1, 3, 1),
10: np.array([[[1.9, 7.9, 15.5]]]).reshape(1, 3, 1),
90: np.array([[[75.1, 83.1, 91.1]]]).reshape(1, 3, 1),
95: np.array([[[77.05, 85.05, 93.05]]]).reshape(1, 3, 1),
},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,5 +492,5 @@ def test_statistic_collectors(self, inplace_ref, q_ref):
for aggr in statistic_collector.aggregators.values():
assert isinstance(aggr, MedianAggregator)
assert aggr.num_samples == num_samples_ref
assert not aggr._keepdims
assert aggr._keepdims
assert aggr._aggregation_axes == (0,)

0 comments on commit 94d1f9c

Please sign in to comment.