Skip to content

Commit

Permalink
Add SmoothQuant and ChannelAlignment to HyperparameterTuner (#2154)
Browse files Browse the repository at this point in the history
### Changes

- Add the `Pipeline` class. To use the algorithm with the
`HyperparameterTuner`, it should be wrapped within a `Pipeline` object.

### Reason for changes

- Add the SmoothQuant and ChannelAlignment algorithms to the parameter
optimization process.

### Related tickets

Ref: 117471

### Tests

<!--- How was the correctness of changes tested and whether new tests
were added -->
  • Loading branch information
andrey-churkin authored Oct 17, 2023
1 parent d982620 commit 5150867
Show file tree
Hide file tree
Showing 13 changed files with 520 additions and 251 deletions.
161 changes: 106 additions & 55 deletions nncf/quantization/algorithms/hyperparameter_tuner/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@
import functools
import itertools
import operator
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union

from nncf.common.factory import NNCFGraphFactory
from nncf.common.factory import StatisticsAggregatorFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.logging import nncf_logger
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import get_backend
from nncf.common.utils.timer import timer
from nncf.data.dataset import Dataset
from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator
from nncf.quantization.algorithms.accuracy_control.evaluator import MetricResults
from nncf.quantization.algorithms.accuracy_control.rank_functions import create_normalized_mse_func
from nncf.quantization.algorithms.accuracy_control.subset_selection import select_subset
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.pipeline import Pipeline
from nncf.quantization.algorithms.pipeline import collect_statistics

TModel = TypeVar("TModel")
TTensor = TypeVar("TTensor")
Expand Down Expand Up @@ -111,7 +112,9 @@ def apply_combination(init_params: Dict[str, Any], combination: Combination) ->
return params


def print_combination_and_score(title: str, combination: Combination, combination_score: float) -> None:
def print_combination_and_score(
title: str, combination: Combination, combination_score: Optional[float] = None
) -> None:
"""
Prints combination and score.
Expand All @@ -126,7 +129,9 @@ def print_combination_and_score(title: str, combination: Combination, combinatio
message = f"{title} {message}"

nncf_logger.info(message)
nncf_logger.info(f"Score: {combination_score}")

if combination_score:
nncf_logger.info(f"Score: {combination_score}")


def find_best_combination(
Expand Down Expand Up @@ -186,7 +191,7 @@ class HyperparameterTuner:
"param_name": [0.1, 0.2],
}
The parameters names should be same as in `algorithm_cls.__init__()` method.
The parameters names should be same as in `pipeline_fn()` method.
In case when "param_name" parameter is a dataclass object there is a way to specify settings
to try for his fields using marker ":"
Expand Down Expand Up @@ -214,17 +219,17 @@ class HyperparameterTuner:

def __init__(
self,
algorithm_cls: Type[Algorithm],
pipeline_fn: Callable[..., Pipeline],
init_params: Dict[str, Any],
param_grid: Dict[str, List[Any]],
param_grids: List[Dict[str, List[Any]]],
calibration_dataset: Dataset,
validation_fn: Callable[[Any, Iterable[Any]], Tuple[float, Union[None, List[float], List[List[TTensor]]]]],
subset_size: int,
initial_metric_results: MetricResults,
quantized_metric_results: MetricResults,
):
"""
:param algorithm_cls: Class of algorithm.
:param pipeline_fn: Function to create pipeline.
:param init_params: Initial set of parameters used to create algorithm.
:param param_grid: Dictionary with parameters names as keys and list of
parameter settings to try as values.
Expand All @@ -235,9 +240,9 @@ def __init__(
:param initial_metric_results: Metric results for initial model.
:param quantized_metric_results: Metric results for quantized with `init_params` model.
"""
self._algorithm_cls = algorithm_cls
self._pipeline_fn = pipeline_fn
self._init_params = init_params
self._param_grid = param_grid
self._param_grids = param_grids
self._calibration_dataset = calibration_dataset
self._evaluator = Evaluator(validation_fn)
self._subset_size = subset_size
Expand All @@ -246,12 +251,12 @@ def __init__(

self._is_metric_mode = isinstance(self._initial_metric_results.values_for_each_item[0], float)

# # Will be initialized inside `apply()` method
# Will be initialized inside `apply()` method
self._error_fn = None

# Will be initialized inside `_initialize_algorithms()` method
self._algorithms: Dict[CombinationKey, Algorithm] = {}
self._statistic_points = None
# Will be initialized inside `_prepare_pipeline_step()` method
self._pipelines: Dict[CombinationKey, Pipeline] = {}
self._step_index_to_statistics: Dict[int, StatisticPointsContainer] = {}

self._calculated_scores: Dict[CombinationKey, float] = {}

Expand All @@ -275,58 +280,101 @@ def apply(self, model: TModel, validation_dataset: Dataset) -> TModel:
self._error_fn,
)

combinations = create_combinations(self._param_grid)
step_model = model # The model on which the `step_index`-th pipeline step will be executed
best_settings = {}

for step_index, step_param_grid in enumerate(self._param_grids):
step_graph = NNCFGraphFactory.create(step_model)

# If there are no parameters to optimize for the current step, simply execute
# this pipeline step on the model.
if not step_param_grid:
# TODO(andrey-churkin): Think about how it can be avoided.
params = apply_combination(self._init_params, best_settings)
pipeline = self._pipeline_fn(**params)
container = pipeline.get_statistic_points_for_step(step_index, step_model, step_graph)
step_statistics = collect_statistics(container, step_model, step_graph, self._calibration_dataset)
step_model = pipeline.run_step(step_index, step_statistics, step_model, step_graph)
continue

initial_graph = NNCFGraphFactory.create(model)
step_combinations = create_combinations(step_param_grid)

nncf_logger.info("Start initialization of algorithms")
with timer():
self._prepare_algorithms(model, initial_graph, combinations)
nncf_logger.info(f"Start preparation for {step_index}-th pipeline step")
with timer():
self._prepare_pipeline_step(step_index, step_model, step_graph, step_combinations, best_settings)

combination_score_fn = functools.partial(
self._calculate_combination_score,
initial_model=model,
initial_graph=initial_graph,
dataset=validation_dataset,
subset_indices=subset_indices,
)
combination_score_fn = functools.partial(
self._calculate_combination_score,
step_index=step_index,
step_model=step_model,
step_graph=step_graph,
dataset=validation_dataset,
subset_indices=subset_indices,
)

nncf_logger.info("Start search best combination of parameters")
with timer():
step_best_combination_key = find_best_combination(
step_combinations, combination_score_fn, step_param_grid
)

nncf_logger.info("Start search best combination of parameters")
with timer():
best_combination_key = find_best_combination(combinations, combination_score_fn, self._param_grid)
best_settings.update(step_combinations[step_best_combination_key])
pipeline = self._pipelines[step_best_combination_key]
step_model = pipeline.run_step(
step_index, self._step_index_to_statistics[step_index], step_model, step_graph
)

algorithm = self._algorithms[best_combination_key]
result_model = algorithm.apply(model, initial_graph, self._statistic_points)
print_combination_and_score("Final best combination of parameters:", best_settings)

return result_model
return step_model

def _prepare_algorithms(
self, initial_model: TModel, initial_graph: NNCFGraph, combinations: Dict[CombinationKey, Combination]
def _prepare_pipeline_step(
self,
step_index: int,
step_model: TModel,
step_graph: NNCFGraph,
step_combinations: Dict[CombinationKey, Combination],
best_settings,
) -> None:
"""
Creates algorithm for each combination of parameters. Collects statistics for
created algorithms.
:param initial_model: Input model used to collect statistics for algorithms.
:param combinations: Combinations of parameters.
Creates a separate pipeline for each combination from step_combination.
Each combination only changes the parameters of the `step_index`-th pipeline
step. After that, combines the statistics required to execute the `step_index`-th
pipeline step and collects them using `step_model`, `step_graph`, and the calibration
dataset.
:param step_index: Zero-based index of pipeline step that should be prepared.
:param step_model: A model.
:param step_graph: A graph assosiated with a model.
:param step_combinations: Combinations that change parameters only for the step_index-th pipeline step.
"""
for combination_key, combination in combinations.items():
kwargs = apply_combination(self._init_params, combination)
self._algorithms[combination_key] = self._algorithm_cls(**kwargs)

# Collect required statistics for created algorithms
stats_aggregator = StatisticsAggregatorFactory.create(initial_model, self._calibration_dataset)
for algorithm in self._algorithms.values():
statistic_points = algorithm.get_statistic_points(initial_model, initial_graph)
stats_aggregator.register_statistic_points(statistic_points)
stats_aggregator.collect_statistics(initial_model, initial_graph)
self._statistic_points = stats_aggregator.statistic_points
# Create a separate pipeline for each combination

# TODO(andrey-churkin): Think about how it can be avoided. In an ideal scenario,
# we would have only one pipeline and set parameters directly within it.
self._pipelines = {}
for combination_key, combination in step_combinations.items():
settings = {}
settings.update(combination)
settings.update(best_settings)
kwargs = apply_combination(self._init_params, settings)
self._pipelines[combination_key] = self._pipeline_fn(**kwargs)

# Collect statistics required to execute `step_index`-th pipeline step
containers = [
pipeline.get_statistic_points_for_step(step_index, step_model, step_graph)
for pipeline in self._pipelines.values()
]
self._step_index_to_statistics[step_index] = collect_statistics(
containers, step_model, step_graph, self._calibration_dataset
)

def _calculate_combination_score(
self,
combination_key: CombinationKey,
initial_model: TModel,
initial_graph: NNCFGraph,
step_index: int,
step_model: TModel,
step_graph: NNCFGraph,
dataset: Dataset,
subset_indices: List[int],
) -> float:
Expand All @@ -343,8 +391,11 @@ def _calculate_combination_score(
if combination_key in self._calculated_scores:
return self._calculated_scores[combination_key]

algorithm = self._algorithms[combination_key]
model = algorithm.apply(initial_model, initial_graph, self._statistic_points)
pipeline = self._pipelines[combination_key]
model = pipeline.run_from_step(
step_model, self._calibration_dataset, step_graph, step_index, self._step_index_to_statistics
)

score = self._validate_model(model, dataset, subset_indices)
self._calculated_scores[combination_key] = score

Expand Down
51 changes: 44 additions & 7 deletions nncf/quantization/algorithms/hyperparameter_tuner/param_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,24 @@
# limitations under the License.

import itertools
from typing import Any, Dict
from typing import Any, Dict, List

from nncf.common.quantization.structs import QuantizationPreset
from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection
from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.algorithms.pipeline import Pipeline
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant
from nncf.quantization.range_estimator import AggregatorType
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.quantization.range_estimator import StatisticsCollectorParameters
from nncf.quantization.range_estimator import StatisticsType

ParamGrid = Dict[str, List[Any]]

def get_quantization_param_grid() -> Dict[str, Any]:
"""
Returns params grid for post-training quantization algorithm.
"""

def _get_minmax_quantization_param_grid() -> ParamGrid:
min_param_values = [
StatisticsCollectorParameters(
statistics_type=StatisticsType.MIN,
Expand Down Expand Up @@ -58,7 +63,6 @@ def get_quantization_param_grid() -> Dict[str, Any]:

param_grid = {
"preset": [QuantizationPreset.PERFORMANCE, QuantizationPreset.MIXED],
"fast_bias_correction": [True, False],
"advanced_parameters:weights_range_estimator_params": [
RangeEstimatorParameters(
min=StatisticsCollectorParameters(statistics_type=StatisticsType.MIN),
Expand All @@ -70,5 +74,38 @@ def get_quantization_param_grid() -> Dict[str, Any]:
for min_v, max_v in itertools.product(min_param_values, max_param_values)
],
}

return param_grid


def _get_smooth_quant_param_grid() -> ParamGrid:
return {"advanced_parameters:smooth_quant_alpha": [0.15, 0.25, 0.5, 0.75, 0.95]}


def _get_channel_alignment_param_grid() -> ParamGrid:
return {}


def _get_bias_correction_param_grid() -> ParamGrid:
return {"fast_bias_correction": [True, False]}


def get_quantization_param_grids(pipeline: Pipeline) -> List[ParamGrid]:
"""
Returns params grid for post-training quantization algorithm.
"""
algorithm_cls_to_param_grid = {
SmoothQuant: _get_smooth_quant_param_grid(),
ChannelAlignment: _get_channel_alignment_param_grid(),
MinMaxQuantization: _get_minmax_quantization_param_grid(),
FastBiasCorrection: _get_bias_correction_param_grid(),
BiasCorrection: _get_bias_correction_param_grid(),
}

param_grids = []
for step in pipeline.pipeline_steps:
param_grid = {}
for algorithm in step:
param_grid.update(algorithm_cls_to_param_grid[algorithm.__class__])
param_grids.append(param_grid)

return param_grids
Loading

0 comments on commit 5150867

Please sign in to comment.