This repository is the official PyTorch implementation of Global Vision Transformer Pruning with Hessian-Aware Saliency (Also known as NViT) presented at CVPR 2023.
Huanrui Yang, Hongxu (Danny) Yin, Maying Shen, Pavlo Molchanov, Hai Li, Jan Kautz.
For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing
Useful links:
We propose NViT, a novel hardware-friendly global structural pruning algorithm enabled by a latency-aware, Hessian-based importance-based criteria and tailored towards the ViT architecture. NViT achieves a nearly lossless 1.9x speedup, significantly outperforms SOTA ViT compression methods and efficient ViT designs.
The code was tested in virtual environment with Python 3.8. The code works best with PyTorch 1.7.0+ and torchvision 0.8.1+
Besides PyTorch, pytorch-image-models 0.3.2 and einops are required. TensorboardX is used for logging.
pip install timm==0.3.2
pip install einops
pip install tensorboardX
Please see TRAINING.md for detailed pruning and finetuning instructions of all models.
NViT model checkpoints after pruning and finetuning can be loaded and evaluated on ImageNet-1K using the following command:
python -m torch.distributed.launch --nproc_per_node=8 --use_env eval_nvit.py --finetune path/to/ft_checkpoint.pth --data-path /path/to/ImageNet2012/ --batch-size 256
The dimensions of each block will be automatically set based on the information stored in the checkpoint to accommodate different pruning configurations.
Please see NViT parameter redistribution repository to evaluate our insight on ViT parameter redistribution, and to flexibly explore novel ViT designs with different dimensions in each block.
NViT ImageNet-1K Pruned Models Models pruned from the pretrained DeiT-B model. Speedup computed with respect to that of DeiT-B model, on a single V100 GPU with batch size 256.
Name | Acc@1(%) | Speedup(x) | Resolution | #Params(M) | FLOPs(G) | Download |
---|---|---|---|---|---|---|
NViT-B | 83.29 | 1.86 | 224x224 | 34 | 6.8 | model |
NViT-H | 82.95 | 2.01 | 224x224 | 30 | 6.2 | model |
NViT-S | 82.19 | 2.52 | 224x224 | 21 | 4.2 | model |
NViT-T | 76.21 | 4.97 | 224x224 | 6.9 | 1.3 | model |
Please consider citing NViT if this repository is useful for your work.
@InProceedings{Yang_2023_CVPR,
author = {Yang, Huanrui and Yin, Hongxu and Shen, Maying and Molchanov, Pavlo and Li, Hai and Kautz, Jan},
title = {Global Vision Transformer Pruning With Hessian-Aware Saliency},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2023},
pages = {18547-18557}
}
Copyright © 2023, NVIDIA Corporation. All rights reserved.
This work is made available under the NVIDIA Source Code License-NC. Click here to view a copy of this license.
The pre-trained models are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
For license information regarding the timm repository, please refer to its repository.
For license information regarding the DeiT repository, please refer to its repository.
For license information regarding the ImageNet dataset, please see the ImageNet official website.
This repository is built on top of the timm repository. We thank Ross Wrightman for creating and maintaining this high-quality library.
Part of this code is modified from the official repo of DeiT. We thank the authors for their amazing work and releasing their code base.