Skip to content

Commit

Permalink
Make NNCF common tensor statistics code pass mypy checks (#2865)
Browse files Browse the repository at this point in the history
### Changes

General: 
- Added the target code path `nncf/common/tensor_statistics` to
`.mypy.ini.`
- Fixed mypy errors caused by type inconsistencies.

Used `# type:ignore` in the following cases:

- has no attribute
- incompatible types that couldn't be resolved without substantial code
changes.

### Related tickets

Closes Issue #2493 

### Tests

To validate that the changes did not affect the codebase, some pytest
tests were run.

---------

Co-authored-by: Daniil Lyakhov <[email protected]>
  • Loading branch information
rk119 and daniil-lyakhov authored Aug 30, 2024
1 parent 74cc471 commit ff967e8
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/, nncf/common/utils/
files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/, nncf/common/utils/, nncf/common/tensor_statistics
follow_imports = silent
strict = True

Expand Down
11 changes: 6 additions & 5 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.data.dataset import DataItem
from nncf.data.dataset import Dataset
from nncf.tensor import Tensor
from nncf.data.dataset import ModelInput

TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")
Expand All @@ -36,7 +37,7 @@ class StatisticsAggregator(ABC):
Base class for statistics collection.
"""

def __init__(self, dataset: Dataset):
def __init__(self, dataset: Dataset[DataItem, ModelInput]):
self.dataset = dataset
self.stat_subset_size = None
self.statistic_points = StatisticPointsContainer()
Expand Down Expand Up @@ -65,12 +66,12 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
model_transformer = factory.ModelTransformerFactory.create(model)
merged_statistics = self._get_merged_statistic_points(self.statistic_points, model, graph)
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
model_with_outputs = model_transformer.transform(transformation_layout)
model_with_outputs: TModel = model_transformer.transform(transformation_layout)
engine = factory.EngineFactory.create(model_with_outputs)

iterations_number = self._get_iterations_number()
empty_statistics = True
for input_data in track(
for input_data in track( # type: ignore
islice(self.dataset.get_inference_data(), iterations_number),
total=iterations_number,
description="Statistics collection",
Expand Down Expand Up @@ -141,7 +142,7 @@ def _get_merged_statistic_points(

@staticmethod
@abstractmethod
def _process_outputs(outputs: Any) -> Dict[str, Tensor]:
def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
"""
Post-process model outputs for the further statistics collection.
Expand Down
93 changes: 48 additions & 45 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from abc import ABC
from abc import abstractmethod
from collections import deque
from typing import List, Optional, Tuple
from typing import Any, Deque, Dict, List, Optional, Tuple, Union, cast

import numpy as np
from numpy.typing import NDArray

from nncf.common.tensor import NNCFTensor
from nncf.common.tensor import TensorType
Expand All @@ -39,7 +40,7 @@ def __init__(self, reduction_shape: Optional[ReductionAxes] = None, num_samples:
self._num_samples = num_samples

@property
def num_samples(self) -> int:
def num_samples(self) -> Union[int, None]:
return self._num_samples

def register_input(self, x: TensorType) -> TensorType:
Expand All @@ -49,38 +50,38 @@ def register_input(self, x: TensorType) -> TensorType:
if self._num_samples is not None and self._collected_samples >= self._num_samples:
return x
if self._reduction_shape is None:
self._reduction_shape = tuple(range(len(x.shape)))
self._reduction_shape = tuple(range(len(cast(NNCFTensor, x).shape)))
self._register_input(x)
self._collected_samples += 1
return x

@abstractmethod
def _register_input(self, x: TensorType):
def _register_input(self, x: TensorType) -> None:
pass

def get_statistics(self):
def get_statistics(self) -> None:
"""Returns collected statistics, if present."""
if self._collected_samples == 0:
raise StatisticsNotCollectedError()
return self._get_statistics()

@abstractmethod
def _get_statistics(self):
def _get_statistics(self) -> None:
pass

def enable(self):
def enable(self) -> None:
self._enabled = True

def disable(self):
def disable(self) -> None:
self._enabled = False

def reset(self):
def reset(self) -> None:
"""Resets all the statistics in the collector."""
self._collected_samples = 0
self._reset()

@abstractmethod
def _reset(self):
def _reset(self) -> None:
pass

def collected_samples(self) -> int:
Expand All @@ -102,9 +103,9 @@ def __init__(
self, reduction_shape: Optional[ReductionAxes] = None, num_samples: int = None, window_size: int = None
):
super().__init__(reduction_shape, num_samples)
self._samples = deque(maxlen=window_size)
self._samples: Deque[int] = deque(maxlen=window_size)

def _reset(self):
def _reset(self) -> None:
self._samples.clear()


Expand All @@ -121,10 +122,10 @@ def __init__(self, use_abs_max: bool, reduction_shape: ReductionAxes, num_sample

@staticmethod
@abstractmethod
def _get_processor():
def _get_processor() -> Any:
pass

def _register_input_common(self, x: NNCFTensor):
def _register_input_common(self, x: NNCFTensor) -> None:
min_reduced = self._tensor_processor.reduce_min(x, self._reduction_shape)
if self._use_abs_max:
x = self._tensor_processor.abs(x)
Expand All @@ -140,7 +141,7 @@ def _register_input_common(self, x: NNCFTensor):
else:
self._max_values = self._tensor_processor.max(max_reduced, self._max_values)

def _reset(self):
def _reset(self) -> None:
self._min_values = None
self._max_values = None

Expand All @@ -164,15 +165,15 @@ def __init__(
self._use_abs_max = use_abs_max
self._tensor_processor = self._get_processor()

self._all_min_values = deque(maxlen=window_size)
self._all_max_values = deque(maxlen=window_size)
self._all_min_values: Deque[int] = deque(maxlen=window_size)
self._all_max_values: Deque[int] = deque(maxlen=window_size)

@staticmethod
@abstractmethod
def _get_processor():
def _get_processor() -> Any:
pass

def _register_input_common(self, x: NNCFTensor):
def _register_input_common(self, x: NNCFTensor) -> None:
min_reduced = self._tensor_processor.reduce_min(x, self._reduction_shape)
if self._use_abs_max:
x = self._tensor_processor.abs(x)
Expand All @@ -186,14 +187,14 @@ def _register_input_common(self, x: NNCFTensor):
self._all_max_values.append(max_reduced)

@abstractmethod
def _min_aggregate(self):
def _min_aggregate(self) -> None:
pass

@abstractmethod
def _max_aggregate(self):
def _max_aggregate(self) -> None:
pass

def _reset(self):
def _reset(self) -> None:
self._all_min_values.clear()
self._all_max_values.clear()

Expand All @@ -217,13 +218,13 @@ def __init__(
self._use_means_of_mins = use_means_of_mins
self._use_means_of_maxs = use_means_of_maxs

def _min_aggregate(self):
def _min_aggregate(self) -> Any:
stacked_min = self._tensor_processor.stack(self._all_min_values)
if self._use_means_of_mins:
return self._tensor_processor.mean(stacked_min, axis=0)
return self._tensor_processor.reduce_min(stacked_min, axis=0)

def _max_aggregate(self):
def _max_aggregate(self) -> Any:
stacked_max = self._tensor_processor.stack(self._all_max_values)
if self._use_means_of_maxs:
return self._tensor_processor.mean(stacked_max, axis=0)
Expand All @@ -235,11 +236,11 @@ class MeanMinMaxStatisticCollector(MinMaxOfflineStatisticCollectorBase):
Collector aggregates mean of minimum values and mean of maximum values.
"""

def _min_aggregate(self):
def _min_aggregate(self) -> Any:
stacked_min = self._tensor_processor.stack(self._all_min_values)
return self._tensor_processor.mean(stacked_min, axis=0)

def _max_aggregate(self):
def _max_aggregate(self) -> Any:
stacked_max = self._tensor_processor.stack(self._all_max_values)
return self._tensor_processor.mean(stacked_max, axis=0)

Expand All @@ -259,30 +260,30 @@ def __init__(self, channel_axis: int, num_samples: Optional[int] = None, window_
super().__init__(num_samples=num_samples)
self._channel_axis = channel_axis
self._tensor_processor = self._get_processor()
self._all_values = deque(maxlen=window_size)
self._all_shapes = deque(maxlen=window_size)
self._all_values: Deque[int] = deque(maxlen=window_size)
self._all_shapes: Deque[int] = deque(maxlen=window_size)

@staticmethod
@abstractmethod
def _get_processor():
def _get_processor() -> Any:
pass

def _register_input_common(self, x: NNCFTensor):
def _register_input_common(self, x: NNCFTensor) -> None:
if self._channel_axis == 0:
self._all_values.append(self._tensor_processor.batch_mean(x))
else:
self._all_values.append(self._tensor_processor.mean_per_channel(x, self._channel_axis))
self._all_shapes.append(x.shape)
self._all_shapes.append(cast(int, x.shape))

def _reset(self):
def _reset(self) -> None:
self._all_values.clear()
self._all_shapes.clear()

def _mean_aggregate(self):
def _mean_aggregate(self) -> Any:
all_values_stack = self._tensor_processor.stack(self._all_values)
return self._tensor_processor.mean(all_values_stack, 0)

def _shape(self):
def _shape(self) -> Any:
return self._all_shapes[0]


Expand All @@ -298,17 +299,17 @@ def __init__(self, num_samples: Optional[int] = None) -> None:
the number of samples that will be processed.
"""
super().__init__(num_samples=num_samples)
self._all_values = []
self._all_values: List[int] = []

@staticmethod
@abstractmethod
def _get_processor():
def _get_processor() -> Any:
pass

def _register_input_common(self, x: NNCFTensor):
self._all_values.append(x.tensor)
def _register_input_common(self, x: NNCFTensor) -> None:
self._all_values.append(cast(int, x.tensor))

def _reset(self):
def _reset(self) -> None:
self._all_values.clear()


Expand All @@ -317,8 +318,10 @@ class MedianMADStatisticCollector(OfflineTensorStatisticCollector):
Collector estimates median and median absolute deviation (MAD).
"""

def _prepare_statistics(self):
per_channel_history = get_per_channel_history(self._samples, list(self._reduction_shape), discard_zeros=True)
def _prepare_statistics(self) -> Tuple[NDArray[Any], NDArray[Any]]:
per_channel_history = get_per_channel_history(
self._samples, cast(List[int], self._reduction_shape), discard_zeros=True
)
per_channel_median = [np.median(channel_hist) for channel_hist in per_channel_history]
per_channel_mad = []
for idx, median in enumerate(per_channel_median):
Expand All @@ -343,8 +346,8 @@ def __init__(
super().__init__(reduction_shape, num_samples, window_size)
self._percentiles_to_collect = percentiles_to_collect

def _prepare_statistics(self):
per_channel_history = get_per_channel_history(self._samples, list(self._reduction_shape))
def _prepare_statistics(self) -> Dict[float, Any]:
per_channel_history = get_per_channel_history(self._samples, cast(List[int], self._reduction_shape))
percentile_vs_values_dict = {}
for pc in self._percentiles_to_collect:
per_channel_percentiles = [np.percentile(channel_hist, pc) for channel_hist in per_channel_history]
Expand All @@ -366,10 +369,10 @@ def __init__(
window_size: int = None,
):
super().__init__(reduction_shape, num_samples, window_size)
self._all_pct_values = {}
self._all_pct_values: Dict[float, Any] = {}
for pc in percentiles_to_collect:
self._all_pct_values[pc] = deque(maxlen=window_size)

def _reset(self):
def _reset(self) -> None:
for _, val in self._all_pct_values.items():
val.clear()
14 changes: 8 additions & 6 deletions nncf/common/tensor_statistics/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque
from typing import List, Tuple
from typing import Any, Deque, List, Tuple

import numpy as np
from numpy.typing import NDArray


def get_channel_count_and_dim_idx(scale_shape: List[int]) -> Tuple[int, int]:
Expand All @@ -25,7 +25,7 @@ def get_channel_count_and_dim_idx(scale_shape: List[int]) -> Tuple[int, int]:
return channel_count, channel_dim_idx


def split_into_channels(input_: np.ndarray, scale_shape: List[int]) -> List[np.ndarray]:
def split_into_channels(input_: NDArray[Any], scale_shape: List[int]) -> List[NDArray[Any]]:
channel_count, channel_dim_idx = get_channel_count_and_dim_idx(scale_shape)
channel_first_tensor = np.moveaxis(input_, channel_dim_idx, 0)
if channel_count == 1:
Expand All @@ -37,9 +37,11 @@ def split_into_channels(input_: np.ndarray, scale_shape: List[int]) -> List[np.n
return ret_list


def get_per_channel_history(raw_input_history: deque, scale_shape: List[int], discard_zeros=False) -> List:
def get_per_channel_history(
raw_input_history: Deque[Any], scale_shape: List[int], discard_zeros: bool = False
) -> List[Any]:
channel_count, _ = get_channel_count_and_dim_idx(scale_shape)
per_channel_history = [None for i in range(channel_count)]
per_channel_history: List[Any] = [None for i in range(channel_count)]
for _ in range(len(raw_input_history)):
entry = raw_input_history.popleft()
split = split_into_channels(entry, scale_shape)
Expand All @@ -59,7 +61,7 @@ def get_per_channel_history(raw_input_history: deque, scale_shape: List[int], di
return per_channel_history


def np_percentile_reduce_like(input_: np.array, ref_tensor_shape: Tuple[int], q: float) -> np.array:
def np_percentile_reduce_like(input_: NDArray[Any], ref_tensor_shape: Tuple[int], q: float) -> NDArray[Any]:
numel = np.prod(ref_tensor_shape)
if numel == 1:
return np.array([np.percentile(input_, q)])
Expand Down
Loading

0 comments on commit ff967e8

Please sign in to comment.