Skip to content

(Pattern Recognition Letters 2023) Pytorch implementation of "Jigsaw-ViT: Learning Jigsaw Puzzles in Vision Transformer"

Notifications You must be signed in to change notification settings

darrenchang/JigsawViT

 
 

Repository files navigation

Jigsaw-ViT

PyTorch implementation of Jigsaw-ViT, accepted by Pattern Recognition Letters (2023):

Jigsaw-ViT: Learning Jigsaw Puzzles in Vision Transformer

by Yingyi Chen, Xi Shen, Yahui Liu, Qinghua Tao, Johan A.K. Suykens

[arXiv] [PDF] [Project Page]

If our project is helpful for your research, please consider citing :

@article{chen2022jigsaw,
  author={Chen, Yingyi and Shen, Xi and Liu, Yahui and Tao, Qinghua and Suykens, Johan A. K.},
  title={Jigsaw-ViT: Learning Jigsaw Puzzles in Vision Transformer},
  journal={Pattern Recognition Letters},
  volume = {166},
  pages = {53-60},
  year = {2023},
  publisher={Elsevier}
}

Table of Content

1. Installation

Firtst, clone the repository locally:

git clone https://github.com/yingyichen-cyy/JigsawViT.git

Then, install PyTorch 1.7.1 and torchvision 0.8.2 and timm, einops, munch:

conda create -n jigsaw-vit
conda activate jigsaw-vit
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
pip install timm==0.4.12
pip install einops
pip install munch

For experiments on Swin backbone, please refer to here for setting environment.

2. Model Zoo

We release our jigsaw models with different backbones.

ImageNet-1K

Models with DeiT backbones are trained from scratch for 300 epochs. Models with Swin backbones are finetuned base on the official pretrained checkpoints on ImageNet-22K for 30 epochs as in https://github.com/microsoft/Swin-Transformer.

backbone resolution #params pretrain Baseline acc@1 Jigsaw-ViT acc@1 Download
DeiT-S/16 224x224 22M - 79.8 80.5 model
DeiT-B/16 224x224 86M - 81.8 82.1 model
Swin-S 224x224 50M Swin-S ImageNet-22K 83.2 83.6 model
Swin-B 224x224 88M Swin-B ImageNet-22K 85.2 85.0 model

Clothing1M

Note that Nested Co-teaching (NCT) (Chen et al., 2022) is used for experiments on Clothing1M dataset. In the first stage of NCT, two independent models (base models) are trained from scratch on Clothing1M training set. In the second stage, two base models are finetuned together with Co-teaching.

Model backbone role acc@1 Download
Jigsaw-ViT DeiT-S/16 NCT base model 1 72.4 model
Jigsaw-ViT DeiT-S/16 NCT base model 2 72.7 model
Jigsaw-ViT+NCT DeiT-S/16 NCT 75.4 model
Jigsaw-ViT+NCT* DeiT-S/16 NCT 75.6 model
* This is the current best Jigsaw-ViT+NCT model we have on Clothing1M.

3. Get Started on ImageNet

3.1 Data Preparation

Download and extract ImageNet train and val images from http://image-net.org/. Please refer to here for extraction commands. The directory structure is the standard layout as in DeiT, and the training and validation data is expected to be in the train/ folder and val/ folder respectively:

/path/to/imagenet/
  train/
    n01440764/
    n01755581/
    n02012849/
    ...
  val/
    n01440764/
    n01755581/
    n02012849/
    ...

3.2 Training

Go to the ImageNet folder

cd imagenet/

3.2.1 DeiT as Backbone

To train Jigsaw-ViT with DeiT-Base/16 as backbone on a single node with 8 GPUs for 300 epochs run:

cd jigsaw-deit/
# run Jigsaw-ViT with DeiT-Base/16 backbone
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_jigsaw.py --model jigsaw_base_patch16_224 --batch-size 128 --data-path ./imagenet --lambda-jigsaw 0.1 --mask-ratio 0.5 --output_dir ./jigsaw_base_results

Note that you can simply change to backbone DeiT-Small/16 by changing:

# run Jigsaw-ViT with DeiT-Small/16 backbone
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_jigsaw.py --model jigsaw_small_patch16_224 --batch-size 128 --data-path ./imagenet --lambda-jigsaw 0.1 --mask-ratio 0.5 --output_dir ./jigsaw_small_results

Please refer to /imagenet/jigsaw-deit/bashes/ for more details.

3.2.2 Swin as Backbone

To finetune Jigsaw-ViT with ImageNet-22K pretrained Swin-Base as backbone on a single node with 8 GPUs for 30 epochs run:

cd jigsaw-swin
# fine-tune a Jigsaw-Swin-B model pre-trained on ImageNet-22K(21K):
bash ./bashes/run_jigsaw_swin_base.sh

To finetune Jigsaw-ViT with ImageNet-22K pretrained Swin-Small as backbone on a single node with 8 GPUs for 30 epochs run:

# fine-tune a Jigsaw-Swin-S model pre-trained on ImageNet-22K(21K):
bash ./bashes/run_jigsaw_swin_small.sh

Please refer to /imagenet/jigsaw-swin/bashes/ for more details.

3.3 Evaluation

For evaluation of the Jigsaw-ViT model, please first download our model from here.

cd /imagenet/jigsaw-deit
# please change --model, --data-path and --resume accordingly.
python main_jigsaw.py --eval --model jigsaw_base_patch16_224 --resume ./jigsaw_base_results/best_checkpoint.pth --data-path ./imagenet

For evaluation of the Jigsaw-Swin model, please first download our model from here.

cd /imagenet/jigsaw-swin
# please change --cfg, --data-path and --resume accordingly.
python -m torch.distributed.launch --nproc_per_node 2 --master_port 12345 main_jigsaw.py --eval --cfg configs/jigsaw_swin/jigsaw_swin_base_patch4_window7_224_22kto1k_finetune.yaml --resume ./jigsaw_swin_base_patch4_window7_224_22kto1k_finetune/ckpt_best.pth --data-path ./imagenet

4. Get Started on Noisy label Datasets

4.1 Data Preparation

Please refer to the following links for downloading and preprossing the datasets.

Dataset download and preprocess
Animal-10N /noisy-label/data/animal.sh
Food-101N /noisy-label/data/food101n.sh
Clothing1M /noisy-label/data/clothing.sh

4.2 Training

Go to the noisy-label folder

cd noisy-label/

4.2.1 Training of Jigsaw-ViT

We train one Jigsaw-ViT for each dataset. The model can be trained on 2 NVIDIA Tesla P100-SXM2-16GB GPUs. Please refer to the following links for training details.

Dataset download and preprocess
Animal-10N /noisy-label/jigsaw-vit/bashes/Animal10N_eta2_mask02.sh
Food-101N /noisy-label/jigsaw-vit/bashes/Food101N_eta1_mask05.sh
Clothing1M /noisy-label/jigsaw-vit/bashes/Clothing1M_eta2_mask01.sh
For example:
mkdir LabelNoise_checkpoint
# please change --data-path accordingly
bash bashes/Animal10N_eta2_mask02.sh

4.2.2 Training of Jigsaw-ViT+NCT

To finetune Jigsaw-ViT+NCT on Clothing1M where NCT is the Nested Co-teaching, we first download the two base trained Jigsaw-ViT models from NCT base model 1, NCT base model 2.

Then, we finetune Jigsaw-ViT+NCT on Clothing1M:

cd jigsaw-vit-NCT
mkdir finetune_nested
# please change --data-path and --resumePthList accordingly in run_finetune.sh
bash run_finetune.sh

4.3 Evaluation

For evaluation of the Jigsaw-ViT model, please first download our best model from here.

cd noisy-label/jigsaw-vit
# Please change "--data-path" and "--resumePth" to your own paths
python3 eval.py --data-path ../data/Clothing1M/ --data-set Clothing1M --arch vit_small_patch16 --resumePth ./pretrained/Clothing1M_eta1_iter400k_wd005_mask025_smallP16_aug_warm20k_Acc0.724 --gpu 0

For evaluation of the Jigsaw-ViT+NCT model, please first download our best model from here. Then please follow:

cd noisy-label/jigsaw-vit-NCT
# Please change "--data-path" and "--resumePth" to your own paths
python3 eval.py --data-path ../data/Clothing1M/ --data-set Clothing1M --arch vit_small_patch16 --resumePth ./finetune_nested/Clothing1M_eta1_iter50k_wd005_mask075_smallP16_aug_bs96_warm0_fgr0.2_lr5e-5_nested100_Acc0.756 --best-k 34 --gpu 0

5. Get Started on Adversarial Examples

We compare DeiT-S/16 and Jigsaw-ViT with the same backbone on adversarial examples. Both models are trained on ImageNet-1K from scratch and checkpoints are provided: Jigsaw-ViT, DeiT-S/16.

5.1 In Black-box Settings

Please go to adversarial-examples/bashes/run_transfer_attack.sh for black-box attacks. Here is one example:

cd adversarial-examples
# please change --data-path and --resumePth1, --resumePth2 accordingly.
# ViT-S/16 as surrogate model
python3 transfer_attack.py --arch vit --attack-type FGSM --data-path ./data/imagenet/ --resumePth1 ./pretrained/imagenet-deit_small_patch16_224-org-acc78.85.pth --resumePth2 ./pretrained/imagenet-deit_small_patch16_224-jigsaw-eta0.1-r0.5-acc80.51.pth --gpu 0
# ResNet-152 as surrogate model
python3 transfer_attack.py --arch resnet152 --attack-type FGSM --data-path ./data/imagenet/ --resumePth1 ./pretrained/imagenet-deit_small_patch16_224-org-acc78.85.pth --resumePth2 ./pretrained/imagenet-deit_small_patch16_224-jigsaw-eta0.1-r0.5-acc80.51.pth --gpu 0

Note that for square attack, which is a black-box attack, we write it in white_attack.py only for convenience. Commands can be found in /adversarial-examples/bashes/run_square_attack.sh.

5.2 In White-box Settings

Please go to /adversarial-examples/bashes/run_white.sh for white-box attacks. Here is one example:

# please change --data-path and --resumePth accordingly.
python3 white_attack.py --attack-type FGSM --eps 0.125 --data-path ./data/imagenet/ --resumePth ./pretrained/imagenet-deit_small_patch16_224-org-acc78.85.pth --gpu 0
python3 white_attack.py --attack-type FGSM --eps 0.125 --data-path ./data/imagenet/ --resumePth ./pretrained/imagenet-deit_small_patch16_224-jigsaw-eta0.1-r0.5-acc80.51.pth --gpu 0

6. Acknowledgement

This repository is based on the official codes of DeiT, Swin, MAE, Nested Co-teaching and Adversarial-Attacks-PyTorch.

About

(Pattern Recognition Letters 2023) Pytorch implementation of "Jigsaw-ViT: Learning Jigsaw Puzzles in Vision Transformer"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.6%
  • Other 1.4%