Skip to content

Commit

Permalink
Legacy MeanStatisticCollector typing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 26, 2023
1 parent cca8660 commit 55118d5
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,17 +499,14 @@ class MeanStatisticCollector(OfflineTensorStatisticCollector):
Collector that aggregates statistics as mean along a pre-assigned axis.
"""

def __init__(
self, reduction_shape: ReductionAxes, num_samples: Optional[int] = None, window_size: Optional[int] = None
) -> None:
def __init__(self, channel_axis: int, num_samples: Optional[int] = None, window_size: Optional[int] = None) -> None:
"""
:param reduction_shape: The shape for the reduction while statistics collection.
For the MeanStatisticCollector this parameter contains the main axis.
:param channel_axis: The main axis for the reduction while statistics collection.
:param num_samples: Optional parameter for statistic collection that regulates
the number of samples that will be processed.
:param window_size: Optional maximum length for the statistic collection
"""
super().__init__(reduction_shape, num_samples)
super().__init__(channel_axis, num_samples)
self._tensor_processor = self._get_processor()
self._all_values = deque(maxlen=window_size)
self._all_shapes = deque(maxlen=window_size)
Expand Down

0 comments on commit 55118d5

Please sign in to comment.