diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index c23d702aa60..26a1ff8baaa 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -160,7 +160,7 @@ def __init__(self, framework_specific_info): # sq algo and args self.sq = None - self.cur_sq_args = None + self.cur_sq_args = {} def smooth_quant(self, model, dataloader, iterations, alpha=0.5, folding=True, percentile=99.999, op_types=['MatMul', 'Gemm', 'Conv', 'FusedConv'], @@ -186,29 +186,38 @@ def smooth_quant(self, model, dataloader, iterations, alpha=0.5, folding=True, return self.smooth_quant_model from .ox_utils.smooth_quant import ORTSmoothQuant - # TODO remove quantize_config as it no consumer - quantize_config = None - # pre-optimization -> sq + + # set params to cur_sq_args + self.cur_sq_args['alpha'] = alpha + self.cur_sq_args['folding'] = folding + self.cur_sq_args['percentile'] = percentile + self.cur_sq_args['op_types'] = op_types + self.cur_sq_args['scales_per_op'] = scales_per_op + self.cur_sq_args['calib_iter'] = iterations + + # pre-optimization self._pre_optimize(model) + # assign the algo to the adaptor, so adaptor can call it later when needed self.sq = ORTSmoothQuant(self.pre_optimized_model, dataloader, self.reduce_range, self.backend) - self.smooth_quant_model = self.sq.transform( - alpha, folding, percentile, op_types, scales_per_op, iterations, quantize_config) + self.sq.record_max_info = record_max_info + self.smooth_quant_model = self.sq.transform(**self.cur_sq_args) logger.info("Updated the pre-optimized model with smooth quant model.") # TODO double-check the smooth_quant_model and pre_optimized_model to make sure there no two fp32 model replicas self.pre_optimized_model = self.smooth_quant_model return self.smooth_quant_model def _need_smooth_quant(self, tune_cfg) -> bool: - # compare the alpha from tune_cfg and current alpha to decide whether re-smooth model or not - # TODO - return False - - def _parse_sq_args(self, tune_cfg, cur_sq_args) -> Dict: - # parse the sq args according to the tune cfg and current sq args - # TODO - return {} - + """Check the model needs smooth quant or not.""" + recipe_cfgs = tune_cfg.get('recipe_cfgs', None) + if recipe_cfgs and recipe_cfgs.get('smooth_quant', False) \ + and recipe_cfgs['smooth_quant_args'].get('alpha', None): + # update alpha according to tune_cfg + self.cur_sq_args['alpha'] = \ + tune_cfg['recipe_cfgs']['smooth_quant_args']['alpha'] + return True + else: + return False @dump_elapsed_time("Pass quantize model") def quantize(self, tune_cfg, model, data_loader, q_func=None): @@ -224,17 +233,9 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): Returns: (dict): quantized model - """ - # two steps to re-smooth the model if needed - if self._need_smooth_quant(tune_cfg): - # step1. recover the sq to original fp32 model - self.sq.recover() - new_sq_args = self._parse_sq_args(tune_cfg, self.cur_sq_args) - # step2. re-smooth the model with new alpha - model = self.smooth_quant(model=model, dataloader=data_loader, iterations=new_sq_args['iterations'],\ - alpha=new_sq_args['alpha'], folding=new_sq_args['folding'], scales_per_op=new_sq_args['scales_per_op']) + """ assert q_func is None, "quantization aware training has not been supported on ONNXRUNTIME" - if self.smooth_quant_model is not None: + if self.smooth_quant_model is not None and model.is_smoothquant_model(): model = self.smooth_quant_model elif self.pre_optimized_model is not None: model = self.pre_optimized_model @@ -277,47 +278,19 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format( repr(e))) tmp_model = model + + # smooth quant the model if needed + if self._need_smooth_quant(tune_cfg) and not tmp_model.is_smoothquant_model(): + self.sq.model = tmp_model + self.sq.record_max_info = False + tmp_model = self.sq.transform(**self.cur_sq_args) + iterations = tune_cfg.get('calib_iteration', 1) calib_sampling_size = tune_cfg.get('calib_sampling_size', 1) if not self.dynamic: - if isinstance(data_loader, BaseDataLoader): - batch_size = data_loader.batch_size - try: - for i in range(batch_size): - if calib_sampling_size % (batch_size - i) == 0: - calib_batch_size = batch_size - i - if i != 0: # pragma: no cover - logger.warning("Reset `calibration.dataloader.batch_size` field " - "to {}".format(calib_batch_size) + - " to make sure the sampling_size is " - "divisible exactly by batch size") - break - tmp_iterations = int(math.ceil(calib_sampling_size / calib_batch_size)) - data_loader.batch(calib_batch_size) - quantize_params = self._get_quantize_params(tmp_model, data_loader, \ - quantize_config, tmp_iterations) - except Exception as e: # pragma: no cover - if 'Got invalid dimensions for input' in str(e): - logger.warning("Please set sampling_size to a multiple of {}".format( - str(e).partition('Expected: ')[2].partition('\n')[0])) - exit(0) - logger.warning( - "Fail to forward with batch size={}, set to {} now.". - format(batch_size, 1)) - data_loader.batch(1) - quantize_params = self._get_quantize_params(tmp_model, data_loader, \ - quantize_config, calib_sampling_size) - else: # pragma: no cover - if hasattr(data_loader, 'batch_size') and \ - calib_sampling_size % data_loader.batch_size != 0: - logger.warning( - "Please note that calibration sampling size {} " \ - "isn't divisible exactly by batch size {}. " \ - "So the real sampling size is {}.". - format(calib_sampling_size, data_loader.batch_size, - data_loader.batch_size * iterations)) - quantize_params = self._get_quantize_params(tmp_model, data_loader, \ - quantize_config, iterations) + calib_iterations = self._reset_calib_iter(data_loader, calib_sampling_size, iterations) + quantize_params = self._get_quantize_params(tmp_model, data_loader, \ + quantize_config, calib_iterations) else: quantize_params = None self.quantize_params = quantize_params @@ -350,6 +323,45 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): self._dump_model_op_stats(tmp_model) tmp_model.topological_sort() return tmp_model + + def _reset_calib_iter(self, data_loader, cfg_calib_sampling_size, cfg_calib_iter): + """Check and reset calibration iterations according to calib_sampleing_size and dataloader batch_size.""" + if isinstance(data_loader, BaseDataLoader): + batch_size = data_loader.batch_size + try: + for i in range(batch_size): + if cfg_calib_sampling_size % (batch_size - i) == 0: + calib_batch_size = batch_size - i + if i != 0: # pragma: no cover + logger.warning("Reset `calibration.dataloader.batch_size` field " + "to {}".format(calib_batch_size) + + " to make sure the sampling_size is " + "divisible exactly by batch size") + break + tmp_iterations = int(math.ceil(cfg_calib_sampling_size / calib_batch_size)) + data_loader.batch(calib_batch_size) + calib_iterations = tmp_iterations + except Exception as e: # pragma: no cover + if 'Got invalid dimensions for input' in str(e): + logger.warning("Please set sampling_size to a multiple of {}".format( + str(e).partition('Expected: ')[2].partition('\n')[0])) + exit(0) + logger.warning( + "Fail to forward with batch size={}, set to {} now.". + format(batch_size, 1)) + data_loader.batch(1) + calib_iterations = cfg_calib_sampling_size + else: # pragma: no cover + if hasattr(data_loader, 'batch_size') and \ + cfg_calib_sampling_size % data_loader.batch_size != 0: + logger.warning( + "Please note that calibration sampling size {} " \ + "isn't divisible exactly by batch size {}. " \ + "So the real sampling size is {}.". + format(cfg_calib_sampling_size, data_loader.batch_size, + data_loader.batch_size * cfg_calib_iter)) + calib_iterations = cfg_calib_iter + return calib_iterations def _generate_qconfig(self, model, tune_cfg, quantize_params): tune_cfg = copy.deepcopy(tune_cfg) @@ -512,10 +524,11 @@ def _get_quantize_params(self, model, data_loader, quantize_config, iterations): model = ONNXModel(model) black_nodes = [node for node in quantize_config if quantize_config[node]=='fp32'] white_nodes = [node for node in quantize_config if quantize_config[node]!='fp32'] + augment = ONNXRTAugment(model, \ data_loader, self.quantizable_op_types, \ black_nodes=black_nodes, white_nodes=white_nodes, \ - iterations=list(range(0, quantize_config['calib_iteration'])), + iterations=list(range(0, iterations)), \ backend=self.backend, reduce_range=self.reduce_range) self.min_max = augment.dump_minmax(quantize_config) quantize_params = augment.dump_calibration(quantize_config, min_max=self.min_max) diff --git a/neural_compressor/adaptor/ox_utils/smooth_quant.py b/neural_compressor/adaptor/ox_utils/smooth_quant.py index 7bca8cf00c9..87d9fe8f9f6 100644 --- a/neural_compressor/adaptor/ox_utils/smooth_quant.py +++ b/neural_compressor/adaptor/ox_utils/smooth_quant.py @@ -121,6 +121,7 @@ def __init__(self, model, dataloader, reduce_range=False, backend='CPUExecutionP self.tensors_to_node = {} self.replace_input = [] self.ops_to_absorb = [] + self.record_max_info = False self._build_absorb_function() def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm', 'Conv', 'MatMul', 'FusedConv'], @@ -142,6 +143,7 @@ def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm' A FP32 model with the same architecture as the orig model but with different weight which will be benefit to quantization """ + self.clean() if isinstance(alpha, float) and (alpha < 0 or alpha > 1): logger.warning("alpha should be a float value in [0, 1] or 'auto' ") if alpha < 0: @@ -155,12 +157,16 @@ def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm' if need_calibration: self._dump_op_info(percentile, op_types, calib_iter, quantize_config) - if alpha == 'auto': - alpha = self._auto_tune_alpha(calib_iter, **auto_alpha_args) + if self.record_max_info: + return self.model + + if alpha == 'auto': + alpha = self._auto_tune_alpha(calib_iter, **auto_alpha_args) + + scales = self._get_smooth_scales(alpha) + self._insert_smooth_mul_op(scales) + self._adjust_weights(scales) - scales = self._get_smooth_scales(alpha) - self._insert_smooth_mul_op(scales) - self._adjust_weights(scales) self.model.add_nodes(self.new_added_mul_nodes) self.model.model.graph.value_info.extend(self.new_added_value_info) self.model.add_initializers(self.new_init_tensors) @@ -202,6 +208,15 @@ def recover(self): self.new_added_mul_nodes = [] self.new_init_tensors = [] self.new_added_value_info = [] + self.replace_input = [] + + def clean(self): + """Clean data collected from calibration.""" + self.tensor_scales_info = {} + self.new_added_mul_nodes = [] + self.new_init_tensors = [] + self.new_added_value_info = [] + self.replace_input = [] def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter): """Check need calibration or not. diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 9eed30e3636..9205d507f1e 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -846,8 +846,8 @@ def smooth_quant_args(val=None): smooth_quant_args = {"alpha": numpy.arange(0.1, 0.5, 0.05).tolist()} """ if isinstance(v, str): - assert v == "auto", "the alpha of sq only supports float and 'auto'" - elif isinstance(v, float) or isinstance(v, int): + assert v == "auto", "the alpha of sq only supports float, list and 'auto'" + elif isinstance(v, float) or isinstance(v, int) or isinstance(v, list): continue else: logger.warning("Ignore the alpha as it's not a list, int or float.") diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index 29b1569ec9f..575c24631bf 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -792,3 +792,14 @@ def match_parent_path( current_node = matched_parent return matched_parents + + def is_smoothquant_model(self): + """Check the model is smooth quantized or not. + + Returns: + bool: the model is smooth quantized or not. + """ + for init in self.model.graph.initializer: + if "_smooth_scale" in init.name: + return True + return False \ No newline at end of file diff --git a/test/algorithm/test_smooth_quant_onnx.py b/test/algorithm/test_smooth_quant_onnx.py index 1f79b12fce5..13d1f4f5bdd 100644 --- a/test/algorithm/test_smooth_quant_onnx.py +++ b/test/algorithm/test_smooth_quant_onnx.py @@ -6,6 +6,52 @@ import shutil from neural_compressor.data import Datasets, DATALOADERS from neural_compressor.adaptor.ox_utils.smooth_quant import ORTSmoothQuant +import logging +logger = logging.getLogger("neural_compressor") + +def check_model_is_same(model_proto1, model_proto2): + # Compare if both models have the same number of nodes + if len(model_proto1.graph.node) != len(model_proto2.graph.node): + return False + + # Compare individual nodes in both models + for node1, node2 in zip(model_proto1.graph.node, model_proto2.graph.node): + print(node1.name, node2.name) + # Check node name, input, output, and op_type + if node1.name != node2.name or \ + node1.op_type != node2.op_type or \ + node1.input != node2.input or \ + node1.output != node2.output: + return False + + # Check node attribure + if len(node1.attribute) != len(node2.attribute): + return False + + for attr1, attr2 in zip(node1.attribute, node2.attribute): + if attr1.name == attr2.name: + if attr1.type == onnx.AttributeProto.FLOATS: + # Compare float attributes using numpy.allclose + if not attr1.floats == attr2.floats: + return False + elif attr1.type == onnx.AttributeProto.INTS: + # Compare int attributes + if attr1.ints != attr2.ints: + return False + # Compare initializer + init1 = {init.name: init for init in model_proto1.graph.initializer} + init2 = {init.name: init for init in model_proto2.graph.initializer} + for name in init1.keys(): + if name not in init2 or \ + not (numpy_helper.to_array(init1[name]) == numpy_helper.to_array(init2[name])).all(): + return False + + # Compare model inputs and outputs + if model_proto1.graph.input != model_proto2.graph.input or \ + model_proto1.graph.output != model_proto2.graph.output: + return False + + return True def build_onnx_model(): @@ -43,6 +89,8 @@ def setUpClass(self): self.model = build_onnx_model() dataset = Datasets("onnxrt_qdq")["dummy_v2"]((5,5), (5,1)) self.dataloader = DATALOADERS['onnxrt_qlinearops'](dataset) + fixed_dataset = Datasets("onnxrt_qdq")['dummy'](shape=(5,5,5), label=True) + self.fixed_dataloader = DATALOADERS['onnxrt_qlinearops'](fixed_dataset) @classmethod def tearDownClass(self): @@ -110,6 +158,89 @@ def test_sq(self): sq_tensor = numpy_helper.to_array(sq.model.get_initializer(init.name)) self.assertAlmostEqual(tensor[0][0], sq_tensor[0][0], 4) - + def _test_sq_tune_alpha_common(self, eval_func, alpha=np.arange(0.1, 0.2, 0.05).tolist(), quant_level=1): + from neural_compressor import quantization + from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion + tuning_criterion = TuningCriterion(max_trials=8) + + fp32_model = copy.deepcopy(self.model) + conf = PostTrainingQuantConfig( + quant_level=quant_level, + tuning_criterion=tuning_criterion, + calibration_sampling_size=4, + recipes={"smooth_quant": True, + "smooth_quant_args": {"alpha": alpha} + } + ) + q_model = quantization.fit( + fp32_model, + conf, + calib_dataloader=self.fixed_dataloader, + eval_func=eval_func, + ) + self.assertIsNotNone(q_model) + return q_model + + def test_tune_sq_alpha(self): + from functools import partial + def fake_eval(model, eval_result_lst): + acc = eval_result_lst.pop(0) + return acc + + # test for quantized models generated by int alpha and list alpha whether they are the same + partial_fake_eval = partial(fake_eval, eval_result_lst = [1, 1.1] ) + q_model_without_tune = self._test_sq_tune_alpha_common(partial_fake_eval, alpha=0.5) + partial_fake_eval = partial(fake_eval, eval_result_lst = [1, 0.8, 1.1] ) + q_model_with_tune = self._test_sq_tune_alpha_common(partial_fake_eval, alpha=[0.4, 0.5]) + self.assertTrue(check_model_is_same(q_model_without_tune.model, q_model_with_tune.model)) + + # test for alpha is a list + for eval_result_lst, note in [ + ([1, 0.8, 1.1, 0.7, 1.1], "Expect tuning ends at 2nd trial with alpha is 0.15"), + ([1, 0.8, 0.9, 0.7, 1.1], "Expect tuning ends at 4th trial with alpha is 0.15"), + ([1, 0.9, 0.8, 0.7, 1.1], "Expect tuning ends at 4th trial with alpha is 0.10") + ]: + logger.info(f"test_sq_tune_alpha_common with eval_result_lst: {eval_result_lst}") + logger.info(note) + partial_fake_eval = partial(fake_eval, eval_result_lst = eval_result_lst ) + self._test_sq_tune_alpha_common(partial_fake_eval) + + # test for various alphas + for eval_result_lst, alpha, note in [ + ([1, 0.8, 1.1, 0.7, 1.1], 0.5 ,"Expect tuning ends at 2nd trial with alpha is 0.5 and not tune sq's alpha."), + ([1, 0.8, 0.9, 0.7, 1.1], [0.5], "Expect tuning ends at 4th trial with alpha is 0.5 and not tune sq's alpha."), + ([1, 0.9, 0.8, 0.7, 1.1], [0.5, 0.7, 0.9] ,"Expect tuning ends at 4th trial with alpha is 0.5") + ]: + logger.info(f"test_sq_tune_alpha_common with eval_result_lst: {eval_result_lst}, alpha: {alpha}") + logger.info(note) + partial_fake_eval = partial(fake_eval, eval_result_lst=eval_result_lst) + self._test_sq_tune_alpha_common(partial_fake_eval, alpha=alpha) + + # test for quant_level is auto or 0 + for eval_result_lst, alpha, quant_level, note in [ + ( + [1, 0.8, 1.1, 0.7, 1.1], + np.arange(0.1, 0.2, 0.05).tolist(), + "auto", + "Expect tuning ends at 2nd trial with alpha is 0.15." + ), + ( + [1, 0.8, 0.9, 0.7, 1.1], + np.arange(0.1, 0.2, 0.05).tolist(), + "auto", + "Expect tuning ends at 4th trial with alpha is 0.15 at basic strategy." + ), + ( + [1, 1.1, 0.8, 0.7, 1.1], + np.arange(0.1, 0.2, 0.05).tolist(), + 0, + "Expect tuning ends at 1th trial with alpha is 0.1") + ]: + logger.info(f"test_sq_tune_alpha_common with ") + logger.info(f"eval_result_lst: {eval_result_lst}, alpha: {alpha}, quant_level: {quant_level}") + logger.info(note) + partial_fake_eval = partial(fake_eval, eval_result_lst=eval_result_lst) + self._test_sq_tune_alpha_common(partial_fake_eval, alpha=alpha, quant_level=quant_level) + if __name__ == '__main__': unittest.main()