Skip to content

Commit

Permalink
Refine code and update autotune API (#1655)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Mar 8, 2024
1 parent 28fb965 commit 3a254e9
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 136 deletions.
30 changes: 18 additions & 12 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from itertools import product
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from neural_compressor.common import Logger
from typing_extensions import Self

from neural_compressor.common.tuning_param import TuningParam
from neural_compressor.common.utils import (
BASE_CONFIG,
Expand All @@ -36,10 +37,9 @@
GLOBAL,
LOCAL,
OP_NAME_OR_MODULE_TYPE,
logger,
)

logger = Logger().get_logger()

__all__ = [
"options",
"register_config",
Expand All @@ -52,11 +52,18 @@


# Config registry to store all registered configs.
class ConfigRegistry:
class ConfigRegistry(object):
registered_configs = {}
_config_registry = None

def __new__(cls) -> Self:
if cls._config_registry is None:
cls._config_registry = super(ConfigRegistry, cls).__new__(cls)

return cls._config_registry

@classmethod
def register_config_impl(cls, framework_name="None", algo_name=None, priority=0):
def register_config_impl(cls, framework_name: str, algo_name: str, priority: Union[float, int] = 0):
"""Register config decorator.
The register the configuration classes for different algorithms within specific frameworks.
Expand All @@ -67,8 +74,8 @@ class ExampleAlgorithmConfig:
# Configuration details for the ExampleAlgorithm
Args:
framework_name: the framework name. Defaults to "None".
algo_name: the algorithm name. Defaults to None.
framework_name: the framework name.
algo_name: the algorithm name.
priority: priority: the priority of the configuration. A larger number indicates a higher priority,
which will be tried first at the auto-tune stage. Defaults to 0.
"""
Expand Down Expand Up @@ -116,7 +123,7 @@ def get_all_config_cls_by_fwk_name(cls, fwk_name: str) -> List[Type[BaseConfig]]
config_registry = ConfigRegistry()


def register_config(framework_name="None", algo_name=None, priority=0):
def register_config(framework_name: str, algo_name: str, priority: Union[float, int] = 0):
"""Register config decorator.
The register the configuration classes for different algorithms within specific frameworks.
Expand All @@ -127,8 +134,8 @@ class ExampleAlgorithmConfig:
# Configuration details for the ExampleAlgorithm
Args:
framework_name: the framework name. Defaults to "None".
algo_name: the algorithm name. Defaults to None.
framework_name: the framework name.
algo_name: the algorithm name.
priority: the priority of the configuration. A larger number indicates a higher priority,
which will be tried first at the auto-tune stage. Defaults to 0.
"""
Expand Down Expand Up @@ -411,7 +418,7 @@ def to_config_mapping(

@staticmethod
def _is_op_type(name: str) -> bool:
# TODO (Yi), ort and tf need override it
# * Ort and TF may override this method.
return not isinstance(name, str)

@classmethod
Expand Down Expand Up @@ -461,7 +468,6 @@ def to_config_mapping(
) -> OrderedDict[str, BaseConfig]:
config_mapping = OrderedDict()
for config in self.config_list:
global_config = config.global_config
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
single_config_model_info = model_info.get(config.name, None)
for op_name, op_type in single_config_model_info:
Expand Down
62 changes: 36 additions & 26 deletions neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import uuid
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union

from neural_compressor.common import Logger
from neural_compressor.common.base_config import BaseConfig
from neural_compressor.common.utils import TuningLogger

logger = Logger().get_logger()
from neural_compressor.common.utils import TuningLogger, logger

__all__ = [
"Evaluator",
"EvaluationFuncWrapper",
"TuningConfig",
"Sampler",
"ConfigLoader",
Expand All @@ -37,9 +35,27 @@
]


class EvaluationFuncWrapper:
def __init__(self, eval_fn: Callable, eval_args=None):
"""Evaluation function wrapper.
Args:
eval_fn: a function for evaluated the float or quantized model
eval_args: positional arguments for `eval_fn`
"""
self.eval_fn = eval_fn
self.eval_args = eval_args

def evaluate(self, model) -> Union[float, int]:
result = self.eval_fn(model, *self.eval_args) if self.eval_args else self.eval_fn(model)
return result


class Evaluator:
"""Evaluator is a collection of evaluation functions.
Note: will deprecate this class in the future.
Examples:
def eval_acc(model):
...
Expand Down Expand Up @@ -82,7 +98,6 @@ def evaluate(self, model) -> float:
return result

def _update_the_objective_score(self, eval_pair, eval_result, overall_result) -> float:
# TODO update the result according to the weight and algo_name
return overall_result + eval_result * eval_pair[self.WEIGHT]

def get_number_of_eval_functions(self) -> int:
Expand Down Expand Up @@ -229,29 +244,25 @@ class TuningConfig:
"""Config for auto tuning pipeline.
Examples:
# TODO: to refine it
from neural_compressor.torch.quantization import TuningConfig
tune_config = TuningConfig(
config_set=[config1, config2, ...],
max_trials=3,
tolerable_loss=0.01
)
# Case 1: Tolerable Loss
fp32_baseline = 100
config1_metric, config2_metric, ... = 98, 99, ...
# Tuning result of case 1:
# The best tuning config is config2, because config2_metric >= fp32_baseline * (1 - tolerable_loss)
# Case 2: Maximum Trials
fp32_baseline = 100
config1_metric, config2_metric, config3_metric, ... = 98, 98, 97, ...
# Tuning result of case 2:
# The best tuning config is config2, because of the following:
# 1. Not achieving the set goal. (config_metric < fp32_baseline * (1 - tolerable_loss))
# 2. Reached maximum tuning times.
tolerable_loss=0.01)
The tuning process stops when either of the following conditions is met:
1) The number of trials reaches the maximum trials.
2) The metric loss is within the tolerable loss.
For condition 2), we calculate the metric loss as follows:
relative_loss = (fp32_baseline - eval_result_of_q_model) / fp32_baseline
If relative_loss <= tolerable_loss, we stop the tuning process.
For example:
tolerable_loss = 0.01
fp32_baseline = 100
eval_result_of_q_model = 99
relative_loss = (100 - 99) / 100 = 0.01
The metric loss is within the tolerable loss, so the tuning process is stopped.
"""

def __init__(
Expand Down Expand Up @@ -321,10 +332,9 @@ def need_stop(self) -> bool:
"""Check if need to stop tuning. Either accuracy goal is met, max trials is reached or timeout is reached.
Returns:
bool: True if need to stop, otherwise False.
stop_flag: True if need to stop, otherwise False.
"""

# TODO: Support more stop criteria in the next PR, such as `timeout`, and so on.
# reach max trials
reach_max_trials = self.trial_cnt >= self.tuning_config.max_trials
# reach accuracy goal
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/common/tuning_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import typing
from enum import Enum, auto
from typing import Any, List
from typing import Any

from pydantic import BaseModel

Expand Down
32 changes: 14 additions & 18 deletions neural_compressor/onnxrt/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import os
import tempfile
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import onnx

from neural_compressor.common import logger
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
from neural_compressor.common.utils import logger
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader
from neural_compressor.onnxrt.quantization.config import FRAMEWORK_NAME
from neural_compressor.onnxrt.quantization.quantize import _quantize
Expand All @@ -39,7 +39,8 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
def autotune(
model_input: Union[Path, str],
tune_config: TuningConfig,
eval_fns: Union[Dict, List[Dict], Callable] = None,
eval_fn: Callable,
eval_args: Optional[Tuple[Any]] = None,
calibration_data_reader: CalibrationDataReader = None,
) -> Union[None, onnx.ModelProto]:
"""The main entry of auto-tune.
Expand All @@ -51,27 +52,22 @@ def autotune(
Support:
Expand parameters to a list of parameters like TuningConfig(config_set=[RTNConfig(weight_bits=[4, 8])])
Pass a list of configs like TuningConfig(config_set=[RTNConfig(), GPTQConfig()])
eval_fns (Union[Dict, List[Dict], Callable]): evaluate functions.
During evaluation, autotune will only pass model path as input into eatch function.
Support:
single eval function,
Dict like {"eval_fn": eval_acc} or {"eval_fn": eval_acc, "weight": 1.0, "name": "accuracy"},
List of Dict, like [
{"eval_fn": eval_acc, "weight": 0.5},
{"eval_fn": eval_perf, "weight": 0.5, "name": "accuracy"},
]
eval_fn (Callable): evaluate function.
During evaluation, autotune will only pass model path as the input of function.
eval_args (Optional[Tuple[Any]]): evaluate arguments.
Positional arguments for `eval_fn`.
calibration_data_reader (CalibrationDataReader): dataloader for calibration.
"""
best_quant_model = None
evaluator.set_eval_fn_registry(eval_fns)
evaluator.self_check()
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
try:
baseline: float = evaluator.evaluate(model_input)
baseline: float = eval_func_wrapper.evaluate(model_input)
except Exception as e:
print(e)
if "'str' object has no attribute 'SerializeToString'" in str(e):
logger.warning("Please refine your eval_fns to accept model path (str) as input.")
logger.warning("Please refine your eval_fn to accept model path (str) as input.")
exit(0)
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
Expand Down Expand Up @@ -105,7 +101,7 @@ def autotune(
Path(model_input).parent.joinpath("config.json").as_posix(),
Path(tmp_dir).joinpath("config.json").as_posix(),
)
eval_result: float = evaluator.evaluate(Path(tmp_dir).joinpath(Path(model_input).name).as_posix())
eval_result: float = eval_func_wrapper.evaluate(Path(tmp_dir).joinpath(Path(model_input).name).as_posix())
tuning_logger.evaluation_end()
logger.info("Evaluation result: %.4f", eval_result)
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
Expand Down
14 changes: 7 additions & 7 deletions neural_compressor/tensorflow/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.

from copy import deepcopy
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import tensorflow as tf

from neural_compressor.common import logger
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
from neural_compressor.common.utils import dump_elapsed_time
from neural_compressor.tensorflow.quantization import quantize_model
from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME, StaticQuantConfig
Expand All @@ -39,16 +39,16 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
def autotune(
model: Union[str, tf.keras.Model, BaseModel],
tune_config: TuningConfig,
eval_fns: Optional[Union[Dict, List[Dict]]] = None,
eval_fn: Callable,
eval_args: Optional[Tuple[Any]] = None,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
) -> Optional[BaseModel]:
"""The main entry of auto-tune."""
best_quant_model = None
evaluator.set_eval_fn_registry(eval_fns)
evaluator.self_check()
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
baseline: float = evaluator.evaluate(model)
baseline: float = eval_func_wrapper.evaluate(model)
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader):
Expand All @@ -58,7 +58,7 @@ def autotune(
q_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration)
tuning_logger.quantization_end()
tuning_logger.evaluation_start()
eval_result: float = evaluator.evaluate(q_model)
eval_result: float = eval_func_wrapper.evaluate(q_model)
tuning_logger.evaluation_end()
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
tuning_logger.trial_end(trial_index)
Expand Down
31 changes: 22 additions & 9 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

from copy import deepcopy
from typing import Dict, List, Optional, Union
from typing import Callable, List, Optional, Union

import torch

from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
from neural_compressor.common.utils import dump_elapsed_time
from neural_compressor.torch.quantization import quantize
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig
Expand Down Expand Up @@ -46,27 +46,35 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
def autotune(
model: torch.nn.Module,
tune_config: TuningConfig,
eval_fns: Optional[Union[Dict, List[Dict]]] = None,
eval_fn: Callable,
eval_args=None,
run_fn=None,
run_args=None,
example_inputs=None,
) -> Optional[torch.nn.Module]:
"""The main entry of auto-tune."""
best_quant_model = None
evaluator.set_eval_fn_registry(eval_fns)
evaluator.self_check()
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
baseline: float = evaluator.evaluate(model)
baseline: float = eval_func_wrapper.evaluate(model)
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader):
tuning_logger.trial_start(trial_index=trial_index)
tuning_logger.quantization_start()
logger.info(quant_config.to_dict())
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(deepcopy(model), quant_config=quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
q_model = quantize(
deepcopy(model),
quant_config=quant_config,
run_fn=run_fn,
run_args=run_args,
inplace=True,
example_inputs=example_inputs,
)
tuning_logger.quantization_end()
tuning_logger.evaluation_start()
eval_result: float = evaluator.evaluate(q_model)
eval_result: float = eval_func_wrapper.evaluate(q_model)
tuning_logger.evaluation_end()
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
tuning_logger.trial_end(trial_index)
Expand All @@ -76,7 +84,12 @@ def autotune(
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(
deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True
deepcopy(model),
quant_config=best_quant_config,
run_fn=run_fn,
run_args=run_args,
inplace=True,
example_inputs=example_inputs,
)
best_quant_model = q_model # quantize model inplace
break
Expand Down
Loading

0 comments on commit 3a254e9

Please sign in to comment.