diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/README.md index b85b5c57153..9fae63388c0 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/README.md +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/README.md @@ -10,7 +10,21 @@ The retraining free pruning feature is still in development, please stay tuned. ## 1. Environment PyTorch 1.8 or higher version is needed with pytorch_fx backend -The loading of llama models requires transformers version 4.28.0 or higher. +The transformers version required varies across different types of models. Here, the transformers version used for running models during experiments is provided as a reference. +| Model | Transformers version | +| :----: | :----: | +| EleutherAI/gpt-j-6b | 4.28/4.30/4.34/4.36 | +| huggyllama/llama-7b | 4.28/4.30/4.34/4.36 | +| meta-llama/Llama-2-7b-hf | 4.30/4.34/4.36 | +| facebook/opt-6.7b | 4.28/4.30/4.34/4.36 | +| databricks/dolly-v2-3b | 4.28/4.30/4.34/4.36 | +| tiiuae/falcon-7b | 4.28/4.30/4.34/4.36 | +| mosaicml/mpt-7b | 4.28/4.30/4.34/4.36 | +| bigscience/bloom-7b1 | 4.28/4.30/4.34/4.36 | +| baichuan-inc/Baichuan-7B | 4.28/4.30 | +| Qwen/Qwen-7B | 4.28/4.30/4.34/4.36 | +| THUDM/chatglm3-6b | 4.34/4.36 | +| mistralai/Mistral-7B-v0.1 | 4.34/4.36 | ```shell @@ -39,18 +53,6 @@ Pruning scripts are available for LLM sparse models such as GPT-j, BLOOM, OPT, L ## Retrain-free Results The last token accuracy for channel pruning using [the retrain-free scripts](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/scripts/run_gptj_pruning.sh) is presented in the following table. -| Model | Calibration dataset | Evaluation dataset | Sparsity pattern | Over MLP block sparsity |Element-wise/matmul, Gemm, conv ratio | Dense last token accuracy | Sparse last token accuracy | Relative drop | -| :----: | :----: | :----: | :----: | :----: | :----: |:----: |:----:| :----: | -| EleutherAI/gpt-j-6b | lambada | lambada | channelx1 | 0.1999 | 0.1242 | 0.7917 | 0.8038 | +1.50% | -| EleutherAI/gpt-j-6b | the_pile | lambada | channelx1 | 0.0999 | 0.0643 | 0.7917 | 0.7931 | +0.17% | -| EleutherAI/gpt-j-6b | pile_10k | lambada | channelx1 | 0.0999 | 0.0643 | 0.7917 | 0.7901 | -0.20% | -| facebook/opt-1.3b | pile_10k | lambada | channelx1 | 0.0999 | 0.0614 | 0.7541 | 0.7498 | -0.57% | -| facebook/opt-2.7b | pile_10k | lambada | channelx1 | 0.0999 | 0.0634 | 0.7779 | 0.7778 | -0.01% | -| decapoda-research/llama-7b-hf | pile_10k | lambada | channelx1 | 0.0999 | 0.0654 | 0.8856 | 0.8815 | -0.46% | -| bigscience/bloom-1b7 | pile_10k | lambada | channelx1 | 0.0999 | 0.0466 | 0.7143 | 0.7141 | -0.03% | -| bigscience/bloom-7b1 | pile_10k | lambada | channelx1 | 0.0999 | 0.0568 | 0.7745 | 0.7742 | -0.04% | - -
The last word acc of the channel-wise sparse model is shown in the following table. All the sparsity is 10% over MLP block. | Model | Task | Calibration dataset | Evaluation dataset | Precision | Dense last word accuracy | Sparse last word accuracy | Relative drop | @@ -68,29 +70,39 @@ The last word acc of the channel-wise sparse model is shown in the following tab | bigscience/bloom-7b1 | CLM | pile_10k | lambada_openai | FP32 | 0.5764 | 0.5791 | 0.47% | | bigscience/bloom-7b1 | CLM | pile_10k | lambada_openai | BF16 | 0.5723 | 0.5756 | 0.58% | -
+ ## SparseGPT Results The last word acc of the 1x1 pattern sparse model using [the sparseGPT script](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/scripts/run_llm_sparsegpt.sh) is shown in the following table. + | Model | Task | Calibration dataset | Evaluation dataset | Sparsity | Precision | Dense last word accuracy | Sparse last word accuracy | Relative drop | | :----: | :----: | :----: | :----: | :----: | :----: | :----: |:----: |:----:| +| meta-llama/Llama-2-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 30% | FP32 | 0.7392 | 0.7320 | -0.97% | +| meta-llama/Llama-2-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 30% | BF16 | 0.7365 | 0.7304 | -1.19% | | EleutherAI/gpt-j-6b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.6831 | 0.6922 | +1.33% | -| EleutherAI/gpt-j-6b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6771 | 0.6874 | +1.52% | +| EleutherAI/gpt-j-6b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6771 | 0.6874 | +0.63% | | decapoda-research/llama-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7361 | 0.7332 | -0.39% | -| decapoda-research/llama-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7326 | 0.7297 | -0.23% | +| decapoda-research/llama-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7326 | 0.7297 | -0.87% | | facebook/opt-6.7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.6769 | 0.6616 | -2.26% | -| facebook/opt-6.7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6730 | 0.6577 | -2.27% | -| tiiuae/falcon-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7467 | 0.7528 | -0.82% | -| tiiuae/falcon-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7464 | 0.7502 | -0.51% | +| facebook/opt-6.7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6730 | 0.6577 | -2.84% | +| tiiuae/falcon-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7467 | 0.7528 | +0.82% | +| tiiuae/falcon-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7464 | 0.7502 | +0.47% | | bigscience/bloom-7b1 | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.5764 | 0.5606 | -2.74% | -| bigscience/bloom-7b1 | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.5725 | 0.5587 | -2.41% | +| bigscience/bloom-7b1 | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.5725 | 0.5587 | -3.07% | | mosaicml/mpt-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7056 | 0.7035 | -0.30% | -| mosaicml/mpt-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6831 | 0.6856 | +0.37% | +| mosaicml/mpt-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6831 | 0.6856 | -2.83% | | mosaicml/mpt-7b-chat | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.6550 | 0.6561 | +0.17% | -| mosaicml/mpt-7b-chat | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6456 | 0.6451 | -0.08% | +| mosaicml/mpt-7b-chat | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6456 | 0.6451 | -1.51% | +| meta-llama/Llama-2-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7679 | 0.7629 | -0.65% | +| meta-llama/Llama-2-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7667 | 0.7601 | -1.02% | | decapoda-research/llama-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 50% | FP32 | 0.7627 | 0.7559 | -0.89% | -| decapoda-research/llama-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 50% | BF16 | 0.7599 | 0.7559 | -0.53% | +| decapoda-research/llama-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 50% | BF16 | 0.7599 | 0.7559 | -0.89% | +| meta-llama/Llama-2-70b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 60% | FP32 | 0.7964 | 0.7951 | -0.16% | +| meta-llama/Llama-2-70b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 60% | BF16 | 0.7937 | 0.7943 | -0.26% | +| Qwen/Qwen-72B | CLM | wikitext-2-raw-v1 | lambada_openai | 60% | FP32 | - | - | - | +| Qwen/Qwen-72B | CLM | wikitext-2-raw-v1 | lambada_openai | 60% | BF16 | 0.7673 | 0.7813 | - | + ## References @@ -102,4 +114,3 @@ The last word acc of the 1x1 pattern sparse model using [the sparseGPT script](h - diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt index b58987bddea..5b0d8f0bd00 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt @@ -1,9 +1,13 @@ accelerate datasets +einops +intel_extension_for_transformers +optimum +peft sentencepiece -transformers +transformers==4.36.0 torch tqdm -optimum -einops - +tiktoken +transformers_stream_generator +git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2 diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_sparsegpt.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_sparsegpt.py index 056496d968a..970580dc63b 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_sparsegpt.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_sparsegpt.py @@ -5,10 +5,10 @@ import math import os import sys -sys.path.insert(0, './neural-compressor') sys.path.insert(0, './') - +sys.path.insert(0, './neural-compressor') import random +import re import numpy as np from itertools import chain from pathlib import Path @@ -36,23 +36,24 @@ def skip(*args, **kwargs): import transformers from transformers import ( + AutoModelForCausalLM, + AutoModel, CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, - AutoModelForCausalLM, - OPTForCausalLM, AutoTokenizer, SchedulerType, default_data_collator, - get_scheduler, ) -from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry + +from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version -from neural_compressor.training import prepare_compression -from neural_compressor.training import WeightPruningConfig from timers import CPUTimer, GPUTimer -from neural_compressor.compression.pruner import model_slim -from neural_compressor.compression.pruner import parse_auto_slim_config +from neural_compressor.training import WeightPruningConfig +from neural_compressor.compression.pruner import (prepare_pruning, + parse_auto_slim_config) +from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate + check_min_version("4.27.0.dev0") logger = logging.getLogger(__name__) @@ -64,99 +65,6 @@ def skip(*args, **kwargs): RANK = int(os.getenv('RANK', -1)) WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) -class Evaluator: - def __init__(self, dataset, tokenizer, device, batch_size=16): - self.dataset = dataset - self.tokenizer = tokenizer - self.device = device - self.dataloader = INCDataloader(dataset, tokenizer, self.device, batch_size) - - @torch.no_grad() - def evaluate(self, model): - model.eval() - # The task is to predict the last word of the input. - total, hit = 0, 0 - if torch.cuda.is_available(): - my_timer = GPUTimer(timelogs = []) - else: - my_timer = CPUTimer(timelogs = []) - warmup_steps = 10 - step = 0 - for input_ids, label, label_indices in tqdm(self.dataloader): - with torch.no_grad(): - # if step == 0: - # model = torch.jit.trace(model, input_ids) - step += 1 - # timing - if step > warmup_steps: my_timer.__enter__() - outputs = model(input_ids) - if step > warmup_steps: my_timer.__exit__() - last_token_logits = outputs[0][torch.arange(len(label_indices)), label_indices, :] - pred = last_token_logits.argmax(dim=-1) - total += label.size(0) - hit += (pred == label).sum().item() - if step % 100 == 0: - logger.info(f"eval step:{step} accuracy:{float(hit/total)}") - avg_latency = my_timer.get_avg_time() - del my_timer - accuracy = hit / total - return accuracy, avg_latency - - -class INCDataloader(): - def __init__(self, dataset, tokenizer, device, batch_size=1): - self.dataset = dataset - self.tokenizer = tokenizer - self.device = device - self.batch_size = batch_size - import math - self.length = math.ceil(len(dataset) / self.batch_size) - self.pad_len = 196 - - # tokenize the dataset - def tokenize_function(examples): - example = self.tokenizer(examples['text']) - return example - - self.dataset = self.dataset.map(tokenize_function, batched=True) - self.dataset.set_format(type='torch', columns=['input_ids']) - - def pad_input(self, input): - input_id = input['input_ids'].unsqueeze(0).to(self.device) - label = input_id[:, -1].to(self.device) - pad_len = self.pad_len - input_id.shape[1] - label_index = -2 - pad_len - input_id = pad(input_id, (0, pad_len), value=1) - - return (input_id, label, label_index) - - def __iter__(self): - input_ids = None - labels = None - label_indices = None - for idx, batch in enumerate(self.dataset): - input_id, label, label_index = self.pad_input(batch) - - if input_ids is None: - input_ids = input_id - labels = label - label_indices = [label_index] - else: - input_ids = torch.cat((input_ids, input_id), 0).to(self.device) - labels = torch.cat((labels, label), 0).to(self.device) - label_indices.append(label_index) - - if (idx + 1) % self.batch_size == 0: - yield (input_ids, labels, label_indices) - input_ids = None - labels = None - label_indices = None - if (idx + 1) % self.batch_size != 0: - yield (input_ids, labels, label_indices) - - def __len__(self): - return self.length - def parse_args(): parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") parser.add_argument( @@ -165,12 +73,6 @@ def parse_args(): default="wikitext-2-raw-v1", help="The name of the pruning dataset to use (via the datasets library).", ) - parser.add_argument( - "--evaluation_dataset_name", - type=str, - default=None, - help="The name of the evaluation dataset to use (via the datasets library).", - ) parser.add_argument( "--dataset_config_name", type=str, @@ -328,17 +230,18 @@ def parse_args(): "If passed, LLM loading time and RAM consumption will be benefited." ), ) + + ### DDP mode config + parser.add_argument( + "--local_rank", + type=int, default=-1, + help="Automatic DDP Multi-GPU argument, do not modify") + # pruning config parser.add_argument( "--do_prune", action="store_true", help="Whether or not to prune the model" ) - parser.add_argument( - "--max_pruning_steps", - type=int, - default=None, - help="Total number of pruning steps to perform. If provided", - ) parser.add_argument( "--pruning_pattern", type=str, default="1x1", @@ -349,15 +252,6 @@ def parse_args(): type=float, default=0.5, help="Target sparsity of the model." ) - parser.add_argument( - "--pruning_frequency", - type=int, default=-1, - help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." - ) - parser.add_argument( - "--auto_slim", action="store_true", - help="Whether or not to auto slim the model after pruning." - ) parser.add_argument( "--auto_config", action="store_true", help="Whether to automatically generate pruning configs." @@ -371,15 +265,16 @@ def parse_args(): "--trust_remote_code", default=True, help="Transformers parameter: use the external repo") - ### DDP mode config - parser.add_argument( - "--local_rank", - type=int, default=-1, - help="Automatic DDP Multi-GPU argument, do not modify") - - parser.add_argument("--eval_fp16", action='store_true', - help=" fp16") - + # Evaluation config + parser.add_argument("--tasks", default=["lambada_openai"], + help="Usually chosen with ['lambada_openai','hellaswag','winogrande','piqa']", + ) + parser.add_argument("--use_accelerate", action='store_true', + help="Usually use to accelerate evaluation for large models" + ) + parser.add_argument("--eval_dtype", default='fp32', + help="choose in bf16, fp16 and fp32" + ) args = parser.parse_args() @@ -484,14 +379,14 @@ def main(): logger.warning("You are instantiating a new config instance from scratch.") is_llama = bool("llama" in args.model_name_or_path) - is_t5 = bool("t5" in args.model_name_or_path) if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) elif args.model_name_or_path: if is_llama: tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name_or_path) else : - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, + use_fast=not args.use_slow_tokenizer, trust_remote_code=True) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." @@ -499,11 +394,9 @@ def main(): ) if args.model_name_or_path: - if is_t5: - model = T5ForConditionalGeneration.from_pretrained( - args.model_name_or_path, - config=config, - ) + if re.search("chatglm", args.model_name_or_path.lower()): + model = AutoModel.from_pretrained(args.model_name_or_path, + trust_remote_code=args.trust_remote_code) # .half() else: model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, @@ -621,7 +514,8 @@ def group_texts(examples): pruning_configs=[ { "pruning_type": "sparse_gpt", - "op_names": [".attn", "_proj", ".fc", "key", "dense", "_h"], + "op_names": [".*"], + "excluded_op_names": ["lm_head", "embed_out"], } ] else: @@ -640,73 +534,60 @@ def group_texts(examples): target_sparsity=args.target_sparsity, pattern=args.pruning_pattern, ) - + + device = args.device + if device != 'cpu': + device = "cuda:"+str(device) + if args.do_prune: torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False use_cache = model.config.use_cache model.config.use_cache = False - # if torch.cuda.is_available(): # Larger models(e.g. 80G+) may not load into the video card memory. - # model = model.cuda() - device = args.device - if device != 'cpu': - device = "cuda:"+str(device) - from neural_compressor.training import prepare_pruning - pruning = prepare_pruning(model, configs, dataloader=train_dataloader, device=device) + + pruning = prepare_pruning(model, configs, dataloader=train_dataloader, device=device) model.config.use_cache = use_cache if args.output_dir is not None: ###TODO set ddp save method output_dir = args.output_dir - if args.auto_slim: - output_dir += "/before_slim" model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) logger.info(f"The model has been exported to {output_dir}") if device != 'cpu': - model = model.to(device) + if not args.use_accelerate: + model = model.to(device) + else: + model = model.cpu() logger.info(f"***** Evaluation in GPU mode. *****") else: logger.info(f"***** Evaluation in CPU mode. *****") model.eval() - if args.evaluation_dataset_name != None: - dataset_eval = load_dataset( - # for example:use the_pile's validation set for pruning, and lambada dataset for eval - args.evaluation_dataset_name, - args.dataset_config_name, - split=f"validation", - ) - else: - dataset_eval = raw_datasets["validation"] - dataset_eval = dataset_eval.shuffle(seed=42) - evaluator = Evaluator(dataset_eval, tokenizer, model.device, batch_size=args.per_device_eval_batch_size) - def eval_func(model): - acc, avg_latency = evaluator.evaluate(model) - return acc, avg_latency - if not args.auto_slim: - # only eval - logger.info(f"***** Running Evaluation *****") - acc, _ = eval_func(model) - logger.info(f"pruned model accuracy:{acc}") + model_name = args.model_name_or_path + dtype = None + if args.eval_dtype == 'bf16': + model = model.to(dtype=torch.bfloat16) + dtype = 'bfloat16' + elif args.eval_dtype == 'fp16': + dtype = 'float16' + model = model.to(dtype=torch.float16) else: - logger.info(f"***** Running Evaluation before ffn auto slim*****") - accuracy, avg_latency = eval_func(model) - logger.info(f"accuracy:{accuracy} avg_latency:{avg_latency}") - model = model_slim(model, round_multiplier=32) - - logger.info(f"***** Running Evaluation after ffn auto_slim*****") - accuracy, avg_latency = eval_func(model) - logger.info(f"accuracy:{accuracy} avg_latency:{avg_latency}") - - if RANK in {-1, 0}: - if args.output_dir is not None and args.auto_slim: - model.to('cpu') - torch.save(model, args.output_dir+"/slimed_model.pt") - tokenizer.save_pretrained(args.output_dir) - if args.push_to_hub: - repo.push_to_hub(commit_message="End of auto slim", auto_lfs_prune=True) + dtype = 'float32' + model = model.to(dtype=torch.float32) + + model_args = f'pretrained={model_name},tokenizer={model_name},dtype={dtype},use_accelerate={args.use_accelerate},trust_remote_code={args.trust_remote_code}' + eval_batch = args.per_device_eval_batch_size + user_model = None if args.use_accelerate else model + results = evaluate( + model="hf-causal", + model_args=model_args, + user_model=user_model, + batch_size=eval_batch, + tasks=args.tasks, + device=device, + ) if __name__ == "__main__": main() diff --git a/neural_compressor/compression/pruner/pruning.py b/neural_compressor/compression/pruner/pruning.py index e31847eec4d..1ded7e5f0c9 100644 --- a/neural_compressor/compression/pruner/pruning.py +++ b/neural_compressor/compression/pruner/pruning.py @@ -205,37 +205,41 @@ def _do_pruning(self): layers = self._layers self._model = self._model.cpu() - inputs, inp_dict = collect_layer_inputs( + inputs, positional_inputs, other_input_infos = collect_layer_inputs( model=self._model, layers=layers, layer_idx=0, layer_inputs=self._dataloader, device=self.dev ) - - with torch.no_grad(): - for i in tqdm(range(len(layers))): - layer = layers[i].to(self.dev) - layer_index_str = "." + str(i) + "." - handles_list = [] - for pruner in self.pruners: - layer_op_names = [key for key in pruner.modules.keys() if layer_index_str in key] + for i in tqdm(range(len(layers))): + layer = layers[i].to(self.dev) + layer_index_str = "." + str(i) + "." + handles_list = [] + for pruner in self.pruners: + layer_op_names = [key for key in pruner.modules.keys() if layer_index_str in key] + if bool(layer_op_names): handles_list.append(pruner.register_gpt_hook(layer_op_names)) + prune_flag = bool(handles_list) + if prune_flag: for j in range(len(inputs)): - input_infos = self.gather_single_batch_from_dict(inp_dict, j) - layer(inputs[j], **input_infos)[0] + other_infos = self.gather_single_batch_from_dict(other_input_infos, j) + with torch.no_grad(): + layer(inputs[j], *positional_inputs, **other_infos)[0] for handles in handles_list: for h in handles: h.remove() for pruner in self.pruners: layer_op_names = [key for key in pruner.modules.keys() if layer_index_str in key] pruner.fasterprune(layer_op_names) - for j in range(len(inputs)): - # the weights of current layer have been pruned, get the latest outputs as the inputs for next layer - input_infos = self.gather_single_batch_from_dict(inp_dict, j) - inputs[j] = layer(inputs[j], **input_infos)[0] - layers[i] = layer.cpu() - if "cuda" in self.dev.type: - torch.cuda.empty_cache() - del inp_dict - del inputs - gc.collect() + for j in range(len(inputs)): + # the weights of current layer have been pruned, get the latest outputs as the inputs for next layer + other_infos = self.gather_single_batch_from_dict(other_input_infos, j) + with torch.no_grad(): + inputs[j] = layer(inputs[j], *positional_inputs, **other_infos)[0] + layers[i] = layer.cpu() + if "cuda" in self.dev.type: + torch.cuda.empty_cache() + del other_infos + del positional_inputs + del inputs + gc.collect() if "cuda" in self.dev.type: torch.cuda.empty_cache() diff --git a/neural_compressor/compression/pruner/utils.py b/neural_compressor/compression/pruner/utils.py index bc43a0e02ab..d31bdb231b6 100644 --- a/neural_compressor/compression/pruner/utils.py +++ b/neural_compressor/compression/pruner/utils.py @@ -672,31 +672,17 @@ def unfoldLayer(module): return layers -def move_input_to_device(input, device): - if device is None: - device = torch.device("cpu") - elif isinstance(device, str): - device = torch.device(device) - +def move_input_to_device(input, device="cpu"): + if isinstance(input, torch.Tensor): + return input.to(device) if isinstance(input, dict) or isinstance(input, UserDict): for inp in input.keys(): - input[inp] = input[inp].to(device) if isinstance(input[inp], torch.Tensor) else input[inp] + input[inp] = move_input_to_device(input[inp], device) elif isinstance(input, list) or isinstance(input, tuple): - input_res, prev_size = [], None + input_res = [] for inp in input: - if prev_size: - if isinstance(inp, torch.Tensor): - if inp.size() == prev_size: - input_res.append(inp.to(device)) - else: - if torch.tensor(inp).size == prev_size: - input_res.append(inp) - else: - input_res.append(inp.to(device) if isinstance(inp, torch.Tensor) else inp) - prev_size = torch.tensor(inp).size() + input_res.append(move_input_to_device(inp, device)) input = input_res - else: - input = input.to(device) # pylint: disable=no-member return input @@ -712,25 +698,26 @@ def collect_layer_inputs(model, layers, layer_idx, layer_inputs, device="cuda:0" Returns: input list. """ inputs = [] + other_input_infos = {} + positional_inputs = [] model_dev = model.device - attention_mask = None - # 'alibi' is a necessary attribute for the bloom models - inputs_info = {} with torch.no_grad(): if layer_idx == 0: layer = layers[layer_idx] - def forward(self, hidden_states, **kwargs): + def forward(_, hidden_states, *positional_args, **kwargs): + nonlocal inputs + nonlocal positional_inputs + nonlocal other_input_infos # TODO solve the problem of batchsize!=1 - inputs.append(hidden_states.to(device)) + inputs.append(move_input_to_device(hidden_states, device)) + if len(positional_inputs) <= 0: + positional_inputs = move_input_to_device(positional_args, device) for key in kwargs.keys(): - if isinstance(kwargs[key], torch.Tensor) or (key == "alibi"): - if key not in inputs_info.keys(): - inputs_info[key] = [] - if isinstance(kwargs[key], torch.Tensor): - kwargs[key] = kwargs[key].to(device) - inputs_info[key].append(kwargs[key]) + if key not in other_input_infos.keys(): + other_input_infos[key] = [] + other_input_infos[key].append(move_input_to_device(kwargs[key], device)) raise ValueError forward_cache = layers[layer_idx].forward @@ -749,7 +736,6 @@ def forward(self, hidden_states, **kwargs): except ValueError: pass layer.forward = forward_cache - else: prev_layer = layers[layer_idx - 1] @@ -758,4 +744,4 @@ def forward(self, hidden_states, **kwargs): batch[0] = prev_output[0] inputs.append(batch) - return inputs, inputs_info + return inputs, positional_inputs, other_input_infos