Skip to content

Pytorch implementation of our paper accepted by IEEE TNNLS, 2021 -- Filter Sketch for Network Pruning

Notifications You must be signed in to change notification settings

lmbxmu/FilterSketch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Filter Sketch for Network Pruning (Link).

Pruning neural network model via filter sketch.

Tips

Any problem, free to contact the authors via emails: [email protected] or [email protected]. Do not post issues with github as much as possible, just in case that I could not receive the emails from github thus ignore the posted issues.

Citation

If you find FilterSketch useful in your research, please consider citing:

@article{lin2020filter,
  title={Filter Sketch for Network Pruning},
  author={Lin, Mingbao and Ji, Rongrong and Li, Shaojie and Ye, Qixiang and Tian, Yonghong and Liu, Jianzhuang and Tian, Qi},
  journal={arXiv preprint arXiv:2001.08514},
  year={2020}
}

Pre-trained Models

We provide the pre-trained models used in our paper.

CIFAR-10

| ResNet56 | ResNet110 |GoogLeNet |

ImageNet

| ResNet50 |

Result Models

We provide our pruned models in the experiments, along with their training loggers and configurations.

DataSet Sketch Rate Flops
(Prune Rate)
Params
(Prune Rate)
Top-1 Accuracy Top-5 Accuracy Download
ResNet56 CIFAR-10 [0.6]*27 73.36M(41.5%) 0.50M(41.2%) 93.19% - Link
ResNet110 CIFAR-10 [0.9]*3+[0.4]*24+[0.3]*24+[0.9]*3 92.84M(63.3%) 0.69M(59.9%) 93.44% - Link
GoogLeNet CIFAR-10 [0.25]*9 0.59B(61.1%) 2.61M(57.6%) 94.88% - Link
ResNet50 ImageNet [0.2]*16 0.93B(77.3%) 7.18M(71.8%) 69.43% 89.23% Link
ResNet50 ImageNet [0.4]*16 1.51B(63.1%) 10.40M(59.2%) 73.04% 91.18% Link
ResNet50 ImageNet [0.6]*16 2.23B(45.5%) 14.53M(43.0%) 74.68% 92.17% Link
ResNet50 ImageNet [0.7]*16 2.64B(35.5%) 16.95M(33.5%) 75.22% 92.41% Link

Performance of FilterSketch using ResNet-56 under different compression rates.

DataSet Sketch Rate Flops
(Prune Rate)
Params
(Prune Rate)
Top-1 Accuracy Download
CIFAR-10 [0.1]*27 11.43M(91.0%) 0.08M(90.45%) 87.38% Link
CIFAR-10 [0.2]*27 24.54M(80.6%) 0.16M(81.0%) 90.19% Link
CIFAR-10 [0.3]*27 35.61M(71.9%) 0.25M(70.6%) 91.65% Link
CIFAR-10 [0.4]*27 48.72M(61.5%) 0.33M(61.1%) 92.00% Link
CIFAR-10 [0.5]*27 63.78M(49.6%) 0.43M(49.8%) 92.29% Link
CIFAR-10 [0.6]*27 73.36M(41.5%) 0.50M(41.2%) 93.19% Link
CIFAR-10 [0.7]*27 87.31M(31.0%) 0.59M(31.1%) 93.36% Link
CIFAR-10 [0.8]*27 98.40M(22.3%) 0.68M(20.8%) 93.40% Link
CIFAR-10 [0.9]*27 111.5M(11.9%) 0.75M(11.3%) 93.44% Link
CIFAR-10 [0.9]*3+[0.1]*10+[0.1]*10+[0.6]*4 32.47M(74.4%) 0.24M(71.8%) 91.20% Link
CIFAR-10 [0.7]*3+[0.4]*10+[0.4]*10+[0.9]*4 62.63M(50.5%) 0.48M(43.3%) 92.94% Link
CIFAR-10 [0.8]*3+[0.5]*10+[0.8]*10+[0.9]*4 88.05M(30.4%) 0.68M(20.6%) 93.65% Link

Running Code

The code has been tested using Pytorch1.3 and CUDA10.0 on Ubuntu16.04.

Filter Sketch

You can run the following code to sketch model on Cifar-10:

python sketch_cifar.py 
--data_set cifar10 
--data_path ../data/cifar10/
--sketch_model ./experiment/pretrain/resnet56.pt 
--job_dir ./experiment/resnet56/sketch/
--arch resnet 
--cfg resnet56 
--lr 0.01
--lr_decay_step 50 100
--num_epochs 150 
--gpus 0
--sketch_rate [0.6]*27
--weight_norm_method l2

You can run the following code to sketch model on Imagenet:

python sketch_imagenet.py 
--data_set imagenet 
--data_path ../data/imagenet/
--sketch_model ./experiment/pretrain/resnet50.pth 
--job_dir ./experiment/resnet50/sketch/
--arch resnet 
--cfg resnet50 
--lr 0.1
--lr_decay_step 30 60
--num_epochs 90 
--gpus 0
--sketch_rate [0.6]*16
--weight_norm_method l2

Test Our Performance

Follow the command below to verify our pruned models:

python test.py 
--data_set cifar10 
--data_path ../data/cifar10 
--arch resnet 
--cfg resnet56 
--sketch_model ./experiment/result/sketch_resnet56.pt 
--sketch_rate [0.6]*27 
--gpus 0

Get FLOPS and Params

You can use the following command to install the thop python package when you need to calculate the flops of the model:

pip install thop
python get_flops_params.py 
--data_set cifar10 
--input_image_size 32 
--arch resnet 
--cfg resnet56
--sketch_rate [0.6]*27

Remarks

The number of pruning rates required for different networks is as follows:

CIFAR-10 ImageNet
ResNet56 27 -
ResNet110 54 -
GoogLeNet 9 -
ResNet50 - 16

Other Arguments

optional arguments:
  -h, --help            show this help message and exit
  --gpus GPUS [GPUS ...]
                        Select gpu_id to use. default:[0]
  --data_set DATA_SET   Select dataset to train. default:cifar10
  --data_path DATA_PATH
                        The dictionary where the input is stored.
                        default:/home/lishaojie/data/cifar10/
  --job_dir JOB_DIR     The directory where the summaries will be stored.
                        default:./experiments
  --arch ARCH           Architecture of model. default:resnet
  --cfg CFG             Detail architecuture of model. default:resnet56
  --num_epochs NUM_EPOCHS
                        The num of epochs to train. default:150
  --train_batch_size TRAIN_BATCH_SIZE
                        Batch size for training. default:128
  --eval_batch_size EVAL_BATCH_SIZE
                        Batch size for validation. default:100
  --momentum MOMENTUM   Momentum for MomentumOptimizer. default:0.9
  --lr LR               Learning rate for train. default:1e-2
  --lr_decay_step LR_DECAY_STEP [LR_DECAY_STEP ...]
                        the iterval of learn rate. default:50, 100
  --weight_decay WEIGHT_DECAY
                        The weight decay of loss. default:5e-4
  --start_conv START_CONV
                        The index of Conv to start sketch, index starts from
                        0. default:1
  --sketch_rate SKETCH_RATE
                        The proportion of each layer reserved after sketching
                        convolution layer sketch. default:None
  --sketch_model SKETCH_MODEL
                        Path to the model wait for sketch. default:None
  --weight_norm_method WEIGHT_NORM_METHOD
                        Select the weight norm method. default:None
                        Optional:l2

About

Pytorch implementation of our paper accepted by IEEE TNNLS, 2021 -- Filter Sketch for Network Pruning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages