Skip to content

Commit

Permalink
Enhance ONNX SmoothQuant tuning structure (#1123)
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho authored Aug 2, 2023
1 parent 1f6b1ad commit f0d51c2
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 72 deletions.
141 changes: 77 additions & 64 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, framework_specific_info):

# sq algo and args
self.sq = None
self.cur_sq_args = None
self.cur_sq_args = {}

def smooth_quant(self, model, dataloader, iterations, alpha=0.5, folding=True,
percentile=99.999, op_types=['MatMul', 'Gemm', 'Conv', 'FusedConv'],
Expand All @@ -186,29 +186,38 @@ def smooth_quant(self, model, dataloader, iterations, alpha=0.5, folding=True,
return self.smooth_quant_model

from .ox_utils.smooth_quant import ORTSmoothQuant
# TODO remove quantize_config as it no consumer
quantize_config = None
# pre-optimization -> sq

# set params to cur_sq_args
self.cur_sq_args['alpha'] = alpha
self.cur_sq_args['folding'] = folding
self.cur_sq_args['percentile'] = percentile
self.cur_sq_args['op_types'] = op_types
self.cur_sq_args['scales_per_op'] = scales_per_op
self.cur_sq_args['calib_iter'] = iterations

# pre-optimization
self._pre_optimize(model)

# assign the algo to the adaptor, so adaptor can call it later when needed
self.sq = ORTSmoothQuant(self.pre_optimized_model, dataloader, self.reduce_range, self.backend)
self.smooth_quant_model = self.sq.transform(
alpha, folding, percentile, op_types, scales_per_op, iterations, quantize_config)
self.sq.record_max_info = record_max_info
self.smooth_quant_model = self.sq.transform(**self.cur_sq_args)
logger.info("Updated the pre-optimized model with smooth quant model.")
# TODO double-check the smooth_quant_model and pre_optimized_model to make sure there no two fp32 model replicas
self.pre_optimized_model = self.smooth_quant_model
return self.smooth_quant_model

def _need_smooth_quant(self, tune_cfg) -> bool:
# compare the alpha from tune_cfg and current alpha to decide whether re-smooth model or not
# TODO
return False

def _parse_sq_args(self, tune_cfg, cur_sq_args) -> Dict:
# parse the sq args according to the tune cfg and current sq args
# TODO
return {}

"""Check the model needs smooth quant or not."""
recipe_cfgs = tune_cfg.get('recipe_cfgs', None)
if recipe_cfgs and recipe_cfgs.get('smooth_quant', False) \
and recipe_cfgs['smooth_quant_args'].get('alpha', None):
# update alpha according to tune_cfg
self.cur_sq_args['alpha'] = \
tune_cfg['recipe_cfgs']['smooth_quant_args']['alpha']
return True
else:
return False

@dump_elapsed_time("Pass quantize model")
def quantize(self, tune_cfg, model, data_loader, q_func=None):
Expand All @@ -224,17 +233,9 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
Returns:
(dict): quantized model
"""
# two steps to re-smooth the model if needed
if self._need_smooth_quant(tune_cfg):
# step1. recover the sq to original fp32 model
self.sq.recover()
new_sq_args = self._parse_sq_args(tune_cfg, self.cur_sq_args)
# step2. re-smooth the model with new alpha
model = self.smooth_quant(model=model, dataloader=data_loader, iterations=new_sq_args['iterations'],\
alpha=new_sq_args['alpha'], folding=new_sq_args['folding'], scales_per_op=new_sq_args['scales_per_op'])
"""
assert q_func is None, "quantization aware training has not been supported on ONNXRUNTIME"
if self.smooth_quant_model is not None:
if self.smooth_quant_model is not None and model.is_smoothquant_model():
model = self.smooth_quant_model
elif self.pre_optimized_model is not None:
model = self.pre_optimized_model
Expand Down Expand Up @@ -277,47 +278,19 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(
repr(e)))
tmp_model = model

# smooth quant the model if needed
if self._need_smooth_quant(tune_cfg) and not tmp_model.is_smoothquant_model():
self.sq.model = tmp_model
self.sq.record_max_info = False
tmp_model = self.sq.transform(**self.cur_sq_args)

iterations = tune_cfg.get('calib_iteration', 1)
calib_sampling_size = tune_cfg.get('calib_sampling_size', 1)
if not self.dynamic:
if isinstance(data_loader, BaseDataLoader):
batch_size = data_loader.batch_size
try:
for i in range(batch_size):
if calib_sampling_size % (batch_size - i) == 0:
calib_batch_size = batch_size - i
if i != 0: # pragma: no cover
logger.warning("Reset `calibration.dataloader.batch_size` field "
"to {}".format(calib_batch_size) +
" to make sure the sampling_size is "
"divisible exactly by batch size")
break
tmp_iterations = int(math.ceil(calib_sampling_size / calib_batch_size))
data_loader.batch(calib_batch_size)
quantize_params = self._get_quantize_params(tmp_model, data_loader, \
quantize_config, tmp_iterations)
except Exception as e: # pragma: no cover
if 'Got invalid dimensions for input' in str(e):
logger.warning("Please set sampling_size to a multiple of {}".format(
str(e).partition('Expected: ')[2].partition('\n')[0]))
exit(0)
logger.warning(
"Fail to forward with batch size={}, set to {} now.".
format(batch_size, 1))
data_loader.batch(1)
quantize_params = self._get_quantize_params(tmp_model, data_loader, \
quantize_config, calib_sampling_size)
else: # pragma: no cover
if hasattr(data_loader, 'batch_size') and \
calib_sampling_size % data_loader.batch_size != 0:
logger.warning(
"Please note that calibration sampling size {} " \
"isn't divisible exactly by batch size {}. " \
"So the real sampling size is {}.".
format(calib_sampling_size, data_loader.batch_size,
data_loader.batch_size * iterations))
quantize_params = self._get_quantize_params(tmp_model, data_loader, \
quantize_config, iterations)
calib_iterations = self._reset_calib_iter(data_loader, calib_sampling_size, iterations)
quantize_params = self._get_quantize_params(tmp_model, data_loader, \
quantize_config, calib_iterations)
else:
quantize_params = None
self.quantize_params = quantize_params
Expand Down Expand Up @@ -350,6 +323,45 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
self._dump_model_op_stats(tmp_model)
tmp_model.topological_sort()
return tmp_model

def _reset_calib_iter(self, data_loader, cfg_calib_sampling_size, cfg_calib_iter):
"""Check and reset calibration iterations according to calib_sampleing_size and dataloader batch_size."""
if isinstance(data_loader, BaseDataLoader):
batch_size = data_loader.batch_size
try:
for i in range(batch_size):
if cfg_calib_sampling_size % (batch_size - i) == 0:
calib_batch_size = batch_size - i
if i != 0: # pragma: no cover
logger.warning("Reset `calibration.dataloader.batch_size` field "
"to {}".format(calib_batch_size) +
" to make sure the sampling_size is "
"divisible exactly by batch size")
break
tmp_iterations = int(math.ceil(cfg_calib_sampling_size / calib_batch_size))
data_loader.batch(calib_batch_size)
calib_iterations = tmp_iterations
except Exception as e: # pragma: no cover
if 'Got invalid dimensions for input' in str(e):
logger.warning("Please set sampling_size to a multiple of {}".format(
str(e).partition('Expected: ')[2].partition('\n')[0]))
exit(0)
logger.warning(
"Fail to forward with batch size={}, set to {} now.".
format(batch_size, 1))
data_loader.batch(1)
calib_iterations = cfg_calib_sampling_size
else: # pragma: no cover
if hasattr(data_loader, 'batch_size') and \
cfg_calib_sampling_size % data_loader.batch_size != 0:
logger.warning(
"Please note that calibration sampling size {} " \
"isn't divisible exactly by batch size {}. " \
"So the real sampling size is {}.".
format(cfg_calib_sampling_size, data_loader.batch_size,
data_loader.batch_size * cfg_calib_iter))
calib_iterations = cfg_calib_iter
return calib_iterations

def _generate_qconfig(self, model, tune_cfg, quantize_params):
tune_cfg = copy.deepcopy(tune_cfg)
Expand Down Expand Up @@ -512,10 +524,11 @@ def _get_quantize_params(self, model, data_loader, quantize_config, iterations):
model = ONNXModel(model)
black_nodes = [node for node in quantize_config if quantize_config[node]=='fp32']
white_nodes = [node for node in quantize_config if quantize_config[node]!='fp32']

augment = ONNXRTAugment(model, \
data_loader, self.quantizable_op_types, \
black_nodes=black_nodes, white_nodes=white_nodes, \
iterations=list(range(0, quantize_config['calib_iteration'])),
iterations=list(range(0, iterations)), \
backend=self.backend, reduce_range=self.reduce_range)
self.min_max = augment.dump_minmax(quantize_config)
quantize_params = augment.dump_calibration(quantize_config, min_max=self.min_max)
Expand Down
25 changes: 20 additions & 5 deletions neural_compressor/adaptor/ox_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, model, dataloader, reduce_range=False, backend='CPUExecutionP
self.tensors_to_node = {}
self.replace_input = []
self.ops_to_absorb = []
self.record_max_info = False
self._build_absorb_function()

def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm', 'Conv', 'MatMul', 'FusedConv'],
Expand All @@ -142,6 +143,7 @@ def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm'
A FP32 model with the same architecture as the orig model but with different weight which will be
benefit to quantization
"""
self.clean()
if isinstance(alpha, float) and (alpha < 0 or alpha > 1):
logger.warning("alpha should be a float value in [0, 1] or 'auto' ")
if alpha < 0:
Expand All @@ -155,12 +157,16 @@ def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm'
if need_calibration:
self._dump_op_info(percentile, op_types, calib_iter, quantize_config)

if alpha == 'auto':
alpha = self._auto_tune_alpha(calib_iter, **auto_alpha_args)
if self.record_max_info:
return self.model

if alpha == 'auto':
alpha = self._auto_tune_alpha(calib_iter, **auto_alpha_args)

scales = self._get_smooth_scales(alpha)
self._insert_smooth_mul_op(scales)
self._adjust_weights(scales)

scales = self._get_smooth_scales(alpha)
self._insert_smooth_mul_op(scales)
self._adjust_weights(scales)
self.model.add_nodes(self.new_added_mul_nodes)
self.model.model.graph.value_info.extend(self.new_added_value_info)
self.model.add_initializers(self.new_init_tensors)
Expand Down Expand Up @@ -202,6 +208,15 @@ def recover(self):
self.new_added_mul_nodes = []
self.new_init_tensors = []
self.new_added_value_info = []
self.replace_input = []

def clean(self):
"""Clean data collected from calibration."""
self.tensor_scales_info = {}
self.new_added_mul_nodes = []
self.new_init_tensors = []
self.new_added_value_info = []
self.replace_input = []

def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter):
"""Check need calibration or not.
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,8 @@ def smooth_quant_args(val=None):
smooth_quant_args = {"alpha": numpy.arange(0.1, 0.5, 0.05).tolist()}
"""
if isinstance(v, str):
assert v == "auto", "the alpha of sq only supports float and 'auto'"
elif isinstance(v, float) or isinstance(v, int):
assert v == "auto", "the alpha of sq only supports float, list and 'auto'"
elif isinstance(v, float) or isinstance(v, int) or isinstance(v, list):
continue
else:
logger.warning("Ignore the alpha as it's not a list, int or float.")
Expand Down
11 changes: 11 additions & 0 deletions neural_compressor/model/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,14 @@ def match_parent_path(
current_node = matched_parent

return matched_parents

def is_smoothquant_model(self):
"""Check the model is smooth quantized or not.
Returns:
bool: the model is smooth quantized or not.
"""
for init in self.model.graph.initializer:
if "_smooth_scale" in init.name:
return True
return False
Loading

0 comments on commit f0d51c2

Please sign in to comment.