💽 DGInStyle: Domain-Generalizable Semantic Segmentation with Image Diffusion Models and Stylized Semantic Control
ECCV 2024 [Project Page] | [ArXiv] | [Datasets]
By Yuru Jia, Lukas Hoyer, Shengyu Huang, Tianfu Wang, Luc Van Gool, Konrad Schindler, Anton Obukhov
We propose DGInStyle, an efficient data generation pipeline with a pretrained text-to-image latent diffusion model (LDM) at its core. The semantic segmentation models trained on our generated dataset offer improved domain generalization, drawing on the prior knowledge embedded in the LDM.
This Repository hosts the training and inference implementation for the DGInStyle data generation pipeline. To assess the effectiveness of the generated data, we provide DGInStyle-SegModel, which is designed for downstream semantic segmentation tasks.
For this project, we used python/3.8.5, cuda/12.1.1.
Clone the repository:
git clone https://github.com/yurujaja/DGInStyle.git
cd DGInStyle
export DGINSTYLE_PATH=/path/to/venv/dginstyle # change this
python3 -m venv ${DGINSTYLE_PATH}
source ${DGINSTYLE_PATH}/bin/activate
pip install -r requirements.txt
We provide a demo inference script given some example semantic conditions. The weights are available via HuggingFace repository, including the source-domain data (GTA) fine-tuned Stable Diffusion weights, ControlNet weights and SegFormer weights trained using DGInStyle. For prior-domain data (LAION-5B) pretrained Stable Diffusion weights, we adopt RunwayML's stable-diffusion-v1-5.
We provide example_data
so that you can preview the data created under specific semantic conditions. Please run the command given below. The output will be saved by default in the example_data/output
folder.
python demo.py
The weights are automatically downloaded to the Hugging Face cache by default. The location of this cache can be changed by overriding the HF_HOME environment variable:
export HF_HOME=new/path # change this
Following the common practice in domain generalization, we use GTA as the synthetic source dataset: Please download all image and label packages from here and extract them to data/gta
. Run the following script to get dataset training file:
python -m controlnet.tools.convert_dataset_file
To enable rare class sampling (rcs) component, please refer to this repository to generate rcs related files.
As a first step of our Style Swap technique, we fine-tune the original latent diffusion model's U-Net
python dreambooth/train_dreambooth.py \
--pretrained_model_name_or_path='runwayml/stable-diffusion-v1-5' \
--instance_data_dir='data/gta/images/' \
--output_dir='path/to/your/dreambooth/output' \ # change this
--instance_prompt='' \
--resolution=512 \
--train_batch_size=1 \
--sample_batch_size=1 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--learning_rate=2e-06 \
--lr_scheduler=constant \
--lr_warmup_steps=1000 \
--max_train_steps=30000 \
--mixed_precision=fp16 \
--report_to=wandb \ # choose from wandb or tensorboard
--validation_prompt='roads building sky person vegetation car' \ # change this
--validation_steps=200 \
--checkpointing_steps=1000
The above resulting U-Net
export U_S_MODEL_DIR='path/to/your/dreambooth/output' # change this
export OUTPUT_DIR='path/to/your/controlnet/output' # change this
export DATASET_FILE='data/gta/dataset_file.json'
accelerate launch controlnet/train.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--inference_basemodel_path='runwayml/stable-diffusion-v1-5' \
--output_dir=$OUTPUT_DIR \
--dataset_file=$DATASET_FILE \
--learning_rate=1e-05 \
--train_batch_size=2 \
--gradient_accumulation_steps=4 \
--checkpointing_steps=2000 \
--validation_steps=1000 \
--report_to='wandb' \
--rcs_enabled \
--rcs_data_root='data/gta/' \
--tracker_project_name='dginstyle_training' \
--validate_file='controlnet/tools/example_validation.jsonl'
This repository is based on the following open-source projects. We thank their authors for making the source code publicly available.
This project is released under the Apache License 2.0, while some specific features in this repository are with other licenses. Please refer to LICENSE for the careful check, if you are using our code for commercial matters.
@inproceedings{jia2024dginstyle,
title = {DGInStyle: Domain-Generalizable Semantic Segmentation with Image Diffusion Models and Stylized Semantic Control},
author = {Yuru Jia and Lukas Hoyer and Shengyu Huang and Tianfu Wang and Luc Van Gool and Konrad Schindler and Anton Obukhov},
booktitle = {European Conference on Computer Vision},
year = {2024},
}