Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in smoothquant for auto alpha #1287

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 27 additions & 31 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,9 @@ def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0):
if len(output.shape) <= 2:
max_value = torch.max(torch.abs(output))
else:
max_value = torch.max(torch.abs(output.reshape(output.shape[0], -1)), dim=-1).values
output = output.reshape(output.shape[0], -1)
output_q = output_q.reshape(output_q.shape[0], -1)
max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1)
max_value = torch.clip(max_value, 1e-5)
output = output / max_value ##FIXME need copy not replace
output_q = output_q / max_value
Expand Down Expand Up @@ -712,7 +714,7 @@ def _update_scales_for_auto(self, absorb_scales, weight_scales):
weight_scale = self._reshape_scale_for_weight(layer, weight_scale)
layer.update_scale(input_scale, weight_scale) ##FIXME

def _get_one_sample_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes):
def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes):
self._change_qdq_for_auto(enable=False)

forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output
Expand Down Expand Up @@ -793,15 +795,15 @@ def dict_to_list(dic):
return best_alpha

def _auto_tune_alpha_new(
self, input_maxes, auto_calib_iter=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
self, input_maxes, calib_sample_num=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
):
"""Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.

This function takes quantization of the former layers into consideration when qdq one layer
Also, it reduces the memory usage at the cost of increasingtuning time
TODO may have compatibility issue when setting folding=True
:param input_maxes:
:param auto_calib_iter:
:param calib_sample_num:
:param alpha_min:
:param alpha_max:
:param alpha_step:
Expand All @@ -828,88 +830,82 @@ def _auto_tune_alpha_new(
self.absorb_to_layer, input_maxes, default_alpha, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {}
cnt = 0
multiply_factor = auto_calib_iter // 4 if auto_calib_iter >= 4 else auto_calib_iter
alpha_update_iter = 0
# multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha
multiply_factor = calib_sample_num // 4 if calib_sample_num >= 4 else calib_sample_num

best_alphas = default_alpha
if not self.dataloader:
self._qdq_model_unwrapper_for_auto()
return best_alphas
try:
for input, label in self.dataloader:
loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
cnt += self.dataloader.batch_size
if cnt // multiply_factor >= 1:
alpha_update_iter += 1
cnt = 0
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
if cnt >= calib_sample_num:
break
except:
for input in self.dataloader:
loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]
cnt += self.dataloader.batch_size
if cnt // multiply_factor >= 1:
alpha_update_iter += 1
cnt = 0

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
if cnt >= calib_sample_num:
break

best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"final {key}:{best_alphas[key]}")
logger.info(f"Final alpha {key}:{best_alphas[key]}")
self._qdq_model_unwrapper_for_auto()
logger.info("auto tuning done")
return best_alphas
Expand Down Expand Up @@ -999,7 +995,7 @@ def transform(

if alpha == "auto":
self.alpha_per_layer = self._auto_tune_alpha_new(
input_maxes_abs, auto_calib_iter=32, **auto_alpha_args
input_maxes_abs, calib_sample_num=32, **auto_alpha_args
) ##save the alpha

if alpha == "auto":
Expand Down
32 changes: 18 additions & 14 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,19 @@ def __iter__(self):

class LLMCalibDataloader:
def __init__(self):
self.batch_size = 1
self.batch_size = 3

def __iter__(self):
yield torch.ones([1, 3], dtype=torch.long)
for i in range(4):
yield torch.ones([3, 3], dtype=torch.long)


class TestSqDepthwiseConv(unittest.TestCase):
@classmethod
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -141,7 +142,7 @@ class TestSqConvOpFuseAuto(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -181,7 +182,7 @@ class TestSqConvOpFuse(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -386,21 +387,21 @@ class TestSqListInput(unittest.TestCase):
def setUpClass(self):
class ListDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield [torch.rand((1, 3))]

class TupleDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield (torch.rand((1, 3)))

class ListTupleDataLoader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
input1 = torch.rand((1, 3))
Expand Down Expand Up @@ -499,7 +500,7 @@ class TestAlphaAutoLinear(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3))
Expand Down Expand Up @@ -535,7 +536,7 @@ class TestSqLinearOpFuse(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3))
Expand Down Expand Up @@ -736,6 +737,8 @@ def test_sq_qkv(self):
sq.transform(alpha=0.5, calib_iter=-1, folding=False)
assert isinstance(sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper)


class TestExample(unittest.TestCase):
def test_sq_quant(self):
from neural_compressor import PostTrainingQuantConfig, quantization

Expand Down Expand Up @@ -763,10 +766,11 @@ def forward(self, x):

class CalibDataloader:
def __init__(self):
self.batch_size = 1
self.batch_size = 3

def __iter__(self):
yield input_ids
for i in range(4):
yield input_ids

def calib_func(model):
for i in range(10):
Expand Down Expand Up @@ -935,7 +939,7 @@ class TestSqSkipOp(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 4))
Expand Down Expand Up @@ -992,7 +996,7 @@ class TestSqSkipOp_attn(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 4))
Expand Down