diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index bd263f2b173..151b131b204 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -642,7 +642,9 @@ def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0): if len(output.shape) <= 2: max_value = torch.max(torch.abs(output)) else: - max_value = torch.max(torch.abs(output.reshape(output.shape[0], -1)), dim=-1).values + output = output.reshape(output.shape[0], -1) + output_q = output_q.reshape(output_q.shape[0], -1) + max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1) max_value = torch.clip(max_value, 1e-5) output = output / max_value ##FIXME need copy not replace output_q = output_q / max_value @@ -712,7 +714,7 @@ def _update_scales_for_auto(self, absorb_scales, weight_scales): weight_scale = self._reshape_scale_for_weight(layer, weight_scale) layer.update_scale(input_scale, weight_scale) ##FIXME - def _get_one_sample_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes): + def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes): self._change_qdq_for_auto(enable=False) forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output @@ -793,7 +795,7 @@ def dict_to_list(dic): return best_alpha def _auto_tune_alpha_new( - self, input_maxes, auto_calib_iter=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min" + self, input_maxes, calib_sample_num=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min" ): """Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly. @@ -801,7 +803,7 @@ def _auto_tune_alpha_new( Also, it reduces the memory usage at the cost of increasingtuning time TODO may have compatibility issue when setting folding=True :param input_maxes: - :param auto_calib_iter: + :param calib_sample_num: :param alpha_min: :param alpha_max: :param alpha_step: @@ -828,9 +830,10 @@ def _auto_tune_alpha_new( self.absorb_to_layer, input_maxes, default_alpha, tuning=True ) self._update_scales_for_auto(absorb_input_scales, weight_scales) - loss_alphas = {} cnt = 0 - multiply_factor = auto_calib_iter // 4 if auto_calib_iter >= 4 else auto_calib_iter + alpha_update_iter = 0 + # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha + multiply_factor = calib_sample_num // 4 if calib_sample_num >= 4 else calib_sample_num best_alphas = default_alpha if not self.dataloader: @@ -838,6 +841,7 @@ def _auto_tune_alpha_new( return best_alphas try: for input, label in self.dataloader: + loss_alphas = {} best_alphas_per_module = best_alphas if isinstance(best_alphas, dict): for key in self.absorb_to_layer.keys(): @@ -845,7 +849,7 @@ def _auto_tune_alpha_new( for layer_name in layer_names: best_alphas_per_module[layer_name] = best_alphas_per_module[key] - loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) + loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) if loss_alphas == {}: loss_alphas = loss_tmp else: @@ -853,26 +857,22 @@ def _auto_tune_alpha_new( cur_loss = loss_alphas[key] for alpha_key in cur_loss.keys(): cur_loss[alpha_key] += loss_tmp[key][alpha_key] - if isinstance(input, list): - input = move_input_to_device(input, self.device) - for inp in input: - cnt += inp.shape[0] - else: - cnt += input.shape[0] - - if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor: + cnt += self.dataloader.batch_size + if cnt // multiply_factor >= 1: + alpha_update_iter += 1 + cnt = 0 best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) for key in best_alphas.keys(): - logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}") + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") absorb_input_scales, weight_scales = self._cal_scales( self.absorb_to_layer, input_maxes, best_alphas, tuning=True ) self._update_scales_for_auto(absorb_input_scales, weight_scales) - loss_alphas = {} ##TODO check need to remove this one - if cnt >= auto_calib_iter: + if cnt >= calib_sample_num: break except: for input in self.dataloader: + loss_alphas = {} best_alphas_per_module = best_alphas if isinstance(best_alphas, dict): for key in self.absorb_to_layer.keys(): @@ -880,7 +880,7 @@ def _auto_tune_alpha_new( for layer_name in layer_names: best_alphas_per_module[layer_name] = best_alphas_per_module[key] - loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) + loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) if loss_alphas == {}: loss_alphas = loss_tmp else: @@ -888,28 +888,24 @@ def _auto_tune_alpha_new( cur_loss = loss_alphas[key] for alpha_key in cur_loss.keys(): cur_loss[alpha_key] += loss_tmp[key][alpha_key] - if isinstance(input, list): - input = move_input_to_device(input, self.device) - for inp in input: - cnt += inp.shape[0] - else: - cnt += input.shape[0] + cnt += self.dataloader.batch_size + if cnt // multiply_factor >= 1: + alpha_update_iter += 1 + cnt = 0 - if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor: best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) for key in best_alphas.keys(): - logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}") + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") absorb_input_scales, weight_scales = self._cal_scales( self.absorb_to_layer, input_maxes, best_alphas, tuning=True ) self._update_scales_for_auto(absorb_input_scales, weight_scales) - loss_alphas = {} ##TODO check need to remove this one - if cnt >= auto_calib_iter: + if cnt >= calib_sample_num: break best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) for key in best_alphas.keys(): - logger.info(f"final {key}:{best_alphas[key]}") + logger.info(f"Final alpha {key}:{best_alphas[key]}") self._qdq_model_unwrapper_for_auto() logger.info("auto tuning done") return best_alphas @@ -999,7 +995,7 @@ def transform( if alpha == "auto": self.alpha_per_layer = self._auto_tune_alpha_new( - input_maxes_abs, auto_calib_iter=32, **auto_alpha_args + input_maxes_abs, calib_sample_num=32, **auto_alpha_args ) ##save the alpha if alpha == "auto": diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index 4664b976678..4b4201edcc3 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -53,10 +53,11 @@ def __iter__(self): class LLMCalibDataloader: def __init__(self): - self.batch_size = 1 + self.batch_size = 3 def __iter__(self): - yield torch.ones([1, 3], dtype=torch.long) + for i in range(4): + yield torch.ones([3, 3], dtype=torch.long) class TestSqDepthwiseConv(unittest.TestCase): @@ -64,7 +65,7 @@ class TestSqDepthwiseConv(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3, 1, 1)) @@ -141,7 +142,7 @@ class TestSqConvOpFuseAuto(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3, 1, 1)) @@ -181,7 +182,7 @@ class TestSqConvOpFuse(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3, 1, 1)) @@ -386,21 +387,21 @@ class TestSqListInput(unittest.TestCase): def setUpClass(self): class ListDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield [torch.rand((1, 3))] class TupleDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield (torch.rand((1, 3))) class ListTupleDataLoader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): input1 = torch.rand((1, 3)) @@ -499,7 +500,7 @@ class TestAlphaAutoLinear(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3)) @@ -535,7 +536,7 @@ class TestSqLinearOpFuse(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 3)) @@ -736,6 +737,8 @@ def test_sq_qkv(self): sq.transform(alpha=0.5, calib_iter=-1, folding=False) assert isinstance(sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper) + +class TestExample(unittest.TestCase): def test_sq_quant(self): from neural_compressor import PostTrainingQuantConfig, quantization @@ -763,10 +766,11 @@ def forward(self, x): class CalibDataloader: def __init__(self): - self.batch_size = 1 + self.batch_size = 3 def __iter__(self): - yield input_ids + for i in range(4): + yield input_ids def calib_func(model): for i in range(10): @@ -935,7 +939,7 @@ class TestSqSkipOp(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 4)) @@ -992,7 +996,7 @@ class TestSqSkipOp_attn(unittest.TestCase): def setUpClass(self): class RandDataloader: def __init__(self): - pass + self.batch_size = 1 def __iter__(self): yield torch.rand((1, 4))