Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ][Torch][Native] Experimental tensor statistics migration #2117

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 96 additions & 41 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from abc import ABC
from abc import abstractmethod
from collections import deque
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np

Expand All @@ -21,14 +21,13 @@
from nncf.common.tensor import TensorType
from nncf.common.tensor_statistics.reduction import get_per_channel_history

ReductionShape = Tuple[int]
MaskedReduceFN = Callable[[NNCFTensor, Union[int, tuple, list], NNCFTensor, bool], NNCFTensor]
ReductionAxes = Tuple[int]


class TensorStatisticCollectorBase(ABC):
"""Collector estimate statistics at the quantization point based on the provided reduction shape."""

def __init__(self, reduction_shape: Optional[ReductionShape] = None, num_samples: Optional[int] = None):
def __init__(self, reduction_shape: Optional[ReductionAxes] = None, num_samples: Optional[int] = None):
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
"""
Initializes Tensor Statistic Collector

Expand Down Expand Up @@ -101,7 +100,7 @@ class OfflineTensorStatisticCollector(TensorStatisticCollectorBase):
"""Collects statistics in offline regime by storing the data and aggregating it afterwards."""

def __init__(
self, reduction_shape: Optional[ReductionShape] = None, num_samples: int = None, window_size: int = None
self, reduction_shape: Optional[ReductionAxes] = None, num_samples: int = None, window_size: int = None
):
super().__init__(reduction_shape, num_samples)
self._samples = deque(maxlen=window_size)
Expand All @@ -117,7 +116,7 @@ class NNCFCollectorTensorProcessor(ABC):

@staticmethod
@abstractmethod
def reduce_min(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor:
def reduce_min(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
"""
Computes minimum of elements across dimensions of NNCFTensor.

Expand All @@ -130,7 +129,7 @@ def reduce_min(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = Fa

@staticmethod
@abstractmethod
def reduce_max(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor:
def reduce_max(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor:
"""
Computes maximum of elements across dimensions of NNCFTensor.

Expand Down Expand Up @@ -175,7 +174,7 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor:

@staticmethod
@abstractmethod
def mean(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor:
def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
"""
Computes the mean of elements across given dimensions of NNCFTensor.

Expand All @@ -188,7 +187,7 @@ def mean(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTe

@staticmethod
@abstractmethod
def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor:
def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor:
"""
Computes the median of elements across given dimensions of NNCFTensor.

Expand All @@ -199,9 +198,11 @@ def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCF
:return: Reduced NNCFTensor.
"""

@staticmethod
@classmethod
@abstractmethod
def masked_mean(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
def masked_mean(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
) -> NNCFTensor:
"""
Computes the masked mean of elements across given dimensions of NNCFTensor.

Expand All @@ -214,9 +215,11 @@ def masked_mean(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor,
:return: Reduced NNCFTensor.
"""

@staticmethod
@classmethod
@abstractmethod
def masked_median(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
def masked_median(
cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False
) -> NNCFTensor:
"""
Computes the masked median of elements across given dimensions of NNCFTensor.

Expand Down Expand Up @@ -251,6 +254,16 @@ def unstack(x: NNCFTensor, axis: int = 0) -> List[NNCFTensor]:
:return: List of NNCFTensor.
"""

@staticmethod
@abstractmethod
def squeeze(x: NNCFTensor, dim: Optional[Union[int, Tuple[int, ...]]] = None) -> NNCFTensor:
"""
Remove axes of length one from x.

:param x: NNCFTensor to squeeze.
:param axis: Selects a subset of the entries of length one in the shape.
"""

@staticmethod
@abstractmethod
def sum(tensor: NNCFTensor) -> TensorElementsType:
Expand All @@ -264,18 +277,42 @@ def sum(tensor: NNCFTensor) -> TensorElementsType:
@staticmethod
@abstractmethod
def quantile(
tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False
tensor: NNCFTensor,
quantile: Union[float, List[float]],
axis: Union[int, Tuple[int, ...], List[int]],
keepdims: bool = False,
) -> List[TensorElementsType]:
"""
Compute the quantile-th percentile(s) of the data along the specified axis.
Compute the quantile(s) of the data along the specified axis.

:param tensor: Given NNCFTensor.
:params quantile: Percentile or sequence of percentiles to compute, which must be between
:params quantile: Quantile or sequence of quantiles to compute, which must be between
0 and 1 inclusive.
:param axis: Axis or axes along which the quantiles are computed.
:param keepdims: If True, the axes which are reduced are left in the result
as dimensions with size one.
:returns: List of the quantile(s) of the tensor elements.
"""

@classmethod
@abstractmethod
def percentile(
cls,
tensor: NNCFTensor,
percentile: Union[float, List[float]],
axis: Union[int, Tuple[int, ...], List[int]],
keepdims: bool = False,
) -> List[TensorElementsType]:
"""
Compute the percentile(s) of the data along the specified axis.

:param tensor: Given NNCFTensor.
:params percentile: percentile or sequence of percentiles to compute, which must be between
0 and 100 inclusive.
:param axis: Axis or axes along which the percentiles are computed.
:param keepdims: If True, the axes which are reduced are left in the result
as dimensions with size one.
:returns: List of the quantile-th percentile(s) of the tensor elements.
:returns: List of the percentile(s) of the tensor elements.
"""

@staticmethod
Expand All @@ -289,27 +326,47 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
:return: Reduced NNCFTensor.
"""

@classmethod
@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
"""
Computes the element-wise logical OR of the given input tensors.
Zeros are treated as False and nonzeros are treated as True.

:param input_: The input tensor.
:param other: The tensor to compute or with.
:return: Result of elementwise or operation between input_ and other tensor.
"""

@staticmethod
def less(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
"""
Return the truth value of (x1 < x2) element-wise.

:param input_: The input tensor.
:param other: The tensor to compute or with.
:return: Result of elementwise less operation between input_ and other tensor.
"""

@staticmethod
@abstractmethod
def no_outliers_map(cls, x: NNCFTensor, fn: MaskedReduceFN, axis: int = 0, alpha: float = 0.01) -> NNCFTensor:
def sub(a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
"""
Returns result of a substract b operation.
"""
Computes quantiles [alpha, 1 - alpha] on given tensor, masks all elements that
are smaller that alpha and bigger than 1 - alpha quantile and applies
given masked reduction function fn.

:param tensor: Given NNCFTensor.
:param fn: Masked reduce operation from the same NNCFCollectorTensorProcessor class.
:param axis: Axis along which the reduction function is computed.
:params alpha: Minimal percentile to filter outliers outside the range
[quantile(alpha), quantile(1 - alpha)]. Must be between 0 and 1. inclusive.
:returns: Result of given masked reduction function on filtered from outliers NNCFTensor.
@classmethod
@abstractmethod
def zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
"""
Returns binary mask from the input x which equal true for all elemets that are smaller than
corresponding machine epsilon.
"""


class MinMaxStatisticCollector(OnlineTensorStatisticCollector):
"""Collector estimates min of minimum values and max of maximum values."""

def __init__(self, use_abs_max: bool, reduction_shape: ReductionShape, num_samples: int = None):
def __init__(self, use_abs_max: bool, reduction_shape: ReductionAxes, num_samples: int = None):
super().__init__(reduction_shape, num_samples)
self._use_abs_max = use_abs_max
self._tensor_processor = self._get_processor()
Expand Down Expand Up @@ -353,7 +410,7 @@ def __init__(
self,
use_per_sample_stats: bool,
use_abs_max: bool,
reduction_shape: ReductionShape,
reduction_shape: ReductionAxes,
num_samples: int = None,
window_size: int = None,
):
Expand Down Expand Up @@ -407,7 +464,7 @@ def __init__(
use_abs_max: bool,
use_means_of_mins: bool,
use_means_of_maxs: bool,
reduction_shape: ReductionShape,
reduction_shape: ReductionAxes,
num_samples: int = None,
window_size: int = None,
):
Expand Down Expand Up @@ -447,17 +504,15 @@ class MeanStatisticCollector(OfflineTensorStatisticCollector):
Collector that aggregates statistics as mean along a pre-assigned axis.
"""

def __init__(
self, reduction_shape: ReductionShape, num_samples: Optional[int] = None, window_size: Optional[int] = None
) -> None:
def __init__(self, channel_axis: int, num_samples: Optional[int] = None, window_size: Optional[int] = None) -> None:
"""
:param reduction_shape: The shape for the reduction while statistics collection.
For the MeanStatisticCollector this parameter contains the main axis.
:param channel_axis: The main axis for the reduction while statistics collection.
:param num_samples: Optional parameter for statistic collection that regulates
the number of samples that will be processed.
:param window_size: Optional maximum length for the statistic collection
"""
super().__init__(reduction_shape, num_samples)
super().__init__(num_samples=num_samples)
self._channel_axis = channel_axis
self._tensor_processor = self._get_processor()
self._all_values = deque(maxlen=window_size)
self._all_shapes = deque(maxlen=window_size)
Expand All @@ -468,10 +523,10 @@ def _get_processor():
pass

def _register_input_common(self, x: NNCFTensor):
if self._reduction_shape == 0:
if self._channel_axis == 0:
self._all_values.append(self._tensor_processor.batch_mean(x))
else:
self._all_values.append(self._tensor_processor.mean_per_channel(x, self._reduction_shape))
self._all_values.append(self._tensor_processor.mean_per_channel(x, self._channel_axis))
self._all_shapes.append(x.shape)

def _reset(self):
Expand Down Expand Up @@ -536,7 +591,7 @@ class PercentileStatisticCollector(OfflineTensorStatisticCollector):
def __init__(
self,
percentiles_to_collect: List[float],
reduction_shape: Optional[ReductionShape] = None,
reduction_shape: Optional[ReductionAxes] = None,
num_samples: int = None,
window_size: int = None,
):
Expand All @@ -561,7 +616,7 @@ class MeanPercentileStatisticCollector(OfflineTensorStatisticCollector):
def __init__(
self,
percentiles_to_collect: List[float],
reduction_shape: Optional[ReductionShape] = None,
reduction_shape: Optional[ReductionAxes] = None,
num_samples: int = None,
window_size: int = None,
):
Expand Down
7 changes: 7 additions & 0 deletions nncf/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
class TensorStatistic(ABC):
"""Base class that stores statistic data"""

TENSOR_STATISTIC_OUTPUT_KEY = "tensor_statistic_output"

@staticmethod
@abstractmethod
def tensor_eq(tensor1: TensorType, tensor2: TensorType, rtol=1e-6) -> bool:
Expand Down Expand Up @@ -63,6 +65,9 @@ def __eq__(self, other: "MeanTensorStatistic") -> bool:


class MedianMADTensorStatistic(TensorStatistic):
MEDIAN_VALUES_STAT = "median_values"
MAD_VALUES_STAT = "mad_values"

def __init__(self, median_values, mad_values):
self.median_values = median_values
self.mad_values = mad_values
Expand All @@ -74,6 +79,8 @@ def __eq__(self, other: "MedianMADTensorStatistic") -> bool:


class PercentileTensorStatistic(TensorStatistic):
PERCENTILE_VS_VALUE_DICT = "percentile_vs_values_dict"

def __init__(self, percentile_vs_values_dict):
self.percentile_vs_values_dict = percentile_vs_values_dict

Expand Down
Loading