This repository is the official Pytorch implementation for SDA.
Everything to the Synthetic: Diffusion-driven Test-time Adaptation via Synthetic-Domain Alignment
Jiayi Guo, Junhao Zhao, Chunjiang Ge, Chaoqun Du Zanlin Ni, Shiji Song, Humphrey Shi, Gao Huang
Sythetic-Domain Alignment (SDA) is a novel test-time adaptation framework that simultaneously aligns the domains of the source model and target data with the same synthetic domain of a diffusion model.
SDA is a novel two-stage TTA framework aligning both the domains of the source model and the target data with the synthetic domain. In Stage 1, the source-domain model is adapted to a synthetic-domain model through synthetic data fine-tuning. This synthetic data is first generated using a conditional diffusion model based on domain-agnostic class labels, then re-synthesized through an unconditional diffusion process to ensure domain alignment with the projected target data in Stage 2. In Stage 2, target data is projected into the synthetic domain using unconditional diffusion for synthetic-domain model prediction.
- [2023.06.07] Code, data and models released!
git clone https://github.com/SHI-Labs/Diffusion-Driven-Test-Time-Adaptation-via-Synthetic-Domain-Alignment.git
cd Diffusion-Driven-Test-Time-Adaptation-via-Synthetic-Domain-Alignment
conda env create -f environment.yml
conda activate SDA
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
mim install mmcv-full
mim install mmcls
Download ImageNet-C and generate ImageNet-W following the official repos.
For a quick evaluation, also download our re-synthesized ImageNet-C-Syn and ImageNet-W-Syn via Google Drive or Tsinghua Cloud. Place all datasets into data/
.
data
|——ImageNet-C
| |——gaussian_noise
| |——5
| |——n01440764
| |——*.JEPG
|——ImageNet-C-Syn
| |——gaussian_noise
| |——5
| |——n01440764
| |——*.JEPG
|——ImageNet-W
| |——val
| |——n01440764
| |——*.JEPG
|——ImageNet-W-Syn
| |——val
| |——n01440764
| |——*.JEPG
You can also re-synthesize the test datasets yourself following the official repo of DDA or using our align.sh.
Download pretrained checkpoints (diffusion models and classifiers) via the following command:
bash scripts/download_ckpt.sh
For a quick evaluation, also download our finetuned checkpoints via Google Drive or Tsinghua Cloud. Place the checkpoints into finetuned_ckpt/
.
We provide example commands to evaluate finetuned models on both ImageNet-C and ImageNet-W:
bash scripts/eval.sh
You can also test a customized model with the following formats:
# ImageNet-C
CUDA_VISIBLE_DEVICES=0 python eval/test_ensemble.py <config> <finetuned ckpt> \
--originckpt <pretrained ckpt> --metrics accuracy --datatype C --ensemble sda --corruption <corruption type> --data_prefix1 data/ImageNet-C --data_prefix2 data/ImageNet-C-Syn
# ImageNet-W
CUDA_VISIBLE_DEVICES=0 python eval/test_ensemble.py <config> <finetuned ckpt> \
--originckpt <pretrained ckpt> --metrics accuracy --datatype W --ensemble sda --data_prefix1 data/ImageNet-W --data_prefix2 data/ImageNet-W-Syn
You may need to set up a new config for your customized model according to our evaluation configs.
Run the following command to generate a synthetic dataset via DiT:
bash scripts/gen.sh
The synthetic dataset contains the 1000 ImageNet classes, with 50 images per class:
data
|——DiT-XL-2-DiT-XL-2-256x256-size-256-vae-ema-cfg-1.0-seed-0
| |——0000
| | |——*.png
| |——0001
| |——....
| |——9999
Run the following command to project the synthetic dataset to the domain of ADM:
bash scripts/align.sh
We also provide example commands to project target data (ImageNet-C/W) to the domain of ADM, which is the same as DDA. Check align.sh for more details.
Important: Our fine-tuning code is constructed based on MMPreTrain, which conflicts with the mmcv version used in the data alignment (Step 2) built on DDA. Therefore, it is necessary to set up a new environment before fine-tuning:
conda env create -f environment_ft.yml
conda activate SDA-FT
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
mim install mmcv
We provide fine-tuning configs for five models used in our paper. Run the following command to start synthetic data fine-tuning:
bash scripts/finetune.sh
In our implementation, we use a 30-epoch fine-tuning scheduler. Empirically, we find that 15 epochs of training is sufficient for evaluation.
If you want to fine-tune different models, please refer to mmpretrain to set up a new config.
- ImageNet-C
- ImageNet-W
- Visualization: Grad-CAM results with prediction classes and confidence scores displayed above the images.
If you find our work helpful, please star 🌟 this repo and cite 📑 our paper. Thanks for your support!
@article{guo2024sda,
title={Everything to the Synthetic: Diffusion-driven Test-time Adaptation via Synthetic-domain Alignment},
author={Jiayi Guo and Junhao Zhao and Chunjiang Ge and Chaoqun Du and Zanlin Ni and Shiji Song and Humphrey Shi and Gao Huang},
journal={arXiv},
year={2024}
}
We thank MMPretrain (model fine-tuning), DiT (data synthesis) and DDA (data alignment).
guo-jy20 at mails dot tsinghua dot edu dot cn