From 7ba9fdcb24a8ea1c1efc27844f39d0c128f83517 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 19 Aug 2024 14:50:50 +0800 Subject: [PATCH] support gptq `true_sequential` and `quant_lm_head` (#1977) Signed-off-by: Kaihui-intel --- .azure-pipelines/scripts/ut/run_itrex.sh | 3 +- docs/source/3x/PT_WeightOnlyQuant.md | 5 +- .../torch/algorithms/weight_only/gptq.py | 422 ++++++++++++++---- .../torch/algorithms/weight_only/rtn.py | 2 + .../torch/quantization/algorithm_entry.py | 2 + .../torch/quantization/config.py | 9 +- .../quantization/weight_only/test_gptq.py | 73 ++- .../quantization/weight_only/test_rtn.py | 13 +- 8 files changed, 430 insertions(+), 99 deletions(-) diff --git a/.azure-pipelines/scripts/ut/run_itrex.sh b/.azure-pipelines/scripts/ut/run_itrex.sh index 2bbbf958398..5adaf86579b 100644 --- a/.azure-pipelines/scripts/ut/run_itrex.sh +++ b/.azure-pipelines/scripts/ut/run_itrex.sh @@ -18,7 +18,8 @@ bash /intel-extension-for-transformers/.github/workflows/script/install_binary.s sed -i '/neural-compressor.git/d' /intel-extension-for-transformers/tests/requirements.txt pip install -r /intel-extension-for-transformers/tests/requirements.txt # workaround -pip install onnx==1.15.0 +pip install onnx==1.16.0 +pip install onnxruntime==1.18.0 echo "pip list itrex ut deps..." pip list LOG_DIR=/neural-compressor/log_dir diff --git a/docs/source/3x/PT_WeightOnlyQuant.md b/docs/source/3x/PT_WeightOnlyQuant.md index 1b7a6e760ff..727b791a8f4 100644 --- a/docs/source/3x/PT_WeightOnlyQuant.md +++ b/docs/source/3x/PT_WeightOnlyQuant.md @@ -111,9 +111,10 @@ model = convert(model) | model_path (str) | Model path that is used to load state_dict per layer | | | use_double_quant (bool) | Enables double quantization | False | | act_order (bool) | Whether to sort Hessian's diagonal values to rearrange channel-wise quantization order | False | -| percdamp (float) | Percentage of Hessian's diagonal values' average, which will be added to Hessian's diagonal to increase numerical stability | 0.01. | +| percdamp (float) | Percentage of Hessian's diagonal values' average, which will be added to Hessian's diagonal to increase numerical stability | 0.01 | | block_size (int) | Execute GPTQ quantization per block, block shape = [C_out, block_size] | 128 | -| static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False. | +| static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False | +| true_sequential (bool) | Whether to quantize layers within a transformer block in their original order. This can lead to higher accuracy but slower overall quantization process. | False | > **Note:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned. ``` python diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index fa50b64e86b..1dbd7511663 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -81,6 +81,7 @@ def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn "transformers": {}, Dict# TODO } """ + find_transformers = False if type(module).__name__ == "MixFormerSequentialForCausalLM": # pragma: no cover gptq_related_blocks = { "embeddings": {}, @@ -110,12 +111,19 @@ def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn } for n, m in module.named_modules(): if type(m) in module_types: + # find the block gptq_related_blocks["transformers_name"] = n gptq_related_blocks["transformers"] = m - return gptq_related_blocks + find_transformers = True + # return gptq_related_blocks + elif is_leaf(m) and not find_transformers: + gptq_related_blocks["embeddings"][n] = m + elif n.find(gptq_related_blocks["transformers_name"]) == -1 and find_transformers: + # no longer belong to transformers + gptq_related_blocks["transformers_post"]["name"] = n + gptq_related_blocks["transformers_post"]["layer"] = m else: - if is_leaf(m): - gptq_related_blocks["embeddings"][n] = m + continue return gptq_related_blocks @@ -178,6 +186,7 @@ def __init__( device=None, use_layer_wise=False, model_path="", + quant_lm_head=False, dataloader=None, *args, **kwargs, @@ -204,6 +213,7 @@ def __init__( dataloader: an iterable containing calibration datasets, contains (inputs, targets) use_layer_wise (bool): Enables quantize model per layer. Defaults to False. model_path (str): Model path that is used to load state_dict per layer. + quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers. Defaults to False. device (str): cpu or cuda. """ # model @@ -224,6 +234,8 @@ def __init__( self.sym_default = False self.act_order_default = False self.static_groups_default = False + self.true_sequential_default = False + self.quant_lm_head = quant_lm_head self.perchannel_default = True self.mse_default = False self.use_double_quant_default = False @@ -293,6 +305,9 @@ def check_layer_config(self): self.weight_config[layer_name]["sym"] = config.get("sym", self.sym_default) self.weight_config[layer_name]["act_order"] = config.get("act_order", self.act_order_default) self.weight_config[layer_name]["static_groups"] = config.get("static_groups", self.static_groups_default) + self.weight_config[layer_name]["true_sequential"] = config.get( + "true_sequential", self.true_sequential_default + ) self.weight_config[layer_name]["perchannel"] = config.get("perchannel", self.perchannel_default) self.weight_config[layer_name]["mse"] = config.get("mse", self.mse_default) self.weight_config[layer_name]["use_double_quant"] = config.get( @@ -473,6 +488,46 @@ def update_blockwise_hidden_states(self, outs): else: self.cache_positional_arguments[0] = outs[:] + def find_true_sequential_config(self): + """Find true sequential config. + + Returns: + bool: True or False. + """ + for layer_name in self.weight_config: + if self.weight_config[layer_name].get("true_sequential", None) is not None: + return self.weight_config[layer_name]["true_sequential"] + return False + + def analyze_true_sequential(self, module, inputs=None): + """To obtain the depth of each linear layers in this block. + + Args: + module (nn.module): block. + inputs (optional): Defaults to None. + + Returns: + list: layers grouping into sequentials. + """ + # to obtain the depth of each linear layers in this block + # obtain all linear layers' names + layers = find_layers(module) + layers = list(layers) + # group layers into sequentials + # case 1: query, key and value are calculated from one matrix, bloom, etc.. + if "q" in layers[0].lower() and "k" in layers[0].lower(): + qkv_layers = [layers[0]] + post_qkv_layers = layers[1:] + else: + # case 2: qkv are calculated separately. + qkv_layers = layers[0:3] + post_qkv_layers = layers[3:] + layers.clear() + layers.append(qkv_layers) + for layer in post_qkv_layers: + layers.append([layer]) + return layers + @torch.no_grad() def execute_quantization(self, means=None, stds=None): """Run quantization.""" @@ -482,6 +537,11 @@ def execute_quantization(self, means=None, stds=None): # Step2: run gptq quantization in a transformer block-wise manner. gptq_config = {} + + self.true_sequential = self.find_true_sequential_config() + # automatically get true_sequential + true_sequential_map = self.analyze_true_sequential(self.gptq_related_blocks["transformers"][0]) + logger.info(f"Sequential Name: {true_sequential_map}") tblock_length = len(self.gptq_related_blocks["transformers"]) for block_idx in range(tblock_length): logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") @@ -493,74 +553,273 @@ def execute_quantization(self, means=None, stds=None): # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. sub_layers = find_layers(transformer_block) sub_layers_to_quant = {} + # add true sequential options + if self.true_sequential is not None and self.true_sequential: + sequentials = true_sequential_map + else: + sequentials = [list(sub_layers.keys())] + # start to process every layers in a sequential + for sequential in sequentials: + logger.info(f"Current quantization sequential: {sequential}") + sub_layers_to_quant = {} + sequential_layers = {n: sub_layers[n] for n in sequential} + for layer_name, layer_obj in sequential_layers.items(): + # filter sub_layers with included layer_names in self.weight_config + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + # if self.weight_config.get(full_layer_name, None) == None: + if self.get_layer_config(full_layer_name) is None: + logger.warning( + f"{full_layer_name} can be quantized " + "but excluded from quantization configs." + ) + else: + sub_layers_to_quant[layer_name] = layer_obj + del sequential_layers + sequential_layers = sub_layers_to_quant + # Step 2.2: Initialize GPTQ quantizers for collected layers. + gptq_for_this_block = {} + # initialize gptq quantizer for every layer in a transformer block + for layer_name in sequential_layers: + # weight_config_this_layer = self.weight_config.get( + # self.get_full_layer_name(layer_name, block_idx), None + # ) + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + weight_config_this_layer = self.get_layer_config(full_layer_name) + if self.use_layer_wise: # pragma: no cover + from neural_compressor.torch.algorithms.layer_wise import load_value + + W = load_value(self.model, full_layer_name + ".weight", self.model_path) + else: + W = sequential_layers[layer_name].weight.data.clone() + + gptq_for_this_block[layer_name] = GPTQ(sequential_layers[layer_name], W, self.device) + # gptq_for_this_block[layer_name].quantizer = Quantizer() + gptq_for_this_block[layer_name].quantizer.configure(weight_config_this_layer) + + # Step 2.3: modify forward functions to hook inputs data (used in gptq execution) + def add_batch(_name): + def tmp(_, inp, out): + gptq_for_this_block[_name].add_batch(inp[0].data, out.data) # noqa: F821 + + return tmp + + handles = [] # register handles which add inputs and outputs to gptq object + for layer_name in sequential_layers: + handles.append(sequential_layers[layer_name].register_forward_hook(add_batch(layer_name))) + batch_num = self.cache_key_arguments.pop("batch_num") + for j in range(batch_num): + cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) + cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) + accelerator.mark_step() + out = transformer_block(*cache_positional_batch, **cache_keyword_batch) + out = self.track_hidden_states(out) + self.cache_key_arguments["batch_num"] = batch_num + for h in handles: + h.remove() + # Step 2.4: everything is prepared, so start quantization! + for layer_name in sequential_layers: + # weight_config_this_layer = self.weight_config.get( + # self.get_full_layer_name(layer_name, block_idx), None + # ) + weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) + logger.info(f"Quantizing layer {layer_name}") + if self.use_layer_wise: # pragma: no cover + from neural_compressor.torch.algorithms.layer_wise import load_value + + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + W = load_value(self.model, full_layer_name + ".weight", self.model_path) + else: + W = sequential_layers[layer_name].weight.data.clone() + accelerator.mark_step() + if "hpu" in self.device: + W = W.to("cpu") + scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( + W, + blocksize=weight_config_this_layer["block_size"], + percdamp=weight_config_this_layer["percdamp"], + groupsize=weight_config_this_layer["group_size"], + act_order=weight_config_this_layer["act_order"], + static_groups=weight_config_this_layer["static_groups"], + ) + if self.use_layer_wise: # pragma: no cover + from neural_compressor.torch.algorithms.layer_wise import ( + LWQ_WORKSPACE, + clean_module_weight, + load_value, + set_module_tensor_to_device, + ) + + sub_layer = sequential_layers[layer_name] + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + for n, p in sub_layer.named_parameters(): + param_name = full_layer_name + "." + n + if n == "weight": + set_module_tensor_to_device(self.model, param_name, self.device, Q) + else: + value = load_value(self.model, param_name, self.model_path) + set_module_tensor_to_device(self.model, param_name, self.device, value) + # sub_layer.weight.data = Q + torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt") + clean_module_weight(sub_layer) + del Q + gc.collect() + else: + sequential_layers[layer_name].weight.data = Q + gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} + if not weight_config_this_layer["sym"]: + gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp + if weight_config_this_layer["act_order"]: # save perm for restoring the weights + gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"] = gptq_for_this_block[ + layer_name + ].perm + gptq_for_this_block[layer_name].free() + + # Step 2.5: replace output data with quantized weights + outs = [] + batch_num = self.cache_key_arguments.pop("batch_num") + for j in range(batch_num): + cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) + cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) + out = transformer_block(*cache_positional_batch, **cache_keyword_batch) + out = self.track_hidden_states(out) + outs.append(out) + self.cache_key_arguments["batch_num"] = batch_num + if self.use_layer_wise: # pragma: no cover + self.gptq_related_blocks["transformers"][block_idx] = transformer_block + else: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() + # Step 2.6: export to compressed model + for layer_name in sequential_layers: + weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) + gptq_scale = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["scale"] + if not weight_config_this_layer["sym"]: + gptq_zp = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] + else: + gptq_zp = None + if weight_config_this_layer["act_order"]: # save perm for restoring the weights + gptq_perm = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"] + else: + gptq_perm = None + if self.use_layer_wise: + state_dict = torch.load( + LWQ_WORKSPACE + f"/{self.get_full_layer_name(layer_name, block_idx)}.pt" + ) + Q = state_dict["weight"].data + bias = state_dict["bias"] if "bias" in state_dict.keys() else None + + else: + Q = sequential_layers[layer_name].weight.data + if weight_config_this_layer["act_order"]: + Q.copy_(Q[:, gptq_perm]) + if is_transformers_imported() and isinstance(sequential_layers[layer_name], transformers.Conv1D): + Q = Q.t_().contiguous() + from .utility import quant_weight_w_scale + + quant_weight_w_scale( + Q, + gptq_scale, + gptq_zp, + weight_config_this_layer["group_size"], + dtype=weight_config_this_layer["dtype"], + ) + if weight_config_this_layer["act_order"]: + invperm = torch.argsort(gptq_perm) + Q.copy_(Q[:, invperm]) + int_weight = Q.type(torch.int32) # copy_ is not workable for different types. + # replace module + if isinstance(sequential_layers[layer_name], torch.nn.Linear): + in_features = sequential_layers[layer_name].in_features + out_features = sequential_layers[layer_name].out_features + elif is_transformers_imported() and isinstance(sequential_layers[layer_name], transformers.Conv1D): + in_features = sequential_layers[layer_name].weight.shape[0] + out_features = sequential_layers[layer_name].weight.shape[1] + int_weight = sequential_layers[layer_name].weight.t_().contiguous() + scale = scale.t_().contiguous() + zp = zp.t_().contiguous() if zp is not None else zp + + if not self.use_layer_wise: + bias = sequential_layers[layer_name].bias + + new_module = INCWeightOnlyLinear( + in_features, + out_features, + dtype=weight_config_this_layer["dtype"], + bits=weight_config_this_layer["bits"], + group_size=weight_config_this_layer["group_size"], + zp=gptq_zp is not None, + bias=bias is not None, + g_idx=gptq_perm is not None, + device=self.device, + ) + new_module.pack(int_weight, gptq_scale, gptq_zp, bias, gptq_perm) + set_module(transformer_block, layer_name, new_module) + + del gptq_for_this_block + torch.cuda.empty_cache() + # iteratively replace the input with output, thus layerwise quantization can continue. + self.update_blockwise_hidden_states(outs) + logger.info("------------------------------") + # 2.7.1 do the post transformer blocks quantization + do_post_transformer_quant = self.quant_lm_head + if do_post_transformer_quant: + logger.info("Quantizing post transformer layers") + # the input should be self.cache_key_arguments and self.cache_positional_arguments + sub_layers = find_layers(self.gptq_related_blocks["transformers_post"]["layer"]) + sub_layers_to_quant = {} for layer_name, layer_obj in sub_layers.items(): # filter sub_layers with included layer_names in self.weight_config - full_layer_name = self.get_full_layer_name(layer_name, block_idx) + full_layer_name = self.gptq_related_blocks["transformers_post"]["name"] # if self.weight_config.get(full_layer_name, None) == None: if self.get_layer_config(full_layer_name) is None: logger.warning(f"{full_layer_name} can be quantized " + "but excluded from quantization configs.") else: - sub_layers_to_quant[layer_name] = layer_obj + sub_layers_to_quant[full_layer_name] = layer_obj del sub_layers sub_layers = sub_layers_to_quant - # Step 2.2: Initialize GPTQ quantizers for collected layers. - gptq_for_this_block = {} - # initialize gptq quantizer for every layer in a transformer block + gptq_post_block = {} + + def add_batch_post(_name): + def tmp(_, inp, out): + gptq_post_block[_name].add_batch(inp[0].data, out.data) + + return tmp + for layer_name in sub_layers: - # weight_config_this_layer = self.weight_config.get( - # self.get_full_layer_name(layer_name, block_idx), None - # ) - full_layer_name = self.get_full_layer_name(layer_name, block_idx) + full_layer_name = self.gptq_related_blocks["transformers_post"]["name"] weight_config_this_layer = self.get_layer_config(full_layer_name) if self.use_layer_wise: # pragma: no cover from neural_compressor.torch.algorithms.layer_wise import load_value + full_layer_name = self.gptq_related_blocks["transformers_post"]["name"] W = load_value(self.model, full_layer_name + ".weight", self.model_path) else: W = sub_layers[layer_name].weight.data.clone() - gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) + gptq_post_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) # gptq_for_this_block[layer_name].quantizer = Quantizer() - gptq_for_this_block[layer_name].quantizer.configure(weight_config_this_layer) - - # Step 2.3: modify forward functions to hook inputs data (used in gptq execution) - def add_batch(_name): - def tmp(_, inp, out): - gptq_for_this_block[_name].add_batch(inp[0].data, out.data) # noqa: F821 - - return tmp - + gptq_post_block[layer_name].quantizer.configure(weight_config_this_layer) + # generate the gptq quantizer handles = [] # register handles which add inputs and outputs to gptq object for layer_name in sub_layers: - handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name))) - batch_num = self.cache_key_arguments.pop("batch_num") - for j in range(batch_num): - cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) - cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) - accelerator.mark_step() - out = transformer_block(*cache_positional_batch, **cache_keyword_batch) - out = self.track_hidden_states(out) - self.cache_key_arguments["batch_num"] = batch_num + handles.append(sub_layers[layer_name].register_forward_hook(add_batch_post(layer_name))) + for j in range(len(self.dataloader)): + if "hidden_states" in self.cache_key_arguments: + out = sub_layers[layer_name](self.cache_key_arguments["hidden_states"][j]) + else: + out = sub_layers[layer_name](self.cache_positional_arguments[0][j]) + + # if "hidden_states" in self.cache_key_arguments: + # self.cache_key_arguments["hidden_states"] = outs[:] + # else: + # self.cache_positional_arguments[0] = outs[:] + # perform the inference process + for h in handles: h.remove() - # Step 2.4: everything is prepared, so start quantization! - for layer_name in sub_layers: - # weight_config_this_layer = self.weight_config.get( - # self.get_full_layer_name(layer_name, block_idx), None - # ) - weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) - logger.info(f"Quantizing layer {layer_name}") - if self.use_layer_wise: # pragma: no cover - from neural_compressor.torch.algorithms.layer_wise import load_value - full_layer_name = self.get_full_layer_name(layer_name, block_idx) - W = load_value(self.model, full_layer_name + ".weight", self.model_path) - else: - W = sub_layers[layer_name].weight.data.clone() - accelerator.mark_step() - if "hpu" in self.device: - W = W.to("cpu") - scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( + for layer_name in sub_layers: + full_layer_name = self.gptq_related_blocks["transformers_post"]["name"] + weight_config_this_layer = self.get_layer_config(full_layer_name) + scale, zp, Q = gptq_post_block[layer_name].fasterquant( W, blocksize=weight_config_this_layer["block_size"], percdamp=weight_config_this_layer["percdamp"], @@ -577,7 +836,7 @@ def tmp(_, inp, out): ) sub_layer = sub_layers[layer_name] - full_layer_name = self.get_full_layer_name(layer_name, block_idx) + full_layer_name = self.gptq_related_blocks["transformers_post"]["name"] for n, p in sub_layer.named_parameters(): param_name = full_layer_name + "." + n if n == "weight": @@ -592,51 +851,39 @@ def tmp(_, inp, out): gc.collect() else: sub_layers[layer_name].weight.data = Q - gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} + # save the quantization results + gptq_config[full_layer_name] = {"scale": scale} if not weight_config_this_layer["sym"]: - gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp - if weight_config_this_layer["act_order"]: # save perm for restoring the weights - gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"] = gptq_for_this_block[ - layer_name - ].perm - gptq_for_this_block[layer_name].free() - - # Step 2.5: replace output data with quantized weights - outs = [] - batch_num = self.cache_key_arguments.pop("batch_num") - for j in range(batch_num): - cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) - cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) - out = transformer_block(*cache_positional_batch, **cache_keyword_batch) - out = self.track_hidden_states(out) - outs.append(out) - self.cache_key_arguments["batch_num"] = batch_num - if self.use_layer_wise: # pragma: no cover - self.gptq_related_blocks["transformers"][block_idx] = transformer_block - else: - self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() - # Step 2.6: export to compressed model + gptq_config[full_layer_name]["zero"] = zp + if weight_config_this_layer["act_order"] and not weight_config_this_layer["static_groups"]: + # save perm for restoring the weights, but only when static_groups is not enabled. + gptq_config[full_layer_name]["perm"] = gptq_post_block[full_layer_name].perm + gptq_post_block[layer_name].free() + + # 2.7.2 lm_head: export to compressed model for layer_name in sub_layers: - weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) - gptq_scale = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["scale"] + full_layer_name = self.gptq_related_blocks["transformers_post"]["name"] + weight_config_this_layer = self.get_layer_config(full_layer_name) + gptq_scale = gptq_config[full_layer_name]["scale"] if not weight_config_this_layer["sym"]: - gptq_zp = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] + gptq_zp = gptq_config[full_layer_name]["zero"] else: gptq_zp = None if weight_config_this_layer["act_order"]: # save perm for restoring the weights - gptq_perm = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"] + gptq_perm = gptq_config[full_layer_name]["perm"] else: gptq_perm = None - if self.use_layer_wise: - state_dict = torch.load(LWQ_WORKSPACE + f"/{self.get_full_layer_name(layer_name, block_idx)}.pt") + if self.use_layer_wise: # pragma: no cover + state_dict = torch.load(LWQ_WORKSPACE + f"/{full_layer_name}.pt") Q = state_dict["weight"].data bias = state_dict["bias"] if "bias" in state_dict.keys() else None - else: Q = sub_layers[layer_name].weight.data if weight_config_this_layer["act_order"]: Q.copy_(Q[:, gptq_perm]) - if is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D): + if is_transformers_imported() and isinstance( + sub_layers[layer_name], transformers.Conv1D + ): # pragma: no cover Q = Q.t_().contiguous() from .utility import quant_weight_w_scale @@ -655,14 +902,16 @@ def tmp(_, inp, out): if isinstance(sub_layers[layer_name], torch.nn.Linear): in_features = sub_layers[layer_name].in_features out_features = sub_layers[layer_name].out_features - elif is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D): + elif is_transformers_imported() and isinstance( + sub_layers[layer_name], transformers.Conv1D + ): # pragma: no cover in_features = sub_layers[layer_name].weight.shape[0] out_features = sub_layers[layer_name].weight.shape[1] int_weight = sub_layers[layer_name].weight.t_().contiguous() scale = scale.t_().contiguous() zp = zp.t_().contiguous() if zp is not None else zp - if not self.use_layer_wise: + if not self.use_layer_wise: # pragma: no cover bias = sub_layers[layer_name].bias new_module = INCWeightOnlyLinear( @@ -677,12 +926,7 @@ def tmp(_, inp, out): device=self.device, ) new_module.pack(int_weight, gptq_scale, gptq_zp, bias, gptq_perm) - set_module(transformer_block, layer_name, new_module) - del gptq_for_this_block - torch.cuda.empty_cache() - # iteratively replace the input with output, thus layerwise quantization can continue. - self.update_blockwise_hidden_states(outs) - logger.info("------------------------------") + set_module(self.model, layer_name, new_module) logger.info("Quantization done") # self.model.config.use_cache = self.use_cache @@ -1097,6 +1341,7 @@ def prepare( device=None, use_layer_wise=False, model_path=None, + quant_lm_head=False, *args, **kwargs, ): @@ -1116,6 +1361,7 @@ def prepare( device=device, use_layer_wise=use_layer_wise, model_path=model_path, + quant_lm_head=quant_lm_head, ) self.gptq_quantizer.prepare_for_calibration() return self.gptq_quantizer.model diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index fe182ddadd9..6ce9b49fac8 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -177,6 +177,8 @@ def convert( if dtype != "int" and "int" in dtype: bits = int(dtype.lstrip("int")) dtype = "int" + else: + continue log_msg = ( f"RTN quantization config: bits={bits}, group_size={group_size}, " + f"scheme={scheme}, quantile={quantile}" diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index fd5ce80455b..3a009d1aa65 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -159,11 +159,13 @@ def gptq_entry( "percdamp": quant_config.percdamp, "block_size": quant_config.block_size, "static_groups": quant_config.static_groups, + "true_sequential": quant_config.true_sequential, } kwargs.update( { "use_layer_wise": quant_config.use_layer_wise, "model_path": quant_config.model_path, + "quant_lm_head": quant_config.quant_lm_head, } ) kwargs.pop("example_inputs") diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 2973dc48a51..c7b19683882 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -351,6 +351,7 @@ class GPTQConfig(TorchBaseConfig): "percdamp", "block_size", "static_groups", + "true_sequential", ] def __init__( @@ -376,6 +377,7 @@ def __init__( percdamp: float = 0.01, block_size: int = 2048, static_groups: bool = False, + true_sequential: bool = False, # Tuning space white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): @@ -404,10 +406,12 @@ def __init__( static_groups (bool): Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. Default is False. + true_sequential (bool): Whether to quantize layers within a transformer block in their original order. + This can lead to higher accuracy but slower overall quantization process. + Default is False. white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types. Default is DEFAULT_WHITE_LIST. """ - assert not quant_lm_head, "GPTQ doesn't support lm_head quantization currently, it's coming soon!" super().__init__(white_list=white_list) self.dtype = dtype self.bits = bits @@ -428,6 +432,7 @@ def __init__( self.percdamp = percdamp self.block_size = block_size self.static_groups = static_groups + self.true_sequential = true_sequential self.quant_lm_head = quant_lm_head self._post_init() # initialize global & local configuration @@ -599,7 +604,7 @@ def __init__( double_quant_bits (int): Number of bits used to represent double_quant scale, default is 4. double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True. double_quant_group_size (int): Size of double_quant groups, default is 32. - quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False. + quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformer, default is False. use_auto_scale (bool): Enables best scales search based on activation distribution, default is True. use_auto_clip (bool): Enables clip range search. Defaults to True. folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer, diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index 92d7d0b790e..4d8134b07d1 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -182,9 +182,10 @@ def test_act_order(self): # compare atol, this case is an ideal case. assert atol_false > atol_true, "act_order=True doesn't help accuracy, maybe is reasonable, please double check." - def test_layer_wise(self): + @pytest.mark.parametrize("quant_lm_head", [False, True]) + def test_layer_wise(self, quant_lm_head): model = copy.deepcopy(self.tiny_gptj) - quant_config = GPTQConfig() + quant_config = GPTQConfig(quant_lm_head=quant_lm_head) model = prepare(model, quant_config) run_fn(model) model = convert(model) @@ -194,12 +195,76 @@ def test_layer_wise(self): model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM") - quant_config = GPTQConfig(use_layer_wise=True, model_path="hf-internal-testing/tiny-random-GPTJForCausalLM") + quant_config = GPTQConfig( + use_layer_wise=True, + quant_lm_head=quant_lm_head, + model_path="hf-internal-testing/tiny-random-GPTJForCausalLM", + ) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) + out = model(self.example_inputs)[0] + + # remove lwq tmp directory + from neural_compressor.torch.algorithms.layer_wise.utils import LWQ_WORKSPACE + + shutil.rmtree(LWQ_WORKSPACE, ignore_errors=True) + assert torch.equal( + out, q_label + ), f"use_layer_wise=True and quant_lm_head={quant_lm_head} output should be same. Please double check." + + def test_true_sequential(self): + # true_sequential=False + model = copy.deepcopy(self.tiny_gptj) + quant_config = GPTQConfig( + true_sequential=False, + ) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) + out = model(self.example_inputs)[0] + atol_false = (out - self.label).amax() + # true_sequential=True + model = copy.deepcopy(self.tiny_gptj) + quant_config = GPTQConfig( + true_sequential=True, + ) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) + out = model(self.example_inputs)[0] + atol_true = (out - self.label).amax() + # compare atol, this case is an ideal case. + assert ( + atol_false < atol_true + ), "true_sequential=True doesn't help accuracy, maybe is reasonable, please double check." + + def test_quant_lm_head(self): + # quant_lm_head=False + model = copy.deepcopy(self.tiny_gptj) + quant_config = GPTQConfig( + quant_lm_head=False, + ) model = prepare(model, quant_config) run_fn(model) model = convert(model) out = model(self.example_inputs)[0] - assert torch.equal(out, q_label), "use_layer_wise=True output should be same. Please double check." + atol_false = (out - self.label).amax() + # quant_lm_head=True + model = copy.deepcopy(self.tiny_gptj) + quant_config = GPTQConfig( + quant_lm_head=True, + ) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) + out = model(self.example_inputs)[0] + atol_true = (out - self.label).amax() + # compare atol, this case is an ideal case. + assert ( + atol_false < atol_true + ), "quant_lm_head=True doesn't help accuracy, maybe is reasonable, please double check." + assert get_woq_linear_num(model, "INCWeightOnlyLinear") == 31, "Incorrect number of INCWeightOnlyLinear modules" @pytest.mark.parametrize("dtype", ["nf4", "int4"]) @pytest.mark.parametrize("double_quant_bits", [6]) diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index d63672539ae..0dd1d7effd0 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -174,6 +174,15 @@ def test_quant_lm_head(self): ), "The tied lm_head weight is not deep copied, please check!" def test_layer_wise(self): + # use_layer_wise=False + model = copy.deepcopy(self.tiny_gptj) + quant_config = RTNConfig( + use_layer_wise=False, + ) + model = prepare(model, quant_config) + model = convert(model) + out0 = model(self.example_inputs)[0] + from neural_compressor.torch import load_empty_model model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM") @@ -182,8 +191,8 @@ def test_layer_wise(self): ) model = prepare(model, quant_config) model = convert(model) - out = model(self.example_inputs)[0] - assert torch.equal(out, self.q_label), "use_layer_wise=True output should be same. Please double check." + out1 = model(self.example_inputs)[0] + assert torch.equal(out1, out0), "use_layer_wise=True output should be same. Please double check." @pytest.mark.parametrize( "dtype",