Skip to content

Commit

Permalink
reduction_axes -> channel_axis for bc/fbc
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 26, 2023
1 parent ef99ed4 commit cca8660
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 32 deletions.
4 changes: 2 additions & 2 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import abstractmethod
from collections import defaultdict
from collections import deque
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union

from nncf.common.tensor import TensorType
from nncf.common.tensor_statistics.collectors import NNCFCollectorTensorProcessor
Expand Down Expand Up @@ -398,7 +398,7 @@ def get_tensor_collector_inputs(
return target_inputs

@staticmethod
def _build_statistic_container(statistic_container_cls: TensorStatistic, kwargs: Dict[Any, Any]):
def _build_statistic_container(statistic_container_cls: Type[TensorStatistic], kwargs: Dict[Any, Any]):
if issubclass(statistic_container_cls, MinMaxTensorStatistic):
return statistic_container_cls(
min_values=kwargs[MinMaxTensorStatistic.MIN_STAT], max_values=kwargs[MinMaxTensorStatistic.MAX_STAT]
Expand Down
16 changes: 14 additions & 2 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,26 @@ def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)


def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace=True):
def get_mean_statistic_collector(
num_samples: int, channel_axis: int, window_size: Optional[int] = None, inplace: bool = True
) -> TensorCollector:
"""
Mean statistic collector builder.
:param num_samples: Maximum number of samples to collect.
:param channel_axis: Channel axis to use during reduction phase.
:param window_size: Number of samples from the end of the list of collected samples to aggregate.
Aggregates all available collected statistics in case parameter is None.
:param inplace: Whether the mean reducer should be calculated inplace or out of place.
:return: Mean statistic collector.
"""
# TODO(dlyakhov): use inplace OVBatchMeanReducer and OVMeanPerChanelReducer
# after migration on openvino-dev=2023.0
inplace = False
if channel_axis == 0:
reducer = OVBatchMeanReducer(inplace)
else:
reducer = OVMeanPerChanelReducer(channel_axis, inplace)
reducer = OVMeanPerChanelReducer(channel_dim=channel_axis, inplace=inplace)
noop_reducer = OVNoopReducer()

kwargs = {
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin
TargetType.POST_LAYER_OPERATION, node_name, port_id=OUTPUT_PORT_OF_NODE
)
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_axes=channel_axis, num_samples=self.subset_size, inplace=self.inplace_statistics
channel_axis=channel_axis, num_samples=self.subset_size, inplace=self.inplace_statistics
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand Down
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/bias_correction/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.common.utils.registry import Registry

Expand Down Expand Up @@ -87,15 +86,15 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: TargetPoint) -
@staticmethod
@abstractmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorStatisticCollectorBase:
"""
Returns backend-specific mean statistic collector.
:param reduction_axes: Channel axis for the statistics aggregation.
:param channel_axis: Channel axis for the statistics aggregation.
:param inplace: Whether to calculate statistic inplace or not.
:param num_samples: Maximum number of samples to collect.
:param window_size: The maximum size of the samples queue.
Expand Down
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/bias_correction/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.onnx.graph.model_utils import remove_fq_from_inputs
from nncf.onnx.graph.node_utils import get_bias_value
Expand Down Expand Up @@ -77,12 +76,12 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: ONNXTargetPoin

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> ONNXMeanStatisticCollector:
return ONNXMeanStatisticCollector(reduction_axes, num_samples, window_size)
return ONNXMeanStatisticCollector(channel_axis, num_samples, window_size)

@staticmethod
def raw_statistic_collector(inplace: bool, num_samples: int = None) -> ONNXMeanStatisticCollector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS
Expand All @@ -31,7 +30,7 @@
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor
from nncf.openvino.statistics.collectors import get_mean_stat_collector
from nncf.openvino.statistics.collectors import get_mean_statistic_collector
from nncf.openvino.statistics.collectors import get_raw_stat_collector
from nncf.openvino.tensor import OVNNCFTensor
from nncf.quantization.algorithms.bias_correction.backend import ALGO_BACKENDS
Expand Down Expand Up @@ -65,12 +64,12 @@ def output_insertion_command(nncf_graph: NNCFGraph, target_point: OVTargetPoint)

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_axes, window_size, inplace)
return get_mean_statistic_collector(num_samples, channel_axis, window_size, inplace)

@staticmethod
def raw_statistic_collector(inplace: bool, num_samples: int = None) -> TensorCollector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _add_statistic_point(self, container: StatisticPointsContainer, point: Targe
:param axis: Channel axis for the statistics calculation.
"""
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_axes=axis, num_samples=self.subset_size, inplace=self.inplace_statistics
channel_axis=axis, num_samples=self.subset_size, inplace=self.inplace_statistics
)
container.add_statistic_point(
StatisticPoint(target_point=point, tensor_collector=stat_collector, algorithm=self._algorithm_key)
Expand Down
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/fast_bias_correction/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.common.utils.registry import Registry

Expand Down Expand Up @@ -79,15 +78,15 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> Transform
@staticmethod
@abstractmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorStatisticCollectorBase:
"""
Returns backend-specific mean statistic collector.
:param reduction_axes: Channel axes for the statistics aggregation.
:param channel_axis: Channel axes for the statistics aggregation.
:param inplace: Whether to calculate statistic inplace or not.
:param num_samples: Maximum number of samples to collect.
:param window_size: The maximum size of the samples queue.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.onnx.graph.node_utils import get_bias_value
from nncf.onnx.graph.node_utils import is_any_weight_quantized
Expand Down Expand Up @@ -64,12 +63,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> ONNXModel

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> ONNXMeanStatisticCollector:
return ONNXMeanStatisticCollector(reduction_axes, num_samples, window_size)
return ONNXMeanStatisticCollector(channel_axis, num_samples, window_size)

@staticmethod
def get_sub_input_output_names(subgraph: onnx.ModelProto) -> Tuple[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS
Expand All @@ -28,7 +27,7 @@
from nncf.openvino.graph.transformations.commands import OVModelExtractionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor
from nncf.openvino.statistics.collectors import get_mean_stat_collector
from nncf.openvino.statistics.collectors import get_mean_statistic_collector
from nncf.openvino.tensor import OVNNCFTensor
from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
Expand Down Expand Up @@ -56,12 +55,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> OVModelEx

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_stat_collector(num_samples, reduction_axes, window_size, inplace)
return get_mean_statistic_collector(num_samples, channel_axis, window_size, inplace)

@staticmethod
def get_sub_input_output_names(subgraph: ov.Model) -> Tuple[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from nncf.common.graph import NNCFNode
from nncf.common.graph.definitions import NNCFGraphNodeType
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS
Expand All @@ -34,7 +33,7 @@
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.tensor import PTNNCFTensor
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor
from nncf.torch.tensor_statistics.collectors import get_mean_statisitic_collector
from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector


@ALGO_BACKENDS.register(BackendType.TORCH)
Expand Down Expand Up @@ -68,12 +67,12 @@ def model_extraction_command(inputs: List[str], outputs: List[str]) -> PTModelEx

@staticmethod
def mean_statistic_collector(
reduction_axes: ReductionAxes,
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_statisitic_collector(num_samples, reduction_axes, window_size)
return get_mean_statistic_collector(num_samples, channel_axis, window_size)

@staticmethod
def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]:
Expand Down
4 changes: 2 additions & 2 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def get_mean_percentile_statistic_collector(
return tensor_collector


def get_mean_statisitic_collector(
def get_mean_statistic_collector(
num_samples: int, channel_axis: int, window_size: Optional[int] = None
) -> TensorCollector:
"""
Expand All @@ -501,7 +501,7 @@ def get_mean_statisitic_collector(
if channel_axis == 0:
reducer = PTBatchMeanReducer()
else:
reducer = PTMeanPerChanelReducer(channel_axis)
reducer = PTMeanPerChanelReducer(channel_dim=channel_axis)
noop_reducer = PTNoopReducer()

kwargs = {
Expand Down

0 comments on commit cca8660

Please sign in to comment.