Skip to content

Commit

Permalink
Support GPTQ on ChatGLM2-6B (#1269)
Browse files Browse the repository at this point in the history
Signed-off-by: YIYANGCAI <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
YIYANGCAI and pre-commit-ci[bot] authored Sep 21, 2023
1 parent 854452b commit b886701
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
sys.path.append("./")
import math
import time
import re

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.functional import pad

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
import datasets
from datasets import load_dataset

Expand Down Expand Up @@ -214,12 +215,14 @@ def skip(*args, **kwargs):
torch.nn.init.normal_ = skip

# model
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True)
model.seqlen = args.pad_max_length
model.eval()
if re.search("chatglm", args.model_name_or_path.lower()): # chatglm requires a different way to be loaded
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True)
model = model.eval()

# dataset
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
calib_dataset = load_dataset(args.dataset, split="train") # default
# calib_dataset = datasets.load_from_disk('/your/local/pile-10k/') # use this if trouble with connecting to HF
calib_dataset = calib_dataset.shuffle(seed=args.seed)
Expand Down
27 changes: 23 additions & 4 deletions neural_compressor/adaptor/torch_utils/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def prepare_dataloader(self):
self.obtain_first_n_samples()
try:
self.inp = [torch.zeros(1) for _ in range(len(self.dataloader))]
self.cache = {"i": 0}
self.cache = {"i": 0} # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.)
self.cache_positional_arguments = [] # a list of list, positional arguments ("rotary_pos_emb" in chatglm)
self.out = [torch.zeros(1) for _ in range(len(self.dataloader))]
self.is_ready = True
except:
Expand Down Expand Up @@ -405,7 +406,7 @@ def pre_quantization(self):
"""Prepare input calibration data and other attributes which are critical for gptq execution."""

# critical: hooker function which collects inputs
def forward(layer, hidden_states, **kwargs):
def forward(layer, hidden_states, *args, **kwargs):
# inputs[inputs_info['idx']] = input_ids # TODO solve the problem of batchsize!=1
self.inp[self.cache["i"]] = hidden_states
self.cache["i"] += 1
Expand All @@ -417,6 +418,14 @@ def forward(layer, hidden_states, **kwargs):
self.cache[arg] = []
self.cache[arg].append(kwargs[arg])
continue
# copy positional arguments, positional arguments are sensitive for their order, be cautious!
# Most models in HF has avoid this, but some models still use positional arguments other than
# hidden_states, chatglm2-6b etc.
for idx, item in enumerate(args):
if (idx + 1) > len(self.cache_positional_arguments):
# initialize
self.cache_positional_arguments.append([])
self.cache_positional_arguments[idx].append(item)
raise ValueError

# Step1: fetch the embeddings and other layers before the transformer stack.
Expand Down Expand Up @@ -459,11 +468,19 @@ def forward(layer, hidden_states, **kwargs):
logger.info("GPTQ quantization prepared.")

def gather_single_batch_from_dict(self, data_dict, idx):
# obtain a set of keyword input from cache
single_batch = {}
for k, v in data_dict.items():
single_batch[k] = data_dict[k][idx]
return single_batch

def gather_single_batch_from_list(self, data_list, idx):
# obtain a set of keyword input from cache
single_batch = []
for data_item in data_list:
single_batch.append(data_item[idx])
return single_batch

@torch.no_grad()
def execute_quantization(self, means=None, stds=None):
"""Run quantization."""
Expand Down Expand Up @@ -520,7 +537,8 @@ def tmp(_, inp, out):
for j in range(len(self.dataloader)):
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
self.out[j] = transformer_block(self.inp[j], *cache_positional_batch, **cache_batch)[0]
self.cache["i"] = idx
for h in handles:
h.remove()
Expand Down Expand Up @@ -551,7 +569,8 @@ def tmp(_, inp, out):
for j in range(len(self.dataloader)):
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
self.out[j] = transformer_block(self.inp[j], **cache_batch)[0]
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
self.out[j] = transformer_block(self.inp[j], *cache_positional_batch, **cache_batch)[0]
self.cache["i"] = idx
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
del gptq_for_this_block
Expand Down

0 comments on commit b886701

Please sign in to comment.