From 55118d5121f0188100c7539c25a50768a47686f3 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 26 Sep 2023 18:02:39 +0200 Subject: [PATCH] Legacy MeanStatisticCollector typing fix --- nncf/common/tensor_statistics/collectors.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/nncf/common/tensor_statistics/collectors.py b/nncf/common/tensor_statistics/collectors.py index 005e4ceebd8..97cf1b8de99 100644 --- a/nncf/common/tensor_statistics/collectors.py +++ b/nncf/common/tensor_statistics/collectors.py @@ -499,17 +499,14 @@ class MeanStatisticCollector(OfflineTensorStatisticCollector): Collector that aggregates statistics as mean along a pre-assigned axis. """ - def __init__( - self, reduction_shape: ReductionAxes, num_samples: Optional[int] = None, window_size: Optional[int] = None - ) -> None: + def __init__(self, channel_axis: int, num_samples: Optional[int] = None, window_size: Optional[int] = None) -> None: """ - :param reduction_shape: The shape for the reduction while statistics collection. - For the MeanStatisticCollector this parameter contains the main axis. + :param channel_axis: The main axis for the reduction while statistics collection. :param num_samples: Optional parameter for statistic collection that regulates the number of samples that will be processed. :param window_size: Optional maximum length for the statistic collection """ - super().__init__(reduction_shape, num_samples) + super().__init__(channel_axis, num_samples) self._tensor_processor = self._get_processor() self._all_values = deque(maxlen=window_size) self._all_shapes = deque(maxlen=window_size)