From 5150867ac43d8691599dbb37c114c38d40bf64d5 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Tue, 17 Oct 2023 15:13:44 +0100 Subject: [PATCH] Add SmoothQuant and ChannelAlignment to HyperparameterTuner (#2154) ### Changes - Add the `Pipeline` class. To use the algorithm with the `HyperparameterTuner`, it should be wrapped within a `Pipeline` object. ### Reason for changes - Add the SmoothQuant and ChannelAlignment algorithms to the parameter optimization process. ### Related tickets Ref: 117471 ### Tests --- .../hyperparameter_tuner/algorithm.py | 161 ++++++++++------ .../hyperparameter_tuner/param_grid.py | 51 ++++- nncf/quantization/algorithms/pipeline.py | 182 ++++++++++++++++++ .../algorithms/post_training/algorithm.py | 144 ++------------ .../algorithms/post_training/pipeline.py | 139 +++++++++++++ nncf/quantization/quantize_model.py | 10 +- tests/onnx/quantization/common.py | 4 + tests/onnx/quantization/test_ptq_params.py | 5 +- .../native/quantization/test_ptq_params.py | 5 +- .../test_templates/test_ptq_params.py | 20 +- .../test_templates/test_quantizer_config.py | 31 +-- tests/torch/ptq/test_fq_params_calculation.py | 14 +- tests/torch/ptq/test_ptq_params.py | 5 +- 13 files changed, 520 insertions(+), 251 deletions(-) create mode 100644 nncf/quantization/algorithms/pipeline.py create mode 100644 nncf/quantization/algorithms/post_training/pipeline.py diff --git a/nncf/quantization/algorithms/hyperparameter_tuner/algorithm.py b/nncf/quantization/algorithms/hyperparameter_tuner/algorithm.py index fba7b984278..2d0e123b1df 100644 --- a/nncf/quantization/algorithms/hyperparameter_tuner/algorithm.py +++ b/nncf/quantization/algorithms/hyperparameter_tuner/algorithm.py @@ -14,12 +14,12 @@ import functools import itertools import operator -from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union from nncf.common.factory import NNCFGraphFactory -from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.graph.graph import NNCFGraph from nncf.common.logging import nncf_logger +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import get_backend from nncf.common.utils.timer import timer from nncf.data.dataset import Dataset @@ -27,7 +27,8 @@ from nncf.quantization.algorithms.accuracy_control.evaluator import MetricResults from nncf.quantization.algorithms.accuracy_control.rank_functions import create_normalized_mse_func from nncf.quantization.algorithms.accuracy_control.subset_selection import select_subset -from nncf.quantization.algorithms.algorithm import Algorithm +from nncf.quantization.algorithms.pipeline import Pipeline +from nncf.quantization.algorithms.pipeline import collect_statistics TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") @@ -111,7 +112,9 @@ def apply_combination(init_params: Dict[str, Any], combination: Combination) -> return params -def print_combination_and_score(title: str, combination: Combination, combination_score: float) -> None: +def print_combination_and_score( + title: str, combination: Combination, combination_score: Optional[float] = None +) -> None: """ Prints combination and score. @@ -126,7 +129,9 @@ def print_combination_and_score(title: str, combination: Combination, combinatio message = f"{title} {message}" nncf_logger.info(message) - nncf_logger.info(f"Score: {combination_score}") + + if combination_score: + nncf_logger.info(f"Score: {combination_score}") def find_best_combination( @@ -186,7 +191,7 @@ class HyperparameterTuner: "param_name": [0.1, 0.2], } - The parameters names should be same as in `algorithm_cls.__init__()` method. + The parameters names should be same as in `pipeline_fn()` method. In case when "param_name" parameter is a dataclass object there is a way to specify settings to try for his fields using marker ":" @@ -214,9 +219,9 @@ class HyperparameterTuner: def __init__( self, - algorithm_cls: Type[Algorithm], + pipeline_fn: Callable[..., Pipeline], init_params: Dict[str, Any], - param_grid: Dict[str, List[Any]], + param_grids: List[Dict[str, List[Any]]], calibration_dataset: Dataset, validation_fn: Callable[[Any, Iterable[Any]], Tuple[float, Union[None, List[float], List[List[TTensor]]]]], subset_size: int, @@ -224,7 +229,7 @@ def __init__( quantized_metric_results: MetricResults, ): """ - :param algorithm_cls: Class of algorithm. + :param pipeline_fn: Function to create pipeline. :param init_params: Initial set of parameters used to create algorithm. :param param_grid: Dictionary with parameters names as keys and list of parameter settings to try as values. @@ -235,9 +240,9 @@ def __init__( :param initial_metric_results: Metric results for initial model. :param quantized_metric_results: Metric results for quantized with `init_params` model. """ - self._algorithm_cls = algorithm_cls + self._pipeline_fn = pipeline_fn self._init_params = init_params - self._param_grid = param_grid + self._param_grids = param_grids self._calibration_dataset = calibration_dataset self._evaluator = Evaluator(validation_fn) self._subset_size = subset_size @@ -246,12 +251,12 @@ def __init__( self._is_metric_mode = isinstance(self._initial_metric_results.values_for_each_item[0], float) - # # Will be initialized inside `apply()` method + # Will be initialized inside `apply()` method self._error_fn = None - # Will be initialized inside `_initialize_algorithms()` method - self._algorithms: Dict[CombinationKey, Algorithm] = {} - self._statistic_points = None + # Will be initialized inside `_prepare_pipeline_step()` method + self._pipelines: Dict[CombinationKey, Pipeline] = {} + self._step_index_to_statistics: Dict[int, StatisticPointsContainer] = {} self._calculated_scores: Dict[CombinationKey, float] = {} @@ -275,58 +280,101 @@ def apply(self, model: TModel, validation_dataset: Dataset) -> TModel: self._error_fn, ) - combinations = create_combinations(self._param_grid) + step_model = model # The model on which the `step_index`-th pipeline step will be executed + best_settings = {} + + for step_index, step_param_grid in enumerate(self._param_grids): + step_graph = NNCFGraphFactory.create(step_model) + + # If there are no parameters to optimize for the current step, simply execute + # this pipeline step on the model. + if not step_param_grid: + # TODO(andrey-churkin): Think about how it can be avoided. + params = apply_combination(self._init_params, best_settings) + pipeline = self._pipeline_fn(**params) + container = pipeline.get_statistic_points_for_step(step_index, step_model, step_graph) + step_statistics = collect_statistics(container, step_model, step_graph, self._calibration_dataset) + step_model = pipeline.run_step(step_index, step_statistics, step_model, step_graph) + continue - initial_graph = NNCFGraphFactory.create(model) + step_combinations = create_combinations(step_param_grid) - nncf_logger.info("Start initialization of algorithms") - with timer(): - self._prepare_algorithms(model, initial_graph, combinations) + nncf_logger.info(f"Start preparation for {step_index}-th pipeline step") + with timer(): + self._prepare_pipeline_step(step_index, step_model, step_graph, step_combinations, best_settings) - combination_score_fn = functools.partial( - self._calculate_combination_score, - initial_model=model, - initial_graph=initial_graph, - dataset=validation_dataset, - subset_indices=subset_indices, - ) + combination_score_fn = functools.partial( + self._calculate_combination_score, + step_index=step_index, + step_model=step_model, + step_graph=step_graph, + dataset=validation_dataset, + subset_indices=subset_indices, + ) + + nncf_logger.info("Start search best combination of parameters") + with timer(): + step_best_combination_key = find_best_combination( + step_combinations, combination_score_fn, step_param_grid + ) - nncf_logger.info("Start search best combination of parameters") - with timer(): - best_combination_key = find_best_combination(combinations, combination_score_fn, self._param_grid) + best_settings.update(step_combinations[step_best_combination_key]) + pipeline = self._pipelines[step_best_combination_key] + step_model = pipeline.run_step( + step_index, self._step_index_to_statistics[step_index], step_model, step_graph + ) - algorithm = self._algorithms[best_combination_key] - result_model = algorithm.apply(model, initial_graph, self._statistic_points) + print_combination_and_score("Final best combination of parameters:", best_settings) - return result_model + return step_model - def _prepare_algorithms( - self, initial_model: TModel, initial_graph: NNCFGraph, combinations: Dict[CombinationKey, Combination] + def _prepare_pipeline_step( + self, + step_index: int, + step_model: TModel, + step_graph: NNCFGraph, + step_combinations: Dict[CombinationKey, Combination], + best_settings, ) -> None: """ - Creates algorithm for each combination of parameters. Collects statistics for - created algorithms. - - :param initial_model: Input model used to collect statistics for algorithms. - :param combinations: Combinations of parameters. + Creates a separate pipeline for each combination from step_combination. + Each combination only changes the parameters of the `step_index`-th pipeline + step. After that, combines the statistics required to execute the `step_index`-th + pipeline step and collects them using `step_model`, `step_graph`, and the calibration + dataset. + + :param step_index: Zero-based index of pipeline step that should be prepared. + :param step_model: A model. + :param step_graph: A graph assosiated with a model. + :param step_combinations: Combinations that change parameters only for the step_index-th pipeline step. """ - for combination_key, combination in combinations.items(): - kwargs = apply_combination(self._init_params, combination) - self._algorithms[combination_key] = self._algorithm_cls(**kwargs) - - # Collect required statistics for created algorithms - stats_aggregator = StatisticsAggregatorFactory.create(initial_model, self._calibration_dataset) - for algorithm in self._algorithms.values(): - statistic_points = algorithm.get_statistic_points(initial_model, initial_graph) - stats_aggregator.register_statistic_points(statistic_points) - stats_aggregator.collect_statistics(initial_model, initial_graph) - self._statistic_points = stats_aggregator.statistic_points + # Create a separate pipeline for each combination + + # TODO(andrey-churkin): Think about how it can be avoided. In an ideal scenario, + # we would have only one pipeline and set parameters directly within it. + self._pipelines = {} + for combination_key, combination in step_combinations.items(): + settings = {} + settings.update(combination) + settings.update(best_settings) + kwargs = apply_combination(self._init_params, settings) + self._pipelines[combination_key] = self._pipeline_fn(**kwargs) + + # Collect statistics required to execute `step_index`-th pipeline step + containers = [ + pipeline.get_statistic_points_for_step(step_index, step_model, step_graph) + for pipeline in self._pipelines.values() + ] + self._step_index_to_statistics[step_index] = collect_statistics( + containers, step_model, step_graph, self._calibration_dataset + ) def _calculate_combination_score( self, combination_key: CombinationKey, - initial_model: TModel, - initial_graph: NNCFGraph, + step_index: int, + step_model: TModel, + step_graph: NNCFGraph, dataset: Dataset, subset_indices: List[int], ) -> float: @@ -343,8 +391,11 @@ def _calculate_combination_score( if combination_key in self._calculated_scores: return self._calculated_scores[combination_key] - algorithm = self._algorithms[combination_key] - model = algorithm.apply(initial_model, initial_graph, self._statistic_points) + pipeline = self._pipelines[combination_key] + model = pipeline.run_from_step( + step_model, self._calibration_dataset, step_graph, step_index, self._step_index_to_statistics + ) + score = self._validate_model(model, dataset, subset_indices) self._calculated_scores[combination_key] = score diff --git a/nncf/quantization/algorithms/hyperparameter_tuner/param_grid.py b/nncf/quantization/algorithms/hyperparameter_tuner/param_grid.py index fe0baa1b833..4874fc80b42 100644 --- a/nncf/quantization/algorithms/hyperparameter_tuner/param_grid.py +++ b/nncf/quantization/algorithms/hyperparameter_tuner/param_grid.py @@ -10,19 +10,24 @@ # limitations under the License. import itertools -from typing import Any, Dict +from typing import Any, Dict, List from nncf.common.quantization.structs import QuantizationPreset +from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection +from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment +from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization +from nncf.quantization.algorithms.pipeline import Pipeline +from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant from nncf.quantization.range_estimator import AggregatorType from nncf.quantization.range_estimator import RangeEstimatorParameters from nncf.quantization.range_estimator import StatisticsCollectorParameters from nncf.quantization.range_estimator import StatisticsType +ParamGrid = Dict[str, List[Any]] -def get_quantization_param_grid() -> Dict[str, Any]: - """ - Returns params grid for post-training quantization algorithm. - """ + +def _get_minmax_quantization_param_grid() -> ParamGrid: min_param_values = [ StatisticsCollectorParameters( statistics_type=StatisticsType.MIN, @@ -58,7 +63,6 @@ def get_quantization_param_grid() -> Dict[str, Any]: param_grid = { "preset": [QuantizationPreset.PERFORMANCE, QuantizationPreset.MIXED], - "fast_bias_correction": [True, False], "advanced_parameters:weights_range_estimator_params": [ RangeEstimatorParameters( min=StatisticsCollectorParameters(statistics_type=StatisticsType.MIN), @@ -70,5 +74,38 @@ def get_quantization_param_grid() -> Dict[str, Any]: for min_v, max_v in itertools.product(min_param_values, max_param_values) ], } - return param_grid + + +def _get_smooth_quant_param_grid() -> ParamGrid: + return {"advanced_parameters:smooth_quant_alpha": [0.15, 0.25, 0.5, 0.75, 0.95]} + + +def _get_channel_alignment_param_grid() -> ParamGrid: + return {} + + +def _get_bias_correction_param_grid() -> ParamGrid: + return {"fast_bias_correction": [True, False]} + + +def get_quantization_param_grids(pipeline: Pipeline) -> List[ParamGrid]: + """ + Returns params grid for post-training quantization algorithm. + """ + algorithm_cls_to_param_grid = { + SmoothQuant: _get_smooth_quant_param_grid(), + ChannelAlignment: _get_channel_alignment_param_grid(), + MinMaxQuantization: _get_minmax_quantization_param_grid(), + FastBiasCorrection: _get_bias_correction_param_grid(), + BiasCorrection: _get_bias_correction_param_grid(), + } + + param_grids = [] + for step in pipeline.pipeline_steps: + param_grid = {} + for algorithm in step: + param_grid.update(algorithm_cls_to_param_grid[algorithm.__class__]) + param_grids.append(param_grid) + + return param_grids diff --git a/nncf/quantization/algorithms/pipeline.py b/nncf/quantization/algorithms/pipeline.py new file mode 100644 index 00000000000..951bd436e2b --- /dev/null +++ b/nncf/quantization/algorithms/pipeline.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, TypeVar, Union + +from nncf.common.factory import NNCFGraphFactory +from nncf.common.factory import StatisticsAggregatorFactory +from nncf.common.graph.graph import NNCFGraph +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.data.dataset import Dataset +from nncf.quantization.algorithms.algorithm import Algorithm + +TModel = TypeVar("TModel") +PipelineStep = List[Algorithm] + + +def collect_statistics( + containers: Union[StatisticPointsContainer, List[StatisticPointsContainer]], + model: TModel, + graph: NNCFGraph, + dataset: Dataset, +) -> StatisticPointsContainer: + """ + Utility method for collecting statistics by model. + + :param statistic_points: Statistic points that need to be collected. + :param model: A model. + :param graph: A graph assosiated with a model. + :param dataset: A dataset. + :return: Collected statistics. + """ + if not isinstance(containers, list): + containers = [containers] + + statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) + for container in containers: + statistics_aggregator.register_statistic_points(container) + statistics_aggregator.collect_statistics(model, graph) + + return statistics_aggregator.statistic_points + + +class Pipeline: + """ + A class for creating pipelines that apply algorithms to a model. + + This class is used for creating custom model processing pipelines + that encapsulate a series of algorithms to be applied to a model + using a provided dataset. + + A pipeline consists of pipeline steps. Each pipeline step is a + sequence of Algorithm class instances whose statistic points are + combined and collected using the model obtained after the previous + pipeline step. The collected statistic points are used for all + algorithms in this step. + """ + + def __init__(self, pipeline_steps: List[PipelineStep]): + """ + :param pipeline_steps: A sequence of pipeline steps to be executed in order. + """ + self._pipeline_steps = pipeline_steps + + @property + def pipeline_steps(self) -> List[PipelineStep]: + """ + Property that defines the sequence of distinct pipeline steps to + be executed in order. + + :return: A sequence of pipeline steps to be executed in order. + """ + return self._pipeline_steps + + def run(self, model: TModel, dataset: Dataset) -> TModel: + """ + Executes the pipeline on the provided model. + + :param model: A model to which pipeline will be applied. + :param dataset: A dataset that holds the data items for algorithms. + :return: The updated model after executing the entire pipeline. + """ + return self.run_from_step(model, dataset) + + def run_step( + self, + step_index: int, + step_statistics: StatisticPointsContainer, + model: TModel, + graph: NNCFGraph, + ) -> TModel: + """ + Executes a provided pipeline step on the provided model. + + :param step_index: Zero-based index of the pipeline step that should be executed + :param step_statistics: Statistics required to execute a pipeline step. + :param model: A model to which a pipeline step will be applied. + :param graph: A graph assosiated with a model. + :return: The updated model after executing the pipeline step. + """ + current_model = model + current_graph = graph + + pipeline_step = self.pipeline_steps[step_index] + for algorithm in pipeline_step[:-1]: + current_model = algorithm.apply(current_model, current_graph, step_statistics) + current_graph = NNCFGraphFactory.create(current_model) + current_model = pipeline_step[-1].apply(current_model, current_graph, step_statistics) + + return current_model + + def run_from_step( + self, + model: TModel, + dataset: Dataset, + graph: Optional[NNCFGraph] = None, + start_step_index: int = 0, + step_index_to_statistics: Optional[Dict[int, StatisticPointsContainer]] = None, + ) -> TModel: + """ + Executes the pipeline from the specified pipeline step to the end. + + :param model: This is the model after the (start_step_index - 1)-th pipeline + step, or the initial model if start_step_index is 0. + :param dataset: A dataset that holds the data items for pipeline steps. + :param graph: A graph assosiated with a model. + :param start_step_index: Zero-based pipeline step index from which the pipeline + should be executed. + :param step_index_to_statistics: A mapping from pipeline step index to statistics + required to execute pipeline step. + :return: The updated model after executing the pipeline from the specified pipeline + step to the end. + """ + if step_index_to_statistics is None: + step_index_to_statistics = {} + + # The `step_model` and `step_graph` entities are required to execute `step_index`-th pipeline step + step_model = model + step_graph = graph + for step_index in range(start_step_index, len(self.pipeline_steps)): + # Create graph required to run current pipeline step + if step_graph is None: + step_graph = NNCFGraphFactory.create(step_model) + + # Collect statistics required to run current pipeline step + step_statistics = step_index_to_statistics.get(step_index) + if step_statistics is None: + statistic_points = self.get_statistic_points_for_step(step_index, step_model, step_graph) + step_statistics = collect_statistics(statistic_points, step_model, step_graph, dataset) + + # Run current pipeline step + step_model = self.run_step(step_index, step_statistics, step_model, step_graph) + + step_graph = None # We should rebuild the graph for the next pipeline step + + return step_model + + def get_statistic_points_for_step( + self, step_index: int, model: TModel, graph: NNCFGraph + ) -> StatisticPointsContainer: + """ + Returns statistics that should be collected to execute `step_index`-th pipeline step. + + :param step_index: Zero-based index of the pipeline step. + :param model: A model. + :param graph: A graph assosiated with a model. + :return: Statistics that should be collected to execute `step_index`-th pipeline step. + """ + container = StatisticPointsContainer() + for algorithm in self.pipeline_steps[step_index]: + for statistic_points in algorithm.get_statistic_points(model, graph).values(): + for statistic_point in statistic_points: + container.add_statistic_point(statistic_point) + + return container diff --git a/nncf/quantization/algorithms/post_training/algorithm.py b/nncf/quantization/algorithms/post_training/algorithm.py index d6e6b40de80..5c9a0e7777e 100644 --- a/nncf/quantization/algorithms/post_training/algorithm.py +++ b/nncf/quantization/algorithms/post_training/algorithm.py @@ -9,30 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, TypeVar +from typing import Callable, Dict, Optional, TypeVar from nncf import Dataset -from nncf.common.deprecation import warning_deprecated -from nncf.common.factory import NNCFGraphFactory -from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.graph.graph import NNCFGraph -from nncf.common.logging import nncf_logger from nncf.common.quantization.structs import QuantizationPreset from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType -from nncf.common.utils.backend import copy_model -from nncf.common.utils.backend import get_backend from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.algorithm import Algorithm -from nncf.quantization.algorithms.bias_correction.algorithm import BIAS_CORRECTION_THRESHOLD -from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection -from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment -from nncf.quantization.algorithms.fast_bias_correction.algorithm import FAST_BIAS_CORRECTION_THRESHOLD -from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection -from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization -from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant +from nncf.quantization.algorithms.post_training.pipeline import create_ptq_pipeline from nncf.scopes import IgnoredScope TModel = TypeVar("TModel") @@ -78,95 +66,16 @@ def __init__( :param advanced_parameters: Advanced quantization parameters for fine-tuning the quantization algorithm """ - super().__init__() - self.algorithms = [] - self.first_stage_algorithms: List[Algorithm] = [] - - if target_device is TargetDevice.VPU: - warning_deprecated("VPU device is deprecated and will no longer be supported in the future.") - - if advanced_parameters is None: - advanced_parameters = AdvancedQuantizationParameters() - - if model_type == ModelType.TRANSFORMER and advanced_parameters.smooth_quant_alpha >= 0: - smooth_quant_algorithm = SmoothQuant( - subset_size=subset_size, - inplace_statistics=advanced_parameters.inplace_statistics, - alpha=advanced_parameters.smooth_quant_alpha, - ) - self.first_stage_algorithms.append(smooth_quant_algorithm) - - if not advanced_parameters.disable_channel_alignment: - channel_alignment = ChannelAlignment( - subset_size=subset_size, - inplace_statistics=advanced_parameters.inplace_statistics, - ) - self.first_stage_algorithms.append(channel_alignment) - - min_max_quantization = MinMaxQuantization( - preset=preset, - target_device=target_device, - subset_size=subset_size, - model_type=model_type, - ignored_scope=ignored_scope, - overflow_fix=advanced_parameters.overflow_fix, - quantize_outputs=advanced_parameters.quantize_outputs, - inplace_statistics=advanced_parameters.inplace_statistics, - activations_quantization_params=advanced_parameters.activations_quantization_params, - weights_quantization_params=advanced_parameters.weights_quantization_params, - activations_range_estimator_params=advanced_parameters.activations_range_estimator_params, - weights_range_estimator_params=advanced_parameters.weights_range_estimator_params, - backend_params=advanced_parameters.backend_params, + self._pipeline = create_ptq_pipeline( + preset, target_device, subset_size, fast_bias_correction, model_type, ignored_scope, advanced_parameters ) - self.algorithms.append(min_max_quantization) - - if advanced_parameters.disable_bias_correction: - return - - bias_correction_params = advanced_parameters.bias_correction_params - if fast_bias_correction: - threshold = FAST_BIAS_CORRECTION_THRESHOLD - if bias_correction_params.threshold is not None: - threshold = bias_correction_params.threshold - bias_correction = FastBiasCorrection( - subset_size=subset_size, - threshold=threshold, - apply_for_all_nodes=bias_correction_params.apply_for_all_nodes, - inplace_statistics=advanced_parameters.inplace_statistics, - backend_params=advanced_parameters.backend_params, - ) - else: - threshold = BIAS_CORRECTION_THRESHOLD - if bias_correction_params.threshold is not None: - threshold = bias_correction_params.threshold - bias_correction_subset_size = max(int(subset_size * 0.2), 1) - bias_correction = BiasCorrection( - subset_size=bias_correction_subset_size, - threshold=threshold, - apply_for_all_nodes=bias_correction_params.apply_for_all_nodes, - inplace_statistics=advanced_parameters.inplace_statistics, - backend_params=advanced_parameters.backend_params, - ) - - self.algorithms.append(bias_correction) - @property def available_backends(self) -> Dict[str, BackendType]: return def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - if self.first_stage_algorithms: - raise NotImplementedError( - "Statistic points are not supported yet for SmoothQuant and ChannelAlignment algorithms." - ) - - output = StatisticPointsContainer() - for algorithm in self.algorithms: - for statistic_points in algorithm.get_statistic_points(model, graph).values(): - for statistic_point in statistic_points: - output.add_statistic_point(statistic_point) - return output + return self._pipeline.get_statistic_points_for_step(0, model, graph) def apply( self, @@ -175,41 +84,14 @@ def apply( statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, ) -> TModel: - modified_model = copy_model(model) - modified_model_graph = graph - backend = get_backend(modified_model) - - for algorithm in self.first_stage_algorithms: - if isinstance(algorithm, SmoothQuant) and backend != BackendType.OPENVINO: - nncf_logger.debug(f"{backend.name} does not support SmoothQuant algorithm yet.") - continue - - if isinstance(algorithm, ChannelAlignment) and backend != BackendType.OPENVINO: - nncf_logger.debug(f"{backend.name} does not support ChannelAlignment algorithm yet.") - continue - - statistics_aggregator = StatisticsAggregatorFactory.create(modified_model, dataset) - algo_statistic_points = algorithm.get_statistic_points(modified_model, modified_model_graph) - statistics_aggregator.register_statistic_points(algo_statistic_points) - statistics_aggregator.collect_statistics(modified_model, modified_model_graph) - modified_model = algorithm.apply( - modified_model, modified_model_graph, statistics_aggregator.statistic_points + if dataset is None and len(self._pipeline.pipeline_steps) > 1: + raise ValueError( + "A dataset is required for the post-training quantization " + "algorithm to collect statistics for intermediate models." ) - modified_model_graph = NNCFGraphFactory.create(modified_model) - - if statistic_points is None: - statistics_aggregator = StatisticsAggregatorFactory.create(modified_model, dataset) - for algorithm in self.algorithms: - algo_statistic_points = algorithm.get_statistic_points(modified_model, modified_model_graph) - statistics_aggregator.register_statistic_points(algo_statistic_points) - - statistics_aggregator.collect_statistics(modified_model, modified_model_graph) - statistic_points = statistics_aggregator.statistic_points - for algorithm in self.algorithms[:-1]: - modified_model = algorithm.apply(modified_model, modified_model_graph, statistic_points) - modified_model_graph = NNCFGraphFactory.create(modified_model) - # building the model graph is not required after the last algorithm - modified_model = self.algorithms[-1].apply(modified_model, modified_model_graph, statistic_points) + step_index_to_statistics = None + if statistic_points: + step_index_to_statistics = {0: statistic_points} - return modified_model + return self._pipeline.run_from_step(model, dataset, graph, 0, step_index_to_statistics) diff --git a/nncf/quantization/algorithms/post_training/pipeline.py b/nncf/quantization/algorithms/post_training/pipeline.py new file mode 100644 index 00000000000..7b522a39724 --- /dev/null +++ b/nncf/quantization/algorithms/post_training/pipeline.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, TypeVar + +from nncf.common.deprecation import warning_deprecated +from nncf.common.quantization.structs import QuantizationPreset +from nncf.parameters import ModelType +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters +from nncf.quantization.algorithms.bias_correction.algorithm import BIAS_CORRECTION_THRESHOLD +from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection +from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment +from nncf.quantization.algorithms.fast_bias_correction.algorithm import FAST_BIAS_CORRECTION_THRESHOLD +from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization +from nncf.quantization.algorithms.pipeline import Pipeline +from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant +from nncf.scopes import IgnoredScope + +TModel = TypeVar("TModel") + + +def create_ptq_pipeline( + preset: QuantizationPreset = QuantizationPreset.PERFORMANCE, + target_device: TargetDevice = TargetDevice.ANY, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + ignored_scope: Optional[IgnoredScope] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> Pipeline: + """ + Creates a post-training quantization pipeline. + + The post-training quantization pipeline includes the following steps: + 1) SmoothQuant + 2) ChannelAlignment + 3) MinMaxQuantization + 4) FastBiasCorrection or BiasCorrection + + :param preset: A preset that controls the quantization mode + (symmetric and asymmetric). It can take the following values: + - `performance`: Symmetric quantization of weights and activations. + - `mixed`: Symmetric quantization of weights and asymmetric + quantization of activations. + :param target_device: A target device the specificity of which will be taken + into account while compressing in order to obtain the best performance + for this type of device. + :param subset_size: Size of a subset to calculate activations + statistics used for quantization. + :param fast_bias_correction: Setting this option to `False` enables a different + bias correction method which is more accurate, in general, and takes + more time but requires less memory. + :param model_type: Model type is needed to specify additional patterns + in the model. Supported only `transformer` now. + :param ignored_scope: An ignored scope that defined the list of model control + flow graph nodes to be ignored during quantization. + :param advanced_parameters: Advanced quantization parameters for + fine-tuning the quantization algorithm + :return: A post-training quantization pipeline. + """ + if target_device is TargetDevice.VPU: + warning_deprecated("VPU device is deprecated and will no longer be supported in the future.") + + if advanced_parameters is None: + advanced_parameters = AdvancedQuantizationParameters() + + # Build the post-training quantization pipeline. + pipeline_steps = [] + + # Add the `SmoothQuant` algorithm as the first step of the pipeline. + # It is added only for `ModelType.TRANSFORMER`. + if model_type == ModelType.TRANSFORMER and advanced_parameters.smooth_quant_alpha >= 0: + pipeline_steps.append( + [SmoothQuant(subset_size, advanced_parameters.inplace_statistics, advanced_parameters.smooth_quant_alpha)] + ) + + # Add the `ChannelAlignment` algorithm as the second step of the pipeline. + if not advanced_parameters.disable_channel_alignment: + pipeline_steps.append([ChannelAlignment(subset_size, advanced_parameters.inplace_statistics)]) + + # Add the `MinMaxQuantization` algorithm as the third step of the pipeline. + pipeline_steps.append( + [ + MinMaxQuantization( + preset, + target_device, + subset_size, + model_type, + ignored_scope, + advanced_parameters.overflow_fix, + advanced_parameters.quantize_outputs, + advanced_parameters.inplace_statistics, + advanced_parameters.activations_quantization_params, + advanced_parameters.weights_quantization_params, + advanced_parameters.activations_range_estimator_params, + advanced_parameters.weights_range_estimator_params, + advanced_parameters.backend_params, + ) + ] + ) + + if not advanced_parameters.disable_bias_correction: + # Add the `FastBiasCorrection` or `BiasCorrection` as additional algorithm + # inside the third step of the pipeline. It is added after `MinMaxQuantization` + # algorithm. + bias_correction_params = advanced_parameters.bias_correction_params + if fast_bias_correction: + threshold = FAST_BIAS_CORRECTION_THRESHOLD + bias_correction_subset_size = subset_size + bias_correction_cls = FastBiasCorrection + else: + threshold = BIAS_CORRECTION_THRESHOLD + bias_correction_subset_size = max(int(subset_size * 0.2), 1) + bias_correction_cls = BiasCorrection + + if bias_correction_params.threshold is not None: + threshold = bias_correction_params.threshold + + pipeline_steps[-1].append( + bias_correction_cls( + bias_correction_subset_size, + threshold, + bias_correction_params.apply_for_all_nodes, + advanced_parameters.inplace_statistics, + advanced_parameters.backend_params, + ) + ) + + return Pipeline(pipeline_steps) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index d94efa1b704..2c64328f6cb 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -26,8 +26,8 @@ from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.accuracy_control.evaluator import MetricResults from nncf.quantization.algorithms.hyperparameter_tuner.algorithm import HyperparameterTuner -from nncf.quantization.algorithms.hyperparameter_tuner.param_grid import get_quantization_param_grid -from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.quantization.algorithms.hyperparameter_tuner.param_grid import get_quantization_param_grids +from nncf.quantization.algorithms.post_training.pipeline import create_ptq_pipeline from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression from nncf.scopes import IgnoredScope @@ -333,12 +333,12 @@ def quantize_with_tune_hyperparams( "advanced_parameters": advanced_quantization_parameters, } - quantization_param_grid = get_quantization_param_grid() + param_grids = get_quantization_param_grids(create_ptq_pipeline(**init_quantization_params)) hyperparameter_tuner = HyperparameterTuner( - PostTrainingQuantization, + create_ptq_pipeline, init_quantization_params, - quantization_param_grid, + param_grids, calibration_dataset, validation_fn, tuner_subset_size, diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index a5b9c8f47e3..7ca77985533 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -111,7 +111,11 @@ def min_max_quantize_model( quantization_params = {} if quantization_params is None else quantization_params advanced_parameters = quantization_params.get("advanced_parameters", AdvancedQuantizationParameters()) + + # ONNX backend does not support these algorithms advanced_parameters.disable_bias_correction = True + advanced_parameters.disable_channel_alignment = True + advanced_parameters.smooth_quant_alpha = -1 quantization_params["advanced_parameters"] = advanced_parameters post_training_quantization = PostTrainingQuantization(subset_size=1, **quantization_params) diff --git a/tests/onnx/quantization/test_ptq_params.py b/tests/onnx/quantization/test_ptq_params.py index 9bc23b1410b..d0a88ad8cc8 100644 --- a/tests/onnx/quantization/test_ptq_params.py +++ b/tests/onnx/quantization/test_ptq_params.py @@ -24,8 +24,8 @@ from nncf.onnx.statistics.collectors import ONNXMeanMinMaxStatisticCollector from nncf.onnx.statistics.collectors import ONNXMinMaxStatisticCollector from nncf.parameters import TargetDevice +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.onnx_backend import ONNXMinMaxAlgoBackend -from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.scopes import IgnoredScope from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype @@ -49,8 +49,7 @@ def get_ignored_patterns(device: TargetDevice = TargetDevice.ANY) -> GraphPatter @pytest.mark.parametrize("target_device", TargetDevice) def test_target_device(target_device): - algo = PostTrainingQuantization(target_device=target_device) - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization(target_device=target_device) min_max_algo._backend_entity = ONNXMinMaxAlgoBackend() assert min_max_algo._target_device == target_device diff --git a/tests/openvino/native/quantization/test_ptq_params.py b/tests/openvino/native/quantization/test_ptq_params.py index 3552915d523..e1d170b72b4 100644 --- a/tests/openvino/native/quantization/test_ptq_params.py +++ b/tests/openvino/native/quantization/test_ptq_params.py @@ -26,8 +26,8 @@ from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.transformations.commands import OVTargetPoint from nncf.parameters import TargetDevice +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend -from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.scopes import IgnoredScope from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype @@ -49,8 +49,7 @@ def get_ignored_patterns(device: TargetDevice = TargetDevice.ANY) -> GraphPatter # pylint: disable=protected-access @pytest.mark.parametrize("target_device", [TargetDevice.CPU, TargetDevice.GPU, TargetDevice.VPU]) def test_target_device(target_device): - algo = PostTrainingQuantization(target_device=target_device) - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization(target_device=target_device) min_max_algo._backend_entity = OVMinMaxAlgoBackend() assert min_max_algo._target_device.value == HW_CONFIG_TYPE_TARGET_DEVICE_MAP[target_device.value] diff --git a/tests/post_training/test_templates/test_ptq_params.py b/tests/post_training/test_templates/test_ptq_params.py index a2ec340c2ff..8de8ce10451 100644 --- a/tests/post_training/test_templates/test_ptq_params.py +++ b/tests/post_training/test_templates/test_ptq_params.py @@ -27,10 +27,8 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.parameters import ModelType -from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization -from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.quantization.passes import transform_to_inference_graph from nncf.quantization.range_estimator import RangeEstimatorParametersSet from nncf.scopes import IgnoredScope @@ -131,12 +129,7 @@ def metatypes_mapping(self): "range_estimator_params", [RangeEstimatorParametersSet.MINMAX, RangeEstimatorParametersSet.MEAN_MINMAX, None] ) def test_range_estimator_per_tensor(self, test_params, range_estimator_params): - algo = PostTrainingQuantization( - advanced_parameters=AdvancedQuantizationParameters( - activations_range_estimator_params=range_estimator_params - ) - ) - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization(activations_range_estimator_params=range_estimator_params) min_max_algo._backend_entity = self.get_algo_backend() assert min_max_algo._range_estimator_params[QuantizerGroup.ACTIVATIONS] == range_estimator_params @@ -161,10 +154,7 @@ def test_range_estimator_per_tensor(self, test_params, range_estimator_params): @pytest.mark.parametrize("quantize_outputs", [False, True]) def test_quantize_outputs(self, test_params, quantize_outputs): - algo = PostTrainingQuantization( - advanced_parameters=AdvancedQuantizationParameters(quantize_outputs=quantize_outputs) - ) - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization(quantize_outputs=quantize_outputs) min_max_algo._backend_entity = self.get_algo_backend() nncf_graph = test_params["test_quantize_outputs"]["nncf_graph"] @@ -189,8 +179,7 @@ def test_quantize_outputs(self, test_params, quantize_outputs): def test_ignored_scopes(self, test_params, ignored_scopes_data): ignored_scope, act_num_ref, weight_num_ref = ignored_scopes_data - algo = PostTrainingQuantization(ignored_scope=ignored_scope) - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization(ignored_scope=ignored_scope) min_max_algo._backend_entity = self.get_algo_backend() assert min_max_algo._ignored_scope == ignored_scope @@ -215,8 +204,7 @@ def test_ignored_scopes(self, test_params, ignored_scopes_data): @pytest.mark.parametrize("model_type", [ModelType.TRANSFORMER]) def test_model_type_pass(self, test_params, model_type): - algo = PostTrainingQuantization(preset=QuantizationPreset.MIXED, model_type=model_type) - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization(preset=QuantizationPreset.MIXED, model_type=model_type) min_max_algo._backend_entity = self.get_algo_backend() nncf_graph = test_params["test_model_type_pass"]["nncf_graph"] diff --git a/tests/post_training/test_templates/test_quantizer_config.py b/tests/post_training/test_templates/test_quantizer_config.py index 72da2111a36..033b35377f5 100644 --- a/tests/post_training/test_templates/test_quantizer_config.py +++ b/tests/post_training/test_templates/test_quantizer_config.py @@ -30,9 +30,8 @@ from nncf.experimental.common.tensor_statistics.collectors import MaxReducer from nncf.experimental.common.tensor_statistics.collectors import MinReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import QuantizationParameters -from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.passes import transform_to_inference_graph from nncf.quantization.range_estimator import RangeEstimatorParametersSet from tests.post_training.test_templates.models import NNCFGraphToTest @@ -86,8 +85,7 @@ def statistic_collector_parameters(self, request) -> TestGetStatisticsCollectorP pass def test_default_quantizer_config(self, single_conv_nncf_graph): - algo = PostTrainingQuantization() - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization() min_max_algo._backend_entity = self.get_algo_backend() nncf_graph = single_conv_nncf_graph.nncf_graph inference_nncf_graph = transform_to_inference_graph( @@ -132,18 +130,15 @@ def test_quantizer_config_from_ptq_params_for_CPU( signed_activations, single_conv_nncf_graph, ): - algo = PostTrainingQuantization( + min_max_algo = MinMaxQuantization( preset=preset, - advanced_parameters=AdvancedQuantizationParameters( - activations_quantization_params=QuantizationParameters( - num_bits=activation_bits, per_channel=activation_per_channel, signedness_to_force=signed_activations - ), - weights_quantization_params=QuantizationParameters( - num_bits=weight_bits, per_channel=weight_per_channel, signedness_to_force=signed_weights - ), + activations_quantization_params=QuantizationParameters( + num_bits=activation_bits, per_channel=activation_per_channel, signedness_to_force=signed_activations + ), + weights_quantization_params=QuantizationParameters( + num_bits=weight_bits, per_channel=weight_per_channel, signedness_to_force=signed_weights ), ) - min_max_algo = algo.algorithms[0] min_max_algo._backend_entity = self.get_algo_backend() nncf_graph = single_conv_nncf_graph.nncf_graph inference_nncf_graph = transform_to_inference_graph( @@ -184,8 +179,7 @@ def test_quantizer_config_from_ptq_params_for_CPU( assert quantization_point.qconfig.signedness_to_force == signed_activations def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph): - algo = PostTrainingQuantization() - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization() min_max_algo._backend_entity = self.get_algo_backend() nncf_graph = depthwise_conv_nncf_graph.nncf_graph inference_nncf_graph = transform_to_inference_graph( @@ -228,12 +222,7 @@ def test_get_stat_collector( statistic_collector_parameters: TestGetStatisticsCollectorParameters, ): params = statistic_collector_parameters - algo = PostTrainingQuantization( - advanced_parameters=AdvancedQuantizationParameters( - activations_range_estimator_params=range_estimator_params - ) - ) - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization(activations_range_estimator_params=range_estimator_params) min_max_algo._backend_entity = self.get_algo_backend() q_config = QuantizerConfig(num_bits=8, mode=q_config_mode, per_channel=q_config_per_channel) diff --git a/tests/torch/ptq/test_fq_params_calculation.py b/tests/torch/ptq/test_fq_params_calculation.py index 3b8b770b471..1dcb3d5d065 100644 --- a/tests/torch/ptq/test_fq_params_calculation.py +++ b/tests/torch/ptq/test_fq_params_calculation.py @@ -18,7 +18,6 @@ import nncf from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix -from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.torch.model_creation import create_nncf_network from nncf.torch.nncf_network import NNCFNetwork @@ -49,13 +48,14 @@ def transform_fn(sample): dataset = nncf.Dataset(dataloader, transform_func=transform_fn) - post_training_quantization = PostTrainingQuantization(subset_size=1, **quantization_params) # Using PTQ, but apply only MinMax - updated_algorithms = [] - for algo in post_training_quantization.algorithms: - if isinstance(algo, MinMaxQuantization): - updated_algorithms.append(algo) - post_training_quantization.algorithms = updated_algorithms + advanced_parameters = quantization_params.get("advanced_parameters", AdvancedQuantizationParameters()) + advanced_parameters.disable_bias_correction = True + advanced_parameters.disable_channel_alignment = True + advanced_parameters.smooth_quant_alpha = -1 + quantization_params["advanced_parameters"] = advanced_parameters + + post_training_quantization = PostTrainingQuantization(subset_size=1, **quantization_params) original_model.eval() nncf_network = create_nncf_network(original_model, config) diff --git a/tests/torch/ptq/test_ptq_params.py b/tests/torch/ptq/test_ptq_params.py index 35ddfe3128e..cdc09a7b81a 100644 --- a/tests/torch/ptq/test_ptq_params.py +++ b/tests/torch/ptq/test_ptq_params.py @@ -28,8 +28,8 @@ from nncf.quantization.advanced_parameters import OverflowFix from nncf.quantization.advanced_parameters import QuantizationMode from nncf.quantization.advanced_parameters import QuantizationParameters +from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend -from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.quantization.range_estimator import RangeEstimatorParametersSet from nncf.scopes import IgnoredScope from nncf.torch.graph.graph import PTTargetPoint @@ -96,8 +96,7 @@ def forward(self, x): @pytest.mark.parametrize("target_device", TargetDevice) def test_target_device(target_device): - algo = PostTrainingQuantization(target_device=target_device) - min_max_algo = algo.algorithms[0] + min_max_algo = MinMaxQuantization(target_device=target_device) min_max_algo._backend_entity = PTMinMaxAlgoBackend() assert min_max_algo._target_device == target_device