Skip to content

Commit

Permalink
fix bug in smoothquant for auto alpha
Browse files Browse the repository at this point in the history
Signed-off-by: y <[email protected]>
  • Loading branch information
xin3he committed Sep 27, 2023
1 parent 0426f82 commit 3ae01fc
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 35 deletions.
58 changes: 27 additions & 31 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,9 @@ def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0):
if len(output.shape) <= 2:
max_value = torch.max(torch.abs(output))
else:
max_value = torch.max(torch.abs(output.reshape(output.shape[0], -1)), dim=-1).values
output = output.reshape(output.shape[0], -1)
output_q = output_q.reshape(output_q.shape[0], -1)
max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1)
max_value = torch.clip(max_value, 1e-5)
output = output / max_value ##FIXME need copy not replace
output_q = output_q / max_value
Expand Down Expand Up @@ -712,7 +714,7 @@ def _update_scales_for_auto(self, absorb_scales, weight_scales):
weight_scale = self._reshape_scale_for_weight(layer, weight_scale)
layer.update_scale(input_scale, weight_scale) ##FIXME

def _get_one_sample_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes):
def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes):
self._change_qdq_for_auto(enable=False)

forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output
Expand Down Expand Up @@ -793,15 +795,15 @@ def dict_to_list(dic):
return best_alpha

def _auto_tune_alpha_new(
self, input_maxes, auto_calib_iter=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
self, input_maxes, calib_sample_num=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
):
"""Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.
This function takes quantization of the former layers into consideration when qdq one layer
Also, it reduces the memory usage at the cost of increasingtuning time
TODO may have compatibility issue when setting folding=True
:param input_maxes:
:param auto_calib_iter:
:param calib_sample_num:
:param alpha_min:
:param alpha_max:
:param alpha_step:
Expand All @@ -828,88 +830,82 @@ def _auto_tune_alpha_new(
self.absorb_to_layer, input_maxes, default_alpha, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {}
cnt = 0
multiply_factor = auto_calib_iter // 4 if auto_calib_iter >= 4 else auto_calib_iter
alpha_update_iter = 0
# multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha
multiply_factor = calib_sample_num // 4 if calib_sample_num >= 4 else calib_sample_num

best_alphas = default_alpha
if not self.dataloader:
self._qdq_model_unwrapper_for_auto()
return best_alphas
try:
for input, label in self.dataloader:
loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
cnt += self.dataloader.batch_size
if cnt // multiply_factor >= 1:
alpha_update_iter += 1
cnt = 0
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
if cnt >= calib_sample_num:
break
except:
for input in self.dataloader:
loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]
cnt += self.dataloader.batch_size
if cnt // multiply_factor >= 1:
alpha_update_iter += 1
cnt = 0

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
if cnt >= calib_sample_num:
break

best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"final {key}:{best_alphas[key]}")
logger.info(f"Final alpha {key}:{best_alphas[key]}")
self._qdq_model_unwrapper_for_auto()
logger.info("auto tuning done")
return best_alphas
Expand Down Expand Up @@ -999,7 +995,7 @@ def transform(

if alpha == "auto":
self.alpha_per_layer = self._auto_tune_alpha_new(
input_maxes_abs, auto_calib_iter=32, **auto_alpha_args
input_maxes_abs, calib_sample_num=32, **auto_alpha_args
) ##save the alpha

if alpha == "auto":
Expand Down
11 changes: 7 additions & 4 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ def __iter__(self):

class LLMCalibDataloader:
def __init__(self):
self.batch_size = 1
self.batch_size = 3

def __iter__(self):
yield torch.ones([1, 3], dtype=torch.long)
for i in range(4):
yield torch.ones([3, 3], dtype=torch.long)


class TestSqDepthwiseConv(unittest.TestCase):
Expand Down Expand Up @@ -736,6 +737,7 @@ def test_sq_qkv(self):
sq.transform(alpha=0.5, calib_iter=-1, folding=False)
assert isinstance(sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper)

class TestExample(unittest.TestCase):
def test_sq_quant(self):
from neural_compressor import PostTrainingQuantConfig, quantization

Expand Down Expand Up @@ -763,10 +765,11 @@ def forward(self, x):

class CalibDataloader:
def __init__(self):
self.batch_size = 1
self.batch_size = 3

def __iter__(self):
yield input_ids
for i in range(4):
yield input_ids

def calib_func(model):
for i in range(10):
Expand Down

0 comments on commit 3ae01fc

Please sign in to comment.