Skip to content
/ GrETTA Public

Gradient-Estimation Test-time Augmentation(Unofficial)

Notifications You must be signed in to change notification settings

agil27/GrETTA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GrETTA: Gradient-Estimation Test-Time data Augmentation

Yuanbiao Wang, Jiancheng Yang, Zudi Lin

If you have any questions about this project, contact me via [email protected]

Usage

Environment

Please install the latest version of the following software & packages

Python3
PyTorch
torchvision
kornia
efficientnet_pytorch

Pretrained weights

In our experiments, we need pre-trained weights for ResNeXT and WideResNet trained on CIFAR-100

We provide the model weights for these large conv nets, as well as distillated smaller student model, which will be handy for gradient estimation. You can download them via Google Drive.

Dataset

We are using CIFAR100 and the corrupted version CIFAR100-C to test robustness. You can download the CIFAR100-C dataset via [Google Drive]

The original version should be automatically handled by torchvision

Training

Designate the arguments in utils.py, in the function default_arg

The original arguments are

{
    'model': 'resnext',
    'augmix': True,
    'opt': 'adam',
    'lr': 1e-6,
    'init': 'zero',
    'epochs': 2,
    'reg': 'none',
    'policy': 'resnet18',
    'sigmoid': False,
    'est': 'vanilla',
    'transform': 'geometry',
    'num_samples': 12,
    'paths': {
        'wideresnet': 'model/model_wrn_best.pth.tar',
        'resnext': 'model/model_resnext_best.pth.tar'
    },
    'student_paths': {
        'wideresnet': 'model/wrn_student_augmix.pth',
        'resnext': 'model/resnext_student_augmix.pth'
    },
    'checkpoints': 'checkpoints',
    'cifar': 'cifar',
    'cifarc': 'cifar_corrupt/CIFAR-100-C'
}

Under the main directory, run

python3 train.py

Result

  1. Experiments on Modified MNIST

    Visualization

  2. Experiments on CIFAR100-C robustness benchmark

    ResNeXT as classification model

    WideResNet as classification model

    visualization

  3. ChestXRay100 to CheXpert transfer learning test

About

Gradient-Estimation Test-time Augmentation(Unofficial)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages