Skip to content

Latest commit

 

History

History
 
 

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 

Step-by-Step

This document describes the step-by-step instructions for reproducing the pruning for Huggingface models.

Prerequisite

Environment

# install dependencies
cd examples/pytorch/nlp/huggingface_models/question-answering/pruning/eager
pip install -r requirements.txt

Pruning

1. Train Sparse Model

Train a sparse model with N:M(2:4) pattern on Squad:

python3 ./run_qa_no_trainer.py \
        --model_name_or_path "/path/to/bertmini/dense_finetuned_model" \
        --dataset_name "squad" \
        --max_seq_length 384 \
        --doc_stride 128 \
        --per_device_train_batch_size 16 \
        --per_device_eval_batch_size 16 \
        --num_warmup_steps 1000 \
        --output_dir "./sparse_qa_bertmini/" \
        --do_prune \
        --target_sparsity 0.5 \
        --pruning_pattern "2:4" \
        --pruning_frequency 1000 \
        --cooldown_epochs 5 \
        --learning_rate 4.5e-4 \
        --num_train_epochs 10 \
        --weight_decay  1e-7 \
        --distill_loss_weight 4.5

NxM(4x1) as pruning pattern:

python ./run_qa_no_trainer.py \
        --model_name_or_path "/path/to/bertmini/dense_finetuned_model" \
        --dataset_name "squad" \
        --max_seq_length 384 \
        --doc_stride 128 \
        --per_device_train_batch_size 16 \
        --per_device_eval_batch_size 16 \
        --num_warmup_steps 1000 \
        --output_dir "./sparse_qa_bertmini" \
        --do_prune \
        --target_sparsity 0.8 \
        --pruning_pattern "4x1" \
        --pruning_frequency 1000 \
        --cooldown_epochs 5 \
        --learning_rate 4.5e-4 \
        --num_train_epochs 10 \
        --weight_decay  1e-7 \
        --distill_loss_weight 4.5

2. Examples

The pruning results of distilbert-base-uncased, bert-base-uncased and bert-large model can be obtained with the following:

python run_qa_no_trainer.py \
        --model_name_or_path "/path/to/distilbert-base-uncased/dense_finetuned_model" \
        --dataset_name "squad" \
        --max_seq_length 384 \
        --doc_stride 128 \
        --per_device_train_batch_size 16 \
        --per_device_eval_batch_size 16 \
        --do_prune \
        --num_warmup_steps 1000 \
        --output_dir "./sparse_qa_distilbert" \
        --weight_decay 1e-7 \
        --learning_rate 1e-4 \
        --cooldown_epochs 10 \
        --num_train_epochs 20 \
        --distill_loss_weight 3 \
        --target_sparsity 0.8 \
        --pruning_pattern "4x1" \
        --pruning_frequency 1000
python run_qa_no_trainer.py \
        --model_name_or_path "/path/to/bert-base-uncased/dense_finetuned_model/" \
        --dataset_name squad \
        --max_seq_length 384 \
        --doc_stride 128 \
        --per_device_train_batch_size 12 \
        --do_prune \
        --num_warmup_steps 1000 \
        --output_dir "./sparse_qa_bertbase" \
        --weight_decay 1e-7 \
        --learning_rate 7e-5 \
        --cooldown_epoch 4 \
        --num_train_epochs 10 \
        --distill_loss_weight 4.5 \
        --target_sparsity 0.8 \
        --pruning_pattern "4x1" \
        --pruning_frequency 1000
python run_qa_no_trainer.py \
        --model_name_or_path "/path/to/bert-large/dense_finetuned_model/" \
        --dataset_name "squad" \
        --max_seq_length 384 \
        --doc_stride 128 \
        --per_device_train_batch_size 24 \
        --per_device_eval_batch_size 24 \
        --do_prune \
        --num_warmup_steps 1000 \
        --output_dir "./sparse_qa_bertlarge" \
        --weight_decay 0\
        --learning_rate 5e-5 \
        --checkpointing_steps "epoch" \
        --cooldown_epochs 10 \
        --num_train_epochs 40 \
        --distill_loss_weight 3 \
        --target_sparsity 0.8 \
        --pruning_pattern "4x1" \
        --pruning_frequency 1000

2:4 sparsity is similar to the above, only the target_sparsity and pruning_pattern need to be changed. please refer to Pruning Scripts for details.

Dense model fine-tune is also supported as following (by setting --do_prune to False):

python ./run_qa_no_trainer.py \
    --model_name_or_path "prajjwal1/bert-mini" \
    --dataset_name "squad" \
    --max_seq_length 384 \
    --doc_stride 128 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 16 \
    --num_warmup_steps 1000 \
    --learning_rate 5e-5 \
    --num_train_epochs 5 \
    --output_dir "./dense_qa_bertmini"

Please refer to Bert-mini SQuAD for details.

Results

The snip-momentum pruning method is used by default and the initial dense models are all fine-tuned.

Model Dataset Sparsity pattern Element-wise/matmul, Gemm, conv ratio Dense F1 (mean/max) Sparse F1 (mean/max) Relative drop
Bert-mini SQuAD 4x1 0.7993 0.7662/0.7687 0.7617/0.7627 -0.78%
Bert-mini SQuAD 2:4 0.4795 0.7662/0.7687 0.7733/0.7762 +0.98%
Distilbert-base-uncased SQuAD 4x1 0.7986 0.8690 0.8615 -0.86%
Distilbert-base-uncased SQuAD 2:4 0.5000 0.8690 0.8731/0.8750 +0.69%
Bert-base-uncased SQuAD 4x1 0.7986 0.8859 0.8778 -0.92%
Bert-base-uncased SQuAD 2:4 0.5000 0.8859 0.8924/0.8940 +0.91%
Bert-large SQuAD 4x1 0.7988 0.9123 0.9091 -0.35%
Bert-large SQuAD 2:4 0.5002 0.9123 0.9167 +0.48%

References