Skip to content

Latest commit

 

History

History
117 lines (72 loc) · 2.72 KB

README.md

File metadata and controls

117 lines (72 loc) · 2.72 KB

GrETTA: Gradient-Estimation Test-Time data Augmentation

Yuanbiao Wang, Jiancheng Yang, Zudi Lin

If you have any questions about this project, contact me via [email protected]

Usage

Environment

Please install the latest version of the following software & packages

Python3
PyTorch
torchvision
kornia
efficientnet_pytorch

Pretrained weights

In our experiments, we need pre-trained weights for ResNeXT and WideResNet trained on CIFAR-100

We provide the model weights for these large conv nets, as well as distillated smaller student model, which will be handy for gradient estimation. You can download them via Google Drive.

Dataset

We are using CIFAR100 and the corrupted version CIFAR100-C to test robustness. You can download the CIFAR100-C dataset via [Google Drive]

The original version should be automatically handled by torchvision

Training

Designate the arguments in utils.py, in the function default_arg

The original arguments are

{
    'model': 'resnext',
    'augmix': True,
    'opt': 'adam',
    'lr': 1e-6,
    'init': 'zero',
    'epochs': 2,
    'reg': 'none',
    'policy': 'resnet18',
    'sigmoid': False,
    'est': 'vanilla',
    'transform': 'geometry',
    'num_samples': 12,
    'paths': {
        'wideresnet': 'model/model_wrn_best.pth.tar',
        'resnext': 'model/model_resnext_best.pth.tar'
    },
    'student_paths': {
        'wideresnet': 'model/wrn_student_augmix.pth',
        'resnext': 'model/resnext_student_augmix.pth'
    },
    'checkpoints': 'checkpoints',
    'cifar': 'cifar',
    'cifarc': 'cifar_corrupt/CIFAR-100-C'
}

Under the main directory, run

python3 train.py

Result

  1. Experiments on Modified MNIST

    Visualization

  2. Experiments on CIFAR100-C robustness benchmark

    ResNeXT as classification model

    WideResNet as classification model

    visualization

  3. ChestXRay100 to CheXpert transfer learning test