diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 77797f95104..2c09b50f478 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -1691,6 +1691,8 @@ class MixedPrecisionConfig(object): model_name (str, optional): The name of the model. Default value is empty. inputs (list, optional): Inputs of model, default is []. outputs (list, optional): Outputs of model, default is []. + quant_level: Support auto, 0 and 1, 0 is conservative(fallback in op type wise), + 1(fallback in op wise), auto (default) is the combination of 0 and 1. tuning_criterion (TuningCriterion object, optional): Accuracy tuning settings, it won't work if there is no accuracy tuning process. accuracy_criterion (AccuracyCriterion object, optional): Accuracy constraint settings, @@ -1739,6 +1741,7 @@ def __init__(self, model_name="", inputs=[], outputs=[], + quant_level="auto", tuning_criterion=tuning_criterion, accuracy_criterion=accuracy_criterion, excluded_precisions=[], @@ -1750,6 +1753,7 @@ def __init__(self, self.outputs = outputs self.backend = backend self.device = device + self.quant_level = quant_level self.excluded_precisions = excluded_precisions self.accuracy_criterion = accuracy_criterion self.tuning_criterion = tuning_criterion @@ -1788,6 +1792,16 @@ def model_name(self, model_name): if _check_value("model_name", model_name, str): self._model_name = model_name + @property + def quant_level(self): + """Get the quantization level.""" + return self._quant_level + + @quant_level.setter + def quant_level(self, quant_level): + """Set the quantization level.""" + self._quant_level = quant_level + @property def accuracy_criterion(self): """Get the accuracy criterion.""" diff --git a/neural_compressor/strategy/auto_mixed_precision.py b/neural_compressor/strategy/auto_mixed_precision.py index d0a1d7c1c6b..108127632de 100644 --- a/neural_compressor/strategy/auto_mixed_precision.py +++ b/neural_compressor/strategy/auto_mixed_precision.py @@ -18,11 +18,11 @@ """The auto-mixed precision strategy.""" import copy -import numpy as np -from collections import OrderedDict +from collections import OrderedDict, defaultdict +from itertools import groupby from .strategy import strategy_registry, TuneStrategy from ..utils import logger -from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_sampler import FallbackTuningSampler from .utils.tuning_structs import OpTuningConfig from neural_compressor.adaptor.torch_utils.mixed_precision import ipex_mixed_precision @@ -50,6 +50,7 @@ def _initialize_config(self, conf): config.domain = getattr(config, 'domain', None) config.reduce_range = getattr(config, 'reduce_range', None) config.example_inputs = getattr(config, 'example_inputs', None) + config.quant_level = getattr(config, "quant_level", "auto") return config def next_tune_cfg(self): @@ -79,54 +80,116 @@ def next_tune_cfg(self): if not target_dtypes: target_dtypes = ['bf16'] # step1. target_dtype AMAP, collect the ops that support target_dtype - bf16_items_name = [] + lower_precision_items_name = [] op_tuning_cfg = {} for idx, target_dtype in enumerate(target_dtypes): - bf16_items = tuning_space.query_items_by_quant_mode(target_dtype) - if len(bf16_items) == 0 and \ - not (idx == len(target_dtypes) - 1 and len(bf16_items_name) == 0): + lower_precision_items = tuning_space.query_items_by_quant_mode(target_dtype) + if len(lower_precision_items) == 0 and \ + not (idx == len(target_dtypes) - 1 and len(lower_precision_items_name) == 0): continue - bf16_items_name = [item.name for item in bf16_items] + lower_precision_items_name = [item.name for item in lower_precision_items] op_tuning_cfg = deepcopy(initial_op_tuning_cfg) - for op_name_type in bf16_items_name: + for op_name_type in lower_precision_items_name: op_tuning_cfg[op_name_type] = \ OpTuningConfig(op_name_type[0], op_name_type[1], target_dtype, tuning_space) calib_sampling_size = 1 op_tuning_cfg['calib_sampling_size'] = calib_sampling_size yield op_tuning_cfg - # step2. fallback - target_dtype = 'fp32' - fallback_items_name_lst = bf16_items_name[::-1] + # step 2, fallback op into fp32 + # quant_level: + # auto: op-type-wise -> op-wise + # 0: op-type wise + # 1: op-wise + + # if quant level is auto or 0, do op type wise fallback + target_dtype = "fp32" + fallback_items_name_lst = lower_precision_items_name[::-1] if fallback_items_name_lst: - logger.info(f"Start to fallback op to {target_dtype} one by one.") - self._fallback_started() - op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst))) + logger.info("[Strategy] start fallback op into fp32.") initial_op_tuning_cfg = deepcopy(op_tuning_cfg) + if self.config.quant_level in ["auto", 0]: + logger.info(f"[Strategy] fallback op into fp32 in op type wise, \ + as quant level is {self.config.quant_level}") + for op_tuning_cfg in self.fallback_in_op_type_wise(tuning_space, fallback_items_name_lst,\ + deepcopy(initial_op_tuning_cfg), target_dtype): + yield op_tuning_cfg + + # if quant level is auto or 1, do op instance fallback + if self.config.quant_level in ["auto", 1]: + logger.info(f"[Strategy] fallback op into fp32 in op wise, \ + as quant level is {self.config.quant_level}") + for op_tuning_cfg in self.fallback_in_op_wise(tuning_space, fallback_items_name_lst,\ + deepcopy(initial_op_tuning_cfg), target_dtype): + yield op_tuning_cfg + + def fallback_in_op_type_wise(self, tuning_space, fallback_items_name_lst, initial_op_tuning_cfg, target_dtype): + """Fallback op in op type wise. + + Args: + tuning_space: tuning space + fallback_items_name_lst: the list of items to be fallback + initial_op_tuning_cfg: initial tuning config + target_dtype: target data type, such as fp32 + + Yields: + tuning config + """ + fallback_items_name_lst.sort(key=lambda x: x[1]) + op_type_groups = groupby(fallback_items_name_lst, key=lambda x: x[1]) + # key: ((op1_name, op_type1),(op2_name, op_type1), (op3_name, op_type1), ...) + # value: target dtype + ops_dtypes = OrderedDict() + for op_type, op_lst in op_type_groups: + ops_dtypes[tuple(op_lst)] = target_dtype fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], - initial_op_tuning_cfg=initial_op_tuning_cfg, - op_dtypes=op_dtypes, accumulate=False) + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=ops_dtypes, accumulate=False) op_fallback_acc_impact = OrderedDict() for op_index, op_tuning_cfg in enumerate(fallback_sampler): - op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + op_tuning_cfg['calib_sampling_size'] = -1 + yield op_tuning_cfg + acc, _ = self.last_tune_result + op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc + + def fallback_in_op_wise(self, tuning_space, fallback_items_name_lst, initial_op_tuning_cfg, target_dtype): + """Fallback op in op wise. + + Args: + tuning_space: tuning space + fallback_items_name_lst: the list of items to be fallback + initial_op_tuning_cfg: initial tuning config + target_dtype: target data type, such as fp32 + + Yields: + tuning config + """ + op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst))) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=False) + op_fallback_acc_impact = OrderedDict() + for op_index, op_tuning_cfg in enumerate(fallback_sampler): + op_tuning_cfg['calib_sampling_size'] = -1 yield op_tuning_cfg acc, _ = self.last_tune_result op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc # do accumulated fallback according to the order in the previous stage if len(op_fallback_acc_impact) > 0: - ordered_ops = sorted(op_fallback_acc_impact.keys(), key=lambda key: op_fallback_acc_impact[key], - reverse=self.higher_is_better) + ordered_ops = sorted(op_fallback_acc_impact.keys(), key=lambda key: op_fallback_acc_impact[key], \ + reverse=self.higher_is_better) op_dtypes = OrderedDict(zip(ordered_ops, [target_dtype] * len(fallback_items_name_lst))) logger.info("Start to accumulate fallback to {target_dtype}.") - initial_op_tuning_cfg = deepcopy(op_tuning_cfg) + initial_op_tuning_cfg = copy.deepcopy(op_tuning_cfg) fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], - initial_op_tuning_cfg=initial_op_tuning_cfg, - op_dtypes=op_dtypes, accumulate=True) + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=True) for op_tuning_cfg in fallback_sampler: - op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + op_tuning_cfg['calib_sampling_size'] = -1 yield op_tuning_cfg + def traverse(self): """Traverse the tuning space according to auto-mixed precision strategy.""" if self.config.backend == "ipex": diff --git a/neural_compressor/strategy/utils/tuning_sampler.py b/neural_compressor/strategy/utils/tuning_sampler.py index 7bf241b5343..01cac55df3e 100644 --- a/neural_compressor/strategy/utils/tuning_sampler.py +++ b/neural_compressor/strategy/utils/tuning_sampler.py @@ -20,7 +20,7 @@ from itertools import product import copy from collections import deque, OrderedDict, defaultdict -from typing import List, Dict, Any +from typing import List, Dict, Any, Union, Tuple from .tuning_space import TuningSpace, pattern_to_internal, pattern_to_path, quant_mode_from_pattern from .tuning_structs import OpTuningConfig from ...utils import logger @@ -382,8 +382,8 @@ class FallbackTuningSampler(TuningSampler): def __init__(self, tuning_space: TuningSpace, tuning_order_lst: List[TuningOrder], - initial_op_tuning_cfg: Dict[tuple, Any], - op_dtypes: Dict[str, str], + initial_op_tuning_cfg: Dict[Tuple, Any], + op_dtypes: Dict[Union[Tuple, Tuple[Tuple]], str], accumulate: bool, skip_first: bool = True ): @@ -414,21 +414,23 @@ def __iter__(self): # Only support fallback to lower precision. if not self.accumulate: new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) - full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, target_dtype) - self.op_complete_path[op_name_type] = copy.deepcopy(full_path) - config_args = {} - self._set_dtype(op_name_type, config_args) - internal_pattern = pattern_to_internal(target_dtype) - quant_mode = quant_mode_from_pattern(internal_pattern) - new_op_config = OpTuningConfig(op_name_type[0], op_name_type[1], - quant_mode, self.tuning_space, - kwargs=config_args) + op_name_type_lst = [op_name_type] if len(op_name_type) != 1 and \ + isinstance(op_name_type[1], str) else op_name_type + for op_name_type in op_name_type_lst: + full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, target_dtype) + self.op_complete_path[op_name_type] = copy.deepcopy(full_path) + config_args = {} + self._set_dtype(op_name_type, config_args) + internal_pattern = pattern_to_internal(target_dtype) + quant_mode = quant_mode_from_pattern(internal_pattern) + new_op_config = OpTuningConfig(op_name_type[0], op_name_type[1], quant_mode, \ + self.tuning_space, kwargs=config_args) - new_tune_cfg.update({op_name_type: new_op_config}) + new_tune_cfg.update({op_name_type: new_op_config}) if self.accumulate and skip_first: # skip the first one skip_first = False continue - logger.info(f"fallback {op_name_type} to {target_dtype}") + logger.info(f"fallback {op_name_type_lst} to {target_dtype}") yield new_tune_cfg # need to skip the first one class LowerBitsSampler(TuningSampler): diff --git a/test/mixed_precision/test_mixed_precision.py b/test/mixed_precision/test_mixed_precision.py index 28509ff4bd4..e587df3e5fc 100644 --- a/test/mixed_precision/test_mixed_precision.py +++ b/test/mixed_precision/test_mixed_precision.py @@ -328,7 +328,7 @@ def test_mixed_precision_with_eval_func(self): def eval(model): return 0.5 - result = [0., 0.1, 0.102, 0.1006, 0.1005, 0.1004, 0.1002] + result = [0., 0.1, 0.102, 0.1003, 0.1005, 0.1004, 0.1002] def eval2(model): del result[0] @@ -371,6 +371,83 @@ def eval2(model): output_model = fit(self.tf_model, conf, eval) self.assertTrue(any([i.op == 'Cast' for i in output_model.graph_def.node])) + + def test_mixed_precision_with_quant_level_1(self): + + result = [0., 0.1, 0.102] + def eval_func(model): + del result[0] + return result[0] + + conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto") + + output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func) + self.assertTrue(any([i.op == 'Cast' for i in output_model.graph_def.node])) + self.assertEqual(conf.inputs, 'input') + self.assertEqual(conf.outputs, 'final') + + def test_mixed_precision_with_quant_level_2(self): + + result = [0., 1, 0.9, 1.1] + # meet acc if fallback all conv + def eval_func(model): + del result[0] + return result[0] + + conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto") + + output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func) + # no cast in output model + self.assertFalse(any([i.op == 'Cast' for i in output_model.graph_def.node])) + + def test_mixed_precision_with_quant_level_3(self): + + result = [0., 1, 0.9, 0.9, 1.1] + # meet acc if fallback 1 conv + def eval_func(model): + del result[0] + return result[0] + + conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level="auto") + + output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func) + # no cast in output model + count_cast = 0 + for node in output_model.graph_def.node: + if node.op == "Cast": + count_cast += 1 + self.assertEqual(count_cast, 4) + + def test_mixed_precision_with_quant_level_4(self): + + result = [0., 1, 0.9, 0.9, 1.1] + # meet acc if fallback the second conv + def eval_func(model): + del result[0] + return result[0] + + conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level=1) + + output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func) + # no cast in output model + count_cast = 0 + for node in output_model.graph_def.node: + if node.op == "Cast": + count_cast += 1 + self.assertEqual(count_cast, 4) + + def test_mixed_precision_with_quant_level_5(self): + result = [0., 1, 0.9, 0.9, 0.9] + # meet not meet + def eval_func(model): + del result[0] + return result[0] + + conf = MixedPrecisionConfig(inputs="input", outputs="final", quant_level=0) + + output_model = mix_precision.fit(self.tf_model, conf, eval_func=eval_func) + self.assertIsNone(output_model) + @unittest.skipIf(PT_VERSION.release < Version("1.11.0").release, "Please use PyTroch 1.11 or higher version for mixed precision.") def test_mixed_precision_with_eval_func_pt(self):