From 9ff7f01c3ca9f5aba0aff01260d58ce3007a8f4c Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Fri, 28 Jul 2023 14:18:52 +0800 Subject: [PATCH] support TEQ layerwise config. (#1120) * support TEQ layerwise config. * fix bug of folding=false. * fix ut. * fix ut. * fix ut. * fix comments. * weight config can exclude specific layer. * fix coverage issue. --- neural_compressor/adaptor/pytorch.py | 88 +++++++++--- neural_compressor/adaptor/pytorch_cpu.yaml | 4 +- .../adaptor/torch_utils/model_wrapper.py | 12 +- neural_compressor/adaptor/torch_utils/teq.py | 131 ++++++------------ .../adaptor/torch_utils/weight_only.py | 5 +- .../test_weight_only_quantization.py | 34 ++--- 6 files changed, 138 insertions(+), 136 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 838dbe0bfaa..5e29dc1b67a 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4589,39 +4589,89 @@ def teq_quantize(self, model, tune_cfg, dataloader, calib_func): logger.debug("quantizing with the TEQ algorithm") from .torch_utils.weight_only import teq_quantize # get example inputs if not provided. - if self.example_inputs is None: + if self.example_inputs is None: # pragma: no cover if dataloader is None: assert False, "Please provide dataloader or example_inputs for TEQ algorithm." try: - for idx, (input, label) in enumerate(dataloader): - self.example_inputs = input + for idx, (x, label) in enumerate(dataloader): + self.example_inputs = x.to(model.device) break except: - for idx, input in enumerate(dataloader): - self.example_inputs = input + for idx, x in enumerate(dataloader): + self.example_inputs = x.to(model.device) break - if 'teq_args' in self.recipes: - wbits = self.recipes.get('wbits', 4) - group_size = self.recipes.get('group_size', 128) - sym = self.recipes.get('scheme', False) - folding = self.recipes.get('folding', True) + folding = True + if 'teq_args' in self.recipes: # pragma: no cover + folding = self.recipes['teq_args'].get('folding', True) + + supported_layers = ['Linear'] + if folding: # pragma: no cover + from .torch_utils.smooth_quant import GraphTrace + tg = GraphTrace() + absorb_to_layer, _ = tg.get_absorb_to_layer(model, self.example_inputs, supported_layers) + if absorb_to_layer is None or absorb_to_layer == {}: + logger.warning('No absorb layer is detected, skip TEQ algorithm') + return model + else: # pragma: no cover + absorb_to_layer = {} + for name, module in model.named_modules(): + for op_type in supported_layers: + if op_type == str(module.__class__.__name__): + absorb_to_layer[name] = [name] - weight_config = { - 'wbits': wbits, - 'group_size': group_size, - 'sym': sym, - 'folding': folding - } - quantizer = teq_quantize( + # got flipped dict from absorb_to_layer dict + flipped_dict = {} + for k, v in absorb_to_layer.items(): + for m in v: + flipped_dict[m] = {'absorb_layer': k} + + # check tune_cfg to skip layers without TEQ config + weight_config = {} + skipped_op_name_set = set() + for key, config in tune_cfg['op'].items(): + op_name, op_type = key + if config['weight']['dtype'] == 'fp32': # pragma: no cover + if op_name in flipped_dict: + absorb_to_layer.pop(flipped_dict[op_name]['absorb_layer']) + continue + else: + weight_config[op_name] = {} + weight_config[op_name]['bits'] = config['weight']['bits'] + weight_config[op_name]['group_size'] = config['weight']['group_size'] + weight_config[op_name]['scheme'] = config['weight']['scheme'] + if op_name in flipped_dict: + algorithm = config['weight']['algorithm'] + if algorithm != 'TEQ': + absorb_to_layer.pop(weight_config[op_name]['absorb_layer']) + else: + skipped_op_name_set.add(op_name) + if skipped_op_name_set: # pragma: no cover + logger.info("{} is skipped by TEQ algorithm".format(skipped_op_name_set)) + + # collect TEQ config from tune_cfg for quantization. + if len(absorb_to_layer) == 0: # pragma: no cover + logger.warning('No absorb layer needs TEQ algorithim, skip it') + else: # pragma: no cover + logger.debug("**absorb layer**: **absorbed layers**") + for k, v in absorb_to_layer.items(): + logger.debug(f"{k}: {v}") + + logger.info("Absorbed layers with the same absorb layer use the same config") + + extra_config = {"folding": folding} + + model = teq_quantize( model, weight_config, + absorb_to_layer, + extra_config, dataloader, example_inputs=self.example_inputs, calib_func=calib_func ) - return quantizer.model - + return model + def awq_quantize(self, model, tune_cfg, dataloader, calib_func): logger.debug("quantizing with the AWQ algorithm") from .torch_utils.weight_only import awq_quantize diff --git a/neural_compressor/adaptor/pytorch_cpu.yaml b/neural_compressor/adaptor/pytorch_cpu.yaml index f4c0416ec8a..b94bc71e284 100644 --- a/neural_compressor/adaptor/pytorch_cpu.yaml +++ b/neural_compressor/adaptor/pytorch_cpu.yaml @@ -267,7 +267,7 @@ # group_size=-1 means per-channel, others means per-group 'group_size': [32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], # [1-inf], # 32 'scheme': ['sym', 'asym'], # sym, no ZP - 'algorithm': ['RTN', 'AWQ', 'GPTQ'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order + 'algorithm': ['RTN', 'AWQ', 'GPTQ', 'TEQ'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order }, 'activation': { 'dtype': ['fp32'], @@ -445,4 +445,4 @@ 'dynamic': *cap_dynamic_s8_1_6, 'quant_aware': *cap_s8_1_6 } - uint8: *cap_s8_1_6 \ No newline at end of file + uint8: *cap_s8_1_6 diff --git a/neural_compressor/adaptor/torch_utils/model_wrapper.py b/neural_compressor/adaptor/torch_utils/model_wrapper.py index afbcdbdc50c..24649667ae6 100644 --- a/neural_compressor/adaptor/torch_utils/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/model_wrapper.py @@ -364,7 +364,7 @@ class FakeAffineTensorQuantFunction(Function): """ @staticmethod - def forward(ctx, inputs, num_bits=4, group_size=1024): + def forward(ctx, inputs, num_bits=4, group_size=1024, scheme="asym"): """ As it will be only applied on activation with per tensor granularity, broadcast is not needed. @@ -379,7 +379,7 @@ def forward(ctx, inputs, num_bits=4, group_size=1024): Returns: outputs: A Tensor of type output_dtype """ - return quant_weight(inputs, num_bits, group_size) + return quant_weight(inputs, num_bits, group_size, scheme) @staticmethod def backward(ctx, grad_outputs): @@ -391,7 +391,7 @@ def backward(ctx, grad_outputs): Returns: grad_inputs: A tensor of gradient """ - return grad_outputs, None, None + return grad_outputs, None, None, None class TEQLinearFakeQuant(torch.nn.Module): @@ -399,7 +399,7 @@ class TEQLinearFakeQuant(torch.nn.Module): wrapper quantization linear """ - def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1): + def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1, scheme="asym"): """ A forward hook to linear module :param orig_layer: the original module @@ -413,6 +413,7 @@ def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1): self.num_bits = num_bits self.group_size = group_size + self.scheme = scheme def forward(self, x): alpha = torch.clip(self.alpha, 1e-5) @@ -421,7 +422,8 @@ def forward(self, x): x = x / alpha.view(shape) weight = self.orig_layer.weight weight = weight * alpha.unsqueeze(dim=0) - weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits, self.group_size) + weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits, + self.group_size, self.scheme) return F.linear(x, weight_q, self.orig_layer.bias) diff --git a/neural_compressor/adaptor/torch_utils/teq.py b/neural_compressor/adaptor/torch_utils/teq.py index c0294b16e0a..8d084c7d50c 100644 --- a/neural_compressor/adaptor/torch_utils/teq.py +++ b/neural_compressor/adaptor/torch_utils/teq.py @@ -41,6 +41,8 @@ def __init__( self, model, weight_config={}, + absorb_to_layer={}, + extra_config={}, example_inputs=None ): """ @@ -48,17 +50,14 @@ def __init__( :param weight_config (dict, optional): contains all info required by GPTQ. Defaults to {}. :param example_inputs: inputs for trace """ - self.model = model - self.num_bits = weight_config.get('wbits', 4) - self.group_size = weight_config.get('group_size', -1) - self.scheme = weight_config.get('sym', False) - self.folding = weight_config.get('folding', True) + self.weight_config = weight_config + self.folding = extra_config.get('folding', True) self.example_inputs = example_inputs self.device, self.dtype = self._get_device() self.model.eval() - self.trained_alphas = {} + self.absorb_to_layer = absorb_to_layer def _get_device(self): """ @@ -68,58 +67,18 @@ def _get_device(self): for _, p in self.model.named_parameters(): return p.data.device, p.data.dtype - def add_tuning_scale(self, op_types=['Linear'], excluded_name="lm_head", - excluded_key=None, sqrt_w_init=False): + def add_tuning_scale(self, sqrt_w_init=False): """ The main entry of smooth quant to the paper for more details - :param op_types: The op typed to be smooth quantized - :param excluded_name: exclude layer - :param excluded_key: exclude key :param sqrt_w_init: use sqrt weight to init """ - if self.folding: - self.insert_mul = False - else: - self.insert_mul = True - - with torch.no_grad(): - if self.insert_mul: - self.absorb_to_layer = self._get_all_layer_names() # TODO: only support linear now. - else: - self.absorb_to_layer, no_absorb_layers = self._trace( - op_types) ##TODO we need to insert mul layer for no_absorb_layers later - if self.absorb_to_layer == None and no_absorb_layers == None: # pragma: no cover - logger.warning("sorry, could not trace the model, smooth quant is skipped") - logger.warning("if you are using huggingface model," - "you could set torchscript to True " - "when loading the model or set the return_dict to False") - elif self.absorb_to_layer == {}: # pragma: no cover - logger.warning("could not find any layer to be absorbed") - else: - to_absorb_cnt = 0 - for key, item in self.absorb_to_layer.items(): - to_absorb_cnt += len(item) - - logger.info( - f" {to_absorb_cnt} out of {to_absorb_cnt + len(no_absorb_layers)} " - f"layers could be absorbed in smooth quant") # freeze model. for n, p in self.model.named_parameters(): p.requires_grad = False - for key, item in self.absorb_to_layer.items(): - if len(item) == 1 and excluded_name in item[0]: - excluded_key = key - break - - if excluded_key != None: - self.absorb_to_layer.pop(excluded_key) ## remove - for layer_norm in self.absorb_to_layer: - if excluded_name in self.absorb_to_layer[layer_norm][0]: # pragma: no cover - continue layer_0_name = self.absorb_to_layer[layer_norm][0] @@ -143,32 +102,33 @@ def add_tuning_scale(self, op_types=['Linear'], excluded_name="lm_head", self.trained_alphas[layer_norm] = alpha for layer_name in self.absorb_to_layer[layer_norm]: + if self.weight_config.get(layer_name) is None: # pragma: no cover + logger.info(f"layer {layer_name} not in weight config, skip.") + continue + num_bits = self.weight_config[layer_name]["bits"] + group_size = self.weight_config[layer_name]["group_size"] + scheme = self.weight_config[layer_name]["scheme"] + module = get_module(self.model, layer_name) wrapper_module = TEQLinearFakeQuant(orig_layer=module, alpha=alpha, - num_bits=self.num_bits, group_size=self.group_size) + num_bits=num_bits, group_size=group_size, scheme=scheme) set_module(self.model, layer_name, wrapper_module) for n, m in self.model.named_modules(): - if isinstance(m, torch.nn.Linear) and excluded_name not in n and "orig_layer" not in n: + if isinstance(m, torch.nn.Linear) and "orig_layer" not in n: + if self.weight_config.get(n) is None: # pragma: no cover + logger.info(f"out of absorbed layer {n} not in weight config, skip.") + continue + num_bits = self.weight_config[layer_name]["bits"] + group_size = self.weight_config[layer_name]["group_size"] + scheme = self.weight_config[layer_name]["scheme"] + alpha = torch.nn.Parameter(torch.ones(m.weight.shape[1], device=self.device)) alpha.requires_grad_(False) wrapper_module = TEQLinearFakeQuant(orig_layer=m, alpha=alpha, - num_bits=self.num_bits, group_size=self.group_size) + num_bits=num_bits, group_size=group_size, scheme=scheme) set_module(self.model, n, wrapper_module) - def _get_all_layer_names(self, op_types=['Linear']): - """ - Try the model to find the layers which can be smooth quantized. - :param op_types: The op types to be smooth quantized - :return: - """ - self_absorb_layer = {} - for name, module in self.model.named_modules(): - for op_type in op_types: - if op_type == str(module.__class__.__name__): - self_absorb_layer[name] = [name] - return self_absorb_layer - @torch.no_grad() def _absorb_scales(self, layer, scale, layer_name=""): """ @@ -178,12 +138,13 @@ def _absorb_scales(self, layer, scale, layer_name=""): :param layer_name: The layer name """ # for insert mul - if self.insert_mul: # pragma: no cover + if not self.folding: # pragma: no cover if isinstance(layer, TEQMulLinear): set_module(self.model, layer_name, layer.sq_linear) ##recover else: new_module = TEQMulLinear(layer, scale) set_module(self.model, layer_name, new_module) + self.weight_config[layer_name + ".sq_linear"] = self.weight_config[layer_name] return if isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.GroupNorm) or \ @@ -275,24 +236,12 @@ def transform(self): layer_module = get_module(self.model, layer_name) self._scale_layer_weight(layer_module, weight_scale) - # for insert_mul = False + # for Folding = True for n, m in self.model.named_modules(): if isinstance(m, TEQLinearFakeQuant): set_module(self.model, n, m.orig_layer) - def _trace(self, op_types): - """ - Try the model to find the layers which can be smooth quantized. - :param op_types: The op types to be smooth quantized - :return: - absorb_to_layer: A dict, absorb layer name:layers to be smooth quantized - no_absorb_layers: A list saving the layers which could not find the absorb layer - """ - tg = GraphTrace() - absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(self.model, self.example_inputs, op_types) - return absorb_to_layer, no_absorb_layers - - def train(self, dataloader, train_steps=100, lr=1e-3, warmup_ratio=0.05, + def train(self, dataloader, train_steps=1000, lr=1e-3, warmup_ratio=0.05, gradient_accumulation_steps=1, logging_steps=10, betas=[0.9, 0.9], weight_decay=0, lr_scheduler_type="linear"): """ @@ -334,7 +283,7 @@ def train(self, dataloader, train_steps=100, lr=1e-3, warmup_ratio=0.05, optimizer.zero_grad() lr_scheduler.step() - if global_steps == train_steps: + if global_steps >= train_steps: # pragma: no cover break logger.info("finish training") @@ -342,24 +291,22 @@ def train(self, dataloader, train_steps=100, lr=1e-3, warmup_ratio=0.05, return None @torch.no_grad() - def quantize(self, scheme=None, quant_lm_head=False): + def quantize(self): """ quantization """ - if scheme is None: - scheme = self.scheme for n, m in self.model.named_modules(): - if quant_lm_head: - if isinstance(m, torch.nn.Linear): - m.weight.data.copy_( - quant_weight(m.weight, num_bits=self.num_bits, - group_size=self.group_size, scheme=scheme)) - else: - if isinstance(m, torch.nn.Linear) and "lm_head" not in n: - m.weight.data.copy_( - quant_weight(m.weight, num_bits=self.num_bits, - group_size=self.group_size, scheme=scheme)) + if self.weight_config.get(n) is None: # pragma: no cover + logger.info(f"quantize layer {n} not in weight config, skip.") + continue + num_bits = self.weight_config[n]["bits"] + group_size = self.weight_config[n]["group_size"] + scheme = self.weight_config[n]["scheme"] + if isinstance(m, torch.nn.Linear): # pragma: no cover + m.weight.data.copy_( + quant_weight(m.weight, num_bits=num_bits, + group_size=group_size, scheme=scheme)) def save(self, save_scale_file="", save_state_dict_file=""): """ diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index 3c4069f83e9..01d44092403 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -661,7 +661,8 @@ def forward(self, *args, **kwargs): logger.info("AWQ quantization is done.") return model -def teq_quantize(model, weight_config={}, dataloader= None, calib_func=None, example_inputs=None): +def teq_quantize(model, weight_config={}, absorb_to_layer={}, extra_config={}, + dataloader= None, calib_func=None, example_inputs=None): """Run weight-only quantization with """ assert isinstance(model, torch.nn.Module), "only support torch module" logger.info("TEQ quantizing start.") @@ -678,7 +679,7 @@ def teq_quantize(model, weight_config={}, dataloader= None, calib_func=None, exa break from .teq import TEQuantizer - teq_quantizer = TEQuantizer(model, weight_config, example_inputs) + teq_quantizer = TEQuantizer(model, weight_config, absorb_to_layer, extra_config, example_inputs) # 1. wrapper tuning scale to model teq_quantizer.add_tuning_scale() diff --git a/test/quantization/test_weight_only_quantization.py b/test/quantization/test_weight_only_quantization.py index cfe264a86dc..2be9b9b32f9 100644 --- a/test/quantization/test_weight_only_quantization.py +++ b/test/quantization/test_weight_only_quantization.py @@ -186,26 +186,28 @@ def train_func(self): def test_teq(self): dataloader = self.generate_random_corpus() model = copy.deepcopy(self.gptj) + weight_config = { - 'wbits': 4, - 'group_size': 128, - 'sym': True, - 'folding': True + # 'op_name': (bit, group_size, sheme) + 'transformer.h.0.mlp.fc_in': { + 'bits': 8, + 'group_size': -1, + 'scheme': 'sym' + }, + 'transformer.h.0.mlp.fc_out': { + 'bits': 4, + 'group_size': 32, + 'scheme': 'asym' + }, } + absorb_dict = { + 'transformer.h.0.mlp.fc_in': ['transformer.h.0.mlp.fc_out'] + } + extra_config = {'folding': True} - model = teq_quantize(model, weight_config=weight_config, dataloader=dataloader) - self.assertTrue(isinstance(model, torch.nn.Module)) - - del model - model = copy.deepcopy(self.gptj) - weight_config = { - 'wbits': 4, - 'group_size': 128, - 'sym': True, - 'folding': False - } - model = teq_quantize(model, weight_config=weight_config, dataloader=dataloader) + model = teq_quantize(model, weight_config=weight_config, absorb_to_layer=absorb_dict, + extra_config=extra_config, dataloader=dataloader) self.assertTrue(isinstance(model, torch.nn.Module)) if __name__ == "__main__":