This repository contains official PyTorch implementation for CVPR 2023 paper TeSLA: Test-Time Self-Learning With Automatic Adversarial Augmentation by Devavrat Tomar, Guillaume Vray, Behzad Bozorgtabar, and Jean-Philippe Thiran.
Most recent test-time adaptation methods focus on only classification tasks, use specialized network architectures, destroy model calibration or rely on lightweight information from the source domain. To tackle these issues, this paper proposes a novel Test-time Self-Learning method with automatic Adversarial augmentation dubbed TeSLA for adapting a pre-trained source model to the unlabeled streaming test data. In contrast to conventional self-learning methods based on cross-entropy, we introduce a new test-time loss function through an implicitly tight connection with the mutual information and online knowledge distillation. Furthermore, we propose a learnable efficient adversarial augmentation module that further enhances online knowledge distillation by simulating high entropy augmented images. Our method achieves state-of-the-art classification and segmentation results on several benchmarks and types of domain shifts, particularly on challenging measurement shifts of medical images. TeSLA also benefits from several desirable properties compared to competing methods in terms of calibration, uncertainty metrics, insensitivity to model architectures, and source training strategies, all supported by extensive ablations.
(a) The student model is adapted on the test images by minimizing the proposed test-time objective . The high-quality soft-pseudo labels required by are obtained from the exponentially weighted averaged teacher model and refined using the proposed Soft-Pseudo Label Refinement (PLR) on the corresponding test images. The soft-pseudo labels are further utilized for teacher-student knowledge distillation via on the adversarially augmented views of the test images. (b) The adversarial augmentations are obtained by applying learned sub-policies sampled i.i.d from using the probability distribution with their corresponding magnitudes selected from . The parameters and of the augmentation module are updated by the unbiased gradient estimator of the loss computed on the augmented test images.
Fist install Anaconda (Python >= 3.8) using this link. Create the following CONDA environment by running the following command:
conda create --name TeSLA python=3.8
conda activate TeSLA
conda install pip
pip install -r requirements.txt
Activate the TeSLA environment as:
conda activate TeSLA
Dataset Name | Download Link | Extract to Relative Path |
---|---|---|
CIFAR-10C | click here | ../Datasets/cifar_dataset/CIFAR-10-C/ |
CIFAR-100C | click here | ../Datasets/cifar_dataset/CIFAR-100-C/ |
ImageNet-C | click here | ../Datasets/imagenet_dataset/ |
VisDA-C | click here | ../Datasets/visda_dataset |
Kather | click here | ../Datasets/Kather/kather2016 |
VisDA-S | click here | ../Datasets/visda_segmentation_dataset |
(MRI) Spinal Cord | click here | ../Datasets/MRI/SpinalCord |
(MRI) Prostate | click here | ../Datasets/MRI/Prostate |
Dataset Name | Download Link | Extract to Relative Path |
---|---|---|
CIFAR-10 | click here | ../Source_classifiers/cifar10 |
CIFAR-100 | click here | ../Source_classifiers/cifar100 |
ImageNet | PyTorch Default | |
VisDA-C | click here | ../Source_classifier/VisDA |
Kather | click here | ../Source_classifier/Kather |
Dataset Name | Download Link | Extract to Relative Path |
---|---|---|
VisDA-S | click here | ../Source_Segmentation/VisDA/ |
MRI (Spinal Cord and Prostate) | click here | ../Source_Segmentation/MRI/ |
Classification task on CIFAR, ImageNet, VisDA, and Kather datasets for online and offline adaptation:
(1) Common Image Corruptions: CIFAR-10C
bash scripts_classification/online/cifar10.sh
bash scripts_classification/offline/cifar10.sh
(2) Common Image Corruptions: CIFAR-100C
bash scripts_classification/online/cifar100.sh
bash scripts_classification/offline/cifar100.sh
(3) Common Image Corruptions: ImageNet-C
bash scripts_classification/online/imagenet.sh
bash scripts_classification/offline/imagenet.sh
(4) Synthetic to Real Adaptation: VisDA-C
bash scripts_classification/online/visdac.sh
bash scripts_classification/offline/visdac.sh
(5) Medical Measurement Shifts: Kather
bash scripts_classification/online/kather.sh
bash scripts_classification/offline/kather.sh
(1) GTA5 to CityScapes
bash scripts_segmentation/online/cityscapes.sh
bash scripts_segmentation/offline/cityscapes.sh
(2) Domain shifts of MRI
bash scripts_segmentation/online/spinalcord.sh
bash scripts_segmentation/offline/prostate.sh
If you find our work useful, please consider citing:
@inproceedings{tomar2023TeSLA,
title={TeSLA: Test-Time Self-Learning With Automatic Adversarial Augmentation},
author={Tomar, Devavrat and Vray, Guillaume and Bozorgtabar, Behzad and Thiran, Jean-Philippe},
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)},
year={2023}
}