From b886701f46c1b2eb40e89f39958d30014a3c22f0 Mon Sep 17 00:00:00 2001 From: Yiyang Cai <49231152+YIYANGCAI@users.noreply.github.com> Date: Thu, 21 Sep 2023 09:09:31 +0800 Subject: [PATCH] Support GPTQ on ChatGLM2-6B (#1269) Signed-off-by: YIYANGCAI Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../ptq_weight_only/run-gptq-llm.py | 15 ++++++----- neural_compressor/adaptor/torch_utils/gptq.py | 27 ++++++++++++++++--- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py index f40748c99ba..149a030af09 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py @@ -2,6 +2,7 @@ sys.path.append("./") import math import time +import re import torch import torch.nn as nn @@ -9,7 +10,7 @@ from torch.nn.functional import pad import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer import datasets from datasets import load_dataset @@ -214,12 +215,14 @@ def skip(*args, **kwargs): torch.nn.init.normal_ = skip # model - model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True) - model.seqlen = args.pad_max_length - model.eval() + if re.search("chatglm", args.model_name_or_path.lower()): # chatglm requires a different way to be loaded + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) + model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True) + else: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True) + model = model.eval() - # dataset - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) calib_dataset = load_dataset(args.dataset, split="train") # default # calib_dataset = datasets.load_from_disk('/your/local/pile-10k/') # use this if trouble with connecting to HF calib_dataset = calib_dataset.shuffle(seed=args.seed) diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index 68f61ea71f1..28fa79766a9 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -237,7 +237,8 @@ def prepare_dataloader(self): self.obtain_first_n_samples() try: self.inp = [torch.zeros(1) for _ in range(len(self.dataloader))] - self.cache = {"i": 0} + self.cache = {"i": 0} # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.) + self.cache_positional_arguments = [] # a list of list, positional arguments ("rotary_pos_emb" in chatglm) self.out = [torch.zeros(1) for _ in range(len(self.dataloader))] self.is_ready = True except: @@ -405,7 +406,7 @@ def pre_quantization(self): """Prepare input calibration data and other attributes which are critical for gptq execution.""" # critical: hooker function which collects inputs - def forward(layer, hidden_states, **kwargs): + def forward(layer, hidden_states, *args, **kwargs): # inputs[inputs_info['idx']] = input_ids # TODO solve the problem of batchsize!=1 self.inp[self.cache["i"]] = hidden_states self.cache["i"] += 1 @@ -417,6 +418,14 @@ def forward(layer, hidden_states, **kwargs): self.cache[arg] = [] self.cache[arg].append(kwargs[arg]) continue + # copy positional arguments, positional arguments are sensitive for their order, be cautious! + # Most models in HF has avoid this, but some models still use positional arguments other than + # hidden_states, chatglm2-6b etc. + for idx, item in enumerate(args): + if (idx + 1) > len(self.cache_positional_arguments): + # initialize + self.cache_positional_arguments.append([]) + self.cache_positional_arguments[idx].append(item) raise ValueError # Step1: fetch the embeddings and other layers before the transformer stack. @@ -459,11 +468,19 @@ def forward(layer, hidden_states, **kwargs): logger.info("GPTQ quantization prepared.") def gather_single_batch_from_dict(self, data_dict, idx): + # obtain a set of keyword input from cache single_batch = {} for k, v in data_dict.items(): single_batch[k] = data_dict[k][idx] return single_batch + def gather_single_batch_from_list(self, data_list, idx): + # obtain a set of keyword input from cache + single_batch = [] + for data_item in data_list: + single_batch.append(data_item[idx]) + return single_batch + @torch.no_grad() def execute_quantization(self, means=None, stds=None): """Run quantization.""" @@ -520,7 +537,8 @@ def tmp(_, inp, out): for j in range(len(self.dataloader)): # self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default) cache_batch = self.gather_single_batch_from_dict(self.cache, j) - self.out[j] = transformer_block(self.inp[j], **cache_batch)[0] + cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) + self.out[j] = transformer_block(self.inp[j], *cache_positional_batch, **cache_batch)[0] self.cache["i"] = idx for h in handles: h.remove() @@ -551,7 +569,8 @@ def tmp(_, inp, out): for j in range(len(self.dataloader)): # self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default) cache_batch = self.gather_single_batch_from_dict(self.cache, j) - self.out[j] = transformer_block(self.inp[j], **cache_batch)[0] + cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) + self.out[j] = transformer_block(self.inp[j], *cache_positional_batch, **cache_batch)[0] self.cache["i"] = idx self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() del gptq_for_this_block