diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/README.md
index b85b5c57153..9fae63388c0 100644
--- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/README.md
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/README.md
@@ -10,7 +10,21 @@ The retraining free pruning feature is still in development, please stay tuned.
## 1. Environment
PyTorch 1.8 or higher version is needed with pytorch_fx backend
-The loading of llama models requires transformers version 4.28.0 or higher.
+The transformers version required varies across different types of models. Here, the transformers version used for running models during experiments is provided as a reference.
+| Model | Transformers version |
+| :----: | :----: |
+| EleutherAI/gpt-j-6b | 4.28/4.30/4.34/4.36 |
+| huggyllama/llama-7b | 4.28/4.30/4.34/4.36 |
+| meta-llama/Llama-2-7b-hf | 4.30/4.34/4.36 |
+| facebook/opt-6.7b | 4.28/4.30/4.34/4.36 |
+| databricks/dolly-v2-3b | 4.28/4.30/4.34/4.36 |
+| tiiuae/falcon-7b | 4.28/4.30/4.34/4.36 |
+| mosaicml/mpt-7b | 4.28/4.30/4.34/4.36 |
+| bigscience/bloom-7b1 | 4.28/4.30/4.34/4.36 |
+| baichuan-inc/Baichuan-7B | 4.28/4.30 |
+| Qwen/Qwen-7B | 4.28/4.30/4.34/4.36 |
+| THUDM/chatglm3-6b | 4.34/4.36 |
+| mistralai/Mistral-7B-v0.1 | 4.34/4.36 |
```shell
@@ -39,18 +53,6 @@ Pruning scripts are available for LLM sparse models such as GPT-j, BLOOM, OPT, L
## Retrain-free Results
The last token accuracy for channel pruning using [the retrain-free scripts](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/scripts/run_gptj_pruning.sh) is presented in the following table.
-| Model | Calibration dataset | Evaluation dataset | Sparsity pattern | Over MLP block sparsity |Element-wise/matmul, Gemm, conv ratio | Dense last token accuracy | Sparse last token accuracy | Relative drop |
-| :----: | :----: | :----: | :----: | :----: | :----: |:----: |:----:| :----: |
-| EleutherAI/gpt-j-6b | lambada | lambada | channelx1 | 0.1999 | 0.1242 | 0.7917 | 0.8038 | +1.50% |
-| EleutherAI/gpt-j-6b | the_pile | lambada | channelx1 | 0.0999 | 0.0643 | 0.7917 | 0.7931 | +0.17% |
-| EleutherAI/gpt-j-6b | pile_10k | lambada | channelx1 | 0.0999 | 0.0643 | 0.7917 | 0.7901 | -0.20% |
-| facebook/opt-1.3b | pile_10k | lambada | channelx1 | 0.0999 | 0.0614 | 0.7541 | 0.7498 | -0.57% |
-| facebook/opt-2.7b | pile_10k | lambada | channelx1 | 0.0999 | 0.0634 | 0.7779 | 0.7778 | -0.01% |
-| decapoda-research/llama-7b-hf | pile_10k | lambada | channelx1 | 0.0999 | 0.0654 | 0.8856 | 0.8815 | -0.46% |
-| bigscience/bloom-1b7 | pile_10k | lambada | channelx1 | 0.0999 | 0.0466 | 0.7143 | 0.7141 | -0.03% |
-| bigscience/bloom-7b1 | pile_10k | lambada | channelx1 | 0.0999 | 0.0568 | 0.7745 | 0.7742 | -0.04% |
-
-
The last word acc of the channel-wise sparse model is shown in the following table. All the sparsity is 10% over MLP block.
| Model | Task | Calibration dataset | Evaluation dataset | Precision | Dense last word accuracy | Sparse last word accuracy | Relative drop |
@@ -68,29 +70,39 @@ The last word acc of the channel-wise sparse model is shown in the following tab
| bigscience/bloom-7b1 | CLM | pile_10k | lambada_openai | FP32 | 0.5764 | 0.5791 | 0.47% |
| bigscience/bloom-7b1 | CLM | pile_10k | lambada_openai | BF16 | 0.5723 | 0.5756 | 0.58% |
-
+
## SparseGPT Results
The last word acc of the 1x1 pattern sparse model using [the sparseGPT script](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/scripts/run_llm_sparsegpt.sh) is shown in the following table.
+
| Model | Task | Calibration dataset | Evaluation dataset | Sparsity | Precision | Dense last word accuracy | Sparse last word accuracy | Relative drop |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |:----: |:----:|
+| meta-llama/Llama-2-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 30% | FP32 | 0.7392 | 0.7320 | -0.97% |
+| meta-llama/Llama-2-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 30% | BF16 | 0.7365 | 0.7304 | -1.19% |
| EleutherAI/gpt-j-6b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.6831 | 0.6922 | +1.33% |
-| EleutherAI/gpt-j-6b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6771 | 0.6874 | +1.52% |
+| EleutherAI/gpt-j-6b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6771 | 0.6874 | +0.63% |
| decapoda-research/llama-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7361 | 0.7332 | -0.39% |
-| decapoda-research/llama-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7326 | 0.7297 | -0.23% |
+| decapoda-research/llama-7b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7326 | 0.7297 | -0.87% |
| facebook/opt-6.7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.6769 | 0.6616 | -2.26% |
-| facebook/opt-6.7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6730 | 0.6577 | -2.27% |
-| tiiuae/falcon-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7467 | 0.7528 | -0.82% |
-| tiiuae/falcon-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7464 | 0.7502 | -0.51% |
+| facebook/opt-6.7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6730 | 0.6577 | -2.84% |
+| tiiuae/falcon-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7467 | 0.7528 | +0.82% |
+| tiiuae/falcon-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7464 | 0.7502 | +0.47% |
| bigscience/bloom-7b1 | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.5764 | 0.5606 | -2.74% |
-| bigscience/bloom-7b1 | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.5725 | 0.5587 | -2.41% |
+| bigscience/bloom-7b1 | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.5725 | 0.5587 | -3.07% |
| mosaicml/mpt-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7056 | 0.7035 | -0.30% |
-| mosaicml/mpt-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6831 | 0.6856 | +0.37% |
+| mosaicml/mpt-7b | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6831 | 0.6856 | -2.83% |
| mosaicml/mpt-7b-chat | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.6550 | 0.6561 | +0.17% |
-| mosaicml/mpt-7b-chat | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6456 | 0.6451 | -0.08% |
+| mosaicml/mpt-7b-chat | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.6456 | 0.6451 | -1.51% |
+| meta-llama/Llama-2-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | FP32 | 0.7679 | 0.7629 | -0.65% |
+| meta-llama/Llama-2-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 40% | BF16 | 0.7667 | 0.7601 | -1.02% |
| decapoda-research/llama-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 50% | FP32 | 0.7627 | 0.7559 | -0.89% |
-| decapoda-research/llama-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 50% | BF16 | 0.7599 | 0.7559 | -0.53% |
+| decapoda-research/llama-13b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 50% | BF16 | 0.7599 | 0.7559 | -0.89% |
+| meta-llama/Llama-2-70b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 60% | FP32 | 0.7964 | 0.7951 | -0.16% |
+| meta-llama/Llama-2-70b-hf | CLM | wikitext-2-raw-v1 | lambada_openai | 60% | BF16 | 0.7937 | 0.7943 | -0.26% |
+| Qwen/Qwen-72B | CLM | wikitext-2-raw-v1 | lambada_openai | 60% | FP32 | - | - | - |
+| Qwen/Qwen-72B | CLM | wikitext-2-raw-v1 | lambada_openai | 60% | BF16 | 0.7673 | 0.7813 | - |
+
## References
@@ -102,4 +114,3 @@ The last word acc of the 1x1 pattern sparse model using [the sparseGPT script](h
-
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt
index b58987bddea..5b0d8f0bd00 100644
--- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/requirements.txt
@@ -1,9 +1,13 @@
accelerate
datasets
+einops
+intel_extension_for_transformers
+optimum
+peft
sentencepiece
-transformers
+transformers==4.36.0
torch
tqdm
-optimum
-einops
-
+tiktoken
+transformers_stream_generator
+git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_sparsegpt.py b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_sparsegpt.py
index 056496d968a..970580dc63b 100644
--- a/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_sparsegpt.py
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_sparsegpt.py
@@ -5,10 +5,10 @@
import math
import os
import sys
-sys.path.insert(0, './neural-compressor')
sys.path.insert(0, './')
-
+sys.path.insert(0, './neural-compressor')
import random
+import re
import numpy as np
from itertools import chain
from pathlib import Path
@@ -36,23 +36,24 @@ def skip(*args, **kwargs):
import transformers
from transformers import (
+ AutoModelForCausalLM,
+ AutoModel,
CONFIG_MAPPING,
MODEL_MAPPING,
AutoConfig,
- AutoModelForCausalLM,
- OPTForCausalLM,
AutoTokenizer,
SchedulerType,
default_data_collator,
- get_scheduler,
)
-from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
+
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
-from neural_compressor.training import prepare_compression
-from neural_compressor.training import WeightPruningConfig
from timers import CPUTimer, GPUTimer
-from neural_compressor.compression.pruner import model_slim
-from neural_compressor.compression.pruner import parse_auto_slim_config
+from neural_compressor.training import WeightPruningConfig
+from neural_compressor.compression.pruner import (prepare_pruning,
+ parse_auto_slim_config)
+from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
+
check_min_version("4.27.0.dev0")
logger = logging.getLogger(__name__)
@@ -64,99 +65,6 @@ def skip(*args, **kwargs):
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
-class Evaluator:
- def __init__(self, dataset, tokenizer, device, batch_size=16):
- self.dataset = dataset
- self.tokenizer = tokenizer
- self.device = device
- self.dataloader = INCDataloader(dataset, tokenizer, self.device, batch_size)
-
- @torch.no_grad()
- def evaluate(self, model):
- model.eval()
- # The task is to predict the last word of the input.
- total, hit = 0, 0
- if torch.cuda.is_available():
- my_timer = GPUTimer(timelogs = [])
- else:
- my_timer = CPUTimer(timelogs = [])
- warmup_steps = 10
- step = 0
- for input_ids, label, label_indices in tqdm(self.dataloader):
- with torch.no_grad():
- # if step == 0:
- # model = torch.jit.trace(model, input_ids)
- step += 1
- # timing
- if step > warmup_steps: my_timer.__enter__()
- outputs = model(input_ids)
- if step > warmup_steps: my_timer.__exit__()
- last_token_logits = outputs[0][torch.arange(len(label_indices)), label_indices, :]
- pred = last_token_logits.argmax(dim=-1)
- total += label.size(0)
- hit += (pred == label).sum().item()
- if step % 100 == 0:
- logger.info(f"eval step:{step} accuracy:{float(hit/total)}")
- avg_latency = my_timer.get_avg_time()
- del my_timer
- accuracy = hit / total
- return accuracy, avg_latency
-
-
-class INCDataloader():
- def __init__(self, dataset, tokenizer, device, batch_size=1):
- self.dataset = dataset
- self.tokenizer = tokenizer
- self.device = device
- self.batch_size = batch_size
- import math
- self.length = math.ceil(len(dataset) / self.batch_size)
- self.pad_len = 196
-
- # tokenize the dataset
- def tokenize_function(examples):
- example = self.tokenizer(examples['text'])
- return example
-
- self.dataset = self.dataset.map(tokenize_function, batched=True)
- self.dataset.set_format(type='torch', columns=['input_ids'])
-
- def pad_input(self, input):
- input_id = input['input_ids'].unsqueeze(0).to(self.device)
- label = input_id[:, -1].to(self.device)
- pad_len = self.pad_len - input_id.shape[1]
- label_index = -2 - pad_len
- input_id = pad(input_id, (0, pad_len), value=1)
-
- return (input_id, label, label_index)
-
- def __iter__(self):
- input_ids = None
- labels = None
- label_indices = None
- for idx, batch in enumerate(self.dataset):
- input_id, label, label_index = self.pad_input(batch)
-
- if input_ids is None:
- input_ids = input_id
- labels = label
- label_indices = [label_index]
- else:
- input_ids = torch.cat((input_ids, input_id), 0).to(self.device)
- labels = torch.cat((labels, label), 0).to(self.device)
- label_indices.append(label_index)
-
- if (idx + 1) % self.batch_size == 0:
- yield (input_ids, labels, label_indices)
- input_ids = None
- labels = None
- label_indices = None
- if (idx + 1) % self.batch_size != 0:
- yield (input_ids, labels, label_indices)
-
- def __len__(self):
- return self.length
-
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
parser.add_argument(
@@ -165,12 +73,6 @@ def parse_args():
default="wikitext-2-raw-v1",
help="The name of the pruning dataset to use (via the datasets library).",
)
- parser.add_argument(
- "--evaluation_dataset_name",
- type=str,
- default=None,
- help="The name of the evaluation dataset to use (via the datasets library).",
- )
parser.add_argument(
"--dataset_config_name",
type=str,
@@ -328,17 +230,18 @@ def parse_args():
"If passed, LLM loading time and RAM consumption will be benefited."
),
)
+
+ ### DDP mode config
+ parser.add_argument(
+ "--local_rank",
+ type=int, default=-1,
+ help="Automatic DDP Multi-GPU argument, do not modify")
+
# pruning config
parser.add_argument(
"--do_prune", action="store_true",
help="Whether or not to prune the model"
)
- parser.add_argument(
- "--max_pruning_steps",
- type=int,
- default=None,
- help="Total number of pruning steps to perform. If provided",
- )
parser.add_argument(
"--pruning_pattern",
type=str, default="1x1",
@@ -349,15 +252,6 @@ def parse_args():
type=float, default=0.5,
help="Target sparsity of the model."
)
- parser.add_argument(
- "--pruning_frequency",
- type=int, default=-1,
- help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps."
- )
- parser.add_argument(
- "--auto_slim", action="store_true",
- help="Whether or not to auto slim the model after pruning."
- )
parser.add_argument(
"--auto_config", action="store_true",
help="Whether to automatically generate pruning configs."
@@ -371,15 +265,16 @@ def parse_args():
"--trust_remote_code", default=True,
help="Transformers parameter: use the external repo")
- ### DDP mode config
- parser.add_argument(
- "--local_rank",
- type=int, default=-1,
- help="Automatic DDP Multi-GPU argument, do not modify")
-
- parser.add_argument("--eval_fp16", action='store_true',
- help=" fp16")
-
+ # Evaluation config
+ parser.add_argument("--tasks", default=["lambada_openai"],
+ help="Usually chosen with ['lambada_openai','hellaswag','winogrande','piqa']",
+ )
+ parser.add_argument("--use_accelerate", action='store_true',
+ help="Usually use to accelerate evaluation for large models"
+ )
+ parser.add_argument("--eval_dtype", default='fp32',
+ help="choose in bf16, fp16 and fp32"
+ )
args = parser.parse_args()
@@ -484,14 +379,14 @@ def main():
logger.warning("You are instantiating a new config instance from scratch.")
is_llama = bool("llama" in args.model_name_or_path)
- is_t5 = bool("t5" in args.model_name_or_path)
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
elif args.model_name_or_path:
if is_llama:
tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name_or_path)
else :
- tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path,
+ use_fast=not args.use_slow_tokenizer, trust_remote_code=True)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
@@ -499,11 +394,9 @@ def main():
)
if args.model_name_or_path:
- if is_t5:
- model = T5ForConditionalGeneration.from_pretrained(
- args.model_name_or_path,
- config=config,
- )
+ if re.search("chatglm", args.model_name_or_path.lower()):
+ model = AutoModel.from_pretrained(args.model_name_or_path,
+ trust_remote_code=args.trust_remote_code) # .half()
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
@@ -621,7 +514,8 @@ def group_texts(examples):
pruning_configs=[
{
"pruning_type": "sparse_gpt",
- "op_names": [".attn", "_proj", ".fc", "key", "dense", "_h"],
+ "op_names": [".*"],
+ "excluded_op_names": ["lm_head", "embed_out"],
}
]
else:
@@ -640,73 +534,60 @@ def group_texts(examples):
target_sparsity=args.target_sparsity,
pattern=args.pruning_pattern,
)
-
+
+ device = args.device
+ if device != 'cpu':
+ device = "cuda:"+str(device)
+
if args.do_prune:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
use_cache = model.config.use_cache
model.config.use_cache = False
- # if torch.cuda.is_available(): # Larger models(e.g. 80G+) may not load into the video card memory.
- # model = model.cuda()
- device = args.device
- if device != 'cpu':
- device = "cuda:"+str(device)
- from neural_compressor.training import prepare_pruning
- pruning = prepare_pruning(model, configs, dataloader=train_dataloader, device=device)
+
+ pruning = prepare_pruning(model, configs, dataloader=train_dataloader, device=device)
model.config.use_cache = use_cache
if args.output_dir is not None:
###TODO set ddp save method
output_dir = args.output_dir
- if args.auto_slim:
- output_dir += "/before_slim"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
logger.info(f"The model has been exported to {output_dir}")
if device != 'cpu':
- model = model.to(device)
+ if not args.use_accelerate:
+ model = model.to(device)
+ else:
+ model = model.cpu()
logger.info(f"***** Evaluation in GPU mode. *****")
else:
logger.info(f"***** Evaluation in CPU mode. *****")
model.eval()
- if args.evaluation_dataset_name != None:
- dataset_eval = load_dataset(
- # for example:use the_pile's validation set for pruning, and lambada dataset for eval
- args.evaluation_dataset_name,
- args.dataset_config_name,
- split=f"validation",
- )
- else:
- dataset_eval = raw_datasets["validation"]
- dataset_eval = dataset_eval.shuffle(seed=42)
- evaluator = Evaluator(dataset_eval, tokenizer, model.device, batch_size=args.per_device_eval_batch_size)
- def eval_func(model):
- acc, avg_latency = evaluator.evaluate(model)
- return acc, avg_latency
- if not args.auto_slim:
- # only eval
- logger.info(f"***** Running Evaluation *****")
- acc, _ = eval_func(model)
- logger.info(f"pruned model accuracy:{acc}")
+ model_name = args.model_name_or_path
+ dtype = None
+ if args.eval_dtype == 'bf16':
+ model = model.to(dtype=torch.bfloat16)
+ dtype = 'bfloat16'
+ elif args.eval_dtype == 'fp16':
+ dtype = 'float16'
+ model = model.to(dtype=torch.float16)
else:
- logger.info(f"***** Running Evaluation before ffn auto slim*****")
- accuracy, avg_latency = eval_func(model)
- logger.info(f"accuracy:{accuracy} avg_latency:{avg_latency}")
- model = model_slim(model, round_multiplier=32)
-
- logger.info(f"***** Running Evaluation after ffn auto_slim*****")
- accuracy, avg_latency = eval_func(model)
- logger.info(f"accuracy:{accuracy} avg_latency:{avg_latency}")
-
- if RANK in {-1, 0}:
- if args.output_dir is not None and args.auto_slim:
- model.to('cpu')
- torch.save(model, args.output_dir+"/slimed_model.pt")
- tokenizer.save_pretrained(args.output_dir)
- if args.push_to_hub:
- repo.push_to_hub(commit_message="End of auto slim", auto_lfs_prune=True)
+ dtype = 'float32'
+ model = model.to(dtype=torch.float32)
+
+ model_args = f'pretrained={model_name},tokenizer={model_name},dtype={dtype},use_accelerate={args.use_accelerate},trust_remote_code={args.trust_remote_code}'
+ eval_batch = args.per_device_eval_batch_size
+ user_model = None if args.use_accelerate else model
+ results = evaluate(
+ model="hf-causal",
+ model_args=model_args,
+ user_model=user_model,
+ batch_size=eval_batch,
+ tasks=args.tasks,
+ device=device,
+ )
if __name__ == "__main__":
main()
diff --git a/neural_compressor/compression/pruner/pruning.py b/neural_compressor/compression/pruner/pruning.py
index e31847eec4d..1ded7e5f0c9 100644
--- a/neural_compressor/compression/pruner/pruning.py
+++ b/neural_compressor/compression/pruner/pruning.py
@@ -205,37 +205,41 @@ def _do_pruning(self):
layers = self._layers
self._model = self._model.cpu()
- inputs, inp_dict = collect_layer_inputs(
+ inputs, positional_inputs, other_input_infos = collect_layer_inputs(
model=self._model, layers=layers, layer_idx=0, layer_inputs=self._dataloader, device=self.dev
)
-
- with torch.no_grad():
- for i in tqdm(range(len(layers))):
- layer = layers[i].to(self.dev)
- layer_index_str = "." + str(i) + "."
- handles_list = []
- for pruner in self.pruners:
- layer_op_names = [key for key in pruner.modules.keys() if layer_index_str in key]
+ for i in tqdm(range(len(layers))):
+ layer = layers[i].to(self.dev)
+ layer_index_str = "." + str(i) + "."
+ handles_list = []
+ for pruner in self.pruners:
+ layer_op_names = [key for key in pruner.modules.keys() if layer_index_str in key]
+ if bool(layer_op_names):
handles_list.append(pruner.register_gpt_hook(layer_op_names))
+ prune_flag = bool(handles_list)
+ if prune_flag:
for j in range(len(inputs)):
- input_infos = self.gather_single_batch_from_dict(inp_dict, j)
- layer(inputs[j], **input_infos)[0]
+ other_infos = self.gather_single_batch_from_dict(other_input_infos, j)
+ with torch.no_grad():
+ layer(inputs[j], *positional_inputs, **other_infos)[0]
for handles in handles_list:
for h in handles:
h.remove()
for pruner in self.pruners:
layer_op_names = [key for key in pruner.modules.keys() if layer_index_str in key]
pruner.fasterprune(layer_op_names)
- for j in range(len(inputs)):
- # the weights of current layer have been pruned, get the latest outputs as the inputs for next layer
- input_infos = self.gather_single_batch_from_dict(inp_dict, j)
- inputs[j] = layer(inputs[j], **input_infos)[0]
- layers[i] = layer.cpu()
- if "cuda" in self.dev.type:
- torch.cuda.empty_cache()
- del inp_dict
- del inputs
- gc.collect()
+ for j in range(len(inputs)):
+ # the weights of current layer have been pruned, get the latest outputs as the inputs for next layer
+ other_infos = self.gather_single_batch_from_dict(other_input_infos, j)
+ with torch.no_grad():
+ inputs[j] = layer(inputs[j], *positional_inputs, **other_infos)[0]
+ layers[i] = layer.cpu()
+ if "cuda" in self.dev.type:
+ torch.cuda.empty_cache()
+ del other_infos
+ del positional_inputs
+ del inputs
+ gc.collect()
if "cuda" in self.dev.type:
torch.cuda.empty_cache()
diff --git a/neural_compressor/compression/pruner/utils.py b/neural_compressor/compression/pruner/utils.py
index bc43a0e02ab..d31bdb231b6 100644
--- a/neural_compressor/compression/pruner/utils.py
+++ b/neural_compressor/compression/pruner/utils.py
@@ -672,31 +672,17 @@ def unfoldLayer(module):
return layers
-def move_input_to_device(input, device):
- if device is None:
- device = torch.device("cpu")
- elif isinstance(device, str):
- device = torch.device(device)
-
+def move_input_to_device(input, device="cpu"):
+ if isinstance(input, torch.Tensor):
+ return input.to(device)
if isinstance(input, dict) or isinstance(input, UserDict):
for inp in input.keys():
- input[inp] = input[inp].to(device) if isinstance(input[inp], torch.Tensor) else input[inp]
+ input[inp] = move_input_to_device(input[inp], device)
elif isinstance(input, list) or isinstance(input, tuple):
- input_res, prev_size = [], None
+ input_res = []
for inp in input:
- if prev_size:
- if isinstance(inp, torch.Tensor):
- if inp.size() == prev_size:
- input_res.append(inp.to(device))
- else:
- if torch.tensor(inp).size == prev_size:
- input_res.append(inp)
- else:
- input_res.append(inp.to(device) if isinstance(inp, torch.Tensor) else inp)
- prev_size = torch.tensor(inp).size()
+ input_res.append(move_input_to_device(inp, device))
input = input_res
- else:
- input = input.to(device) # pylint: disable=no-member
return input
@@ -712,25 +698,26 @@ def collect_layer_inputs(model, layers, layer_idx, layer_inputs, device="cuda:0"
Returns: input list.
"""
inputs = []
+ other_input_infos = {}
+ positional_inputs = []
model_dev = model.device
- attention_mask = None
- # 'alibi' is a necessary attribute for the bloom models
- inputs_info = {}
with torch.no_grad():
if layer_idx == 0:
layer = layers[layer_idx]
- def forward(self, hidden_states, **kwargs):
+ def forward(_, hidden_states, *positional_args, **kwargs):
+ nonlocal inputs
+ nonlocal positional_inputs
+ nonlocal other_input_infos
# TODO solve the problem of batchsize!=1
- inputs.append(hidden_states.to(device))
+ inputs.append(move_input_to_device(hidden_states, device))
+ if len(positional_inputs) <= 0:
+ positional_inputs = move_input_to_device(positional_args, device)
for key in kwargs.keys():
- if isinstance(kwargs[key], torch.Tensor) or (key == "alibi"):
- if key not in inputs_info.keys():
- inputs_info[key] = []
- if isinstance(kwargs[key], torch.Tensor):
- kwargs[key] = kwargs[key].to(device)
- inputs_info[key].append(kwargs[key])
+ if key not in other_input_infos.keys():
+ other_input_infos[key] = []
+ other_input_infos[key].append(move_input_to_device(kwargs[key], device))
raise ValueError
forward_cache = layers[layer_idx].forward
@@ -749,7 +736,6 @@ def forward(self, hidden_states, **kwargs):
except ValueError:
pass
layer.forward = forward_cache
-
else:
prev_layer = layers[layer_idx - 1]
@@ -758,4 +744,4 @@ def forward(self, hidden_states, **kwargs):
batch[0] = prev_output[0]
inputs.append(batch)
- return inputs, inputs_info
+ return inputs, positional_inputs, other_input_infos