- 1. Overview
- 2. To-Do List
- 3. Code Structure
- 4. Implementation Details
- 5. Prerequisites
- 6. Training
- 7. Sampling
- 8. Results
- 9. Star History
This is a re-implementation of ControlNet trained with shape masks. If you have any suggestions about this repo, please feel free to start a new issue or propose a PR.
- Update
install.sh
- Update the pre-annotated masks of COCO 2014 training set
- Regular Maintainence
shape-guided-controlnet
├── LICENSE
├── README.md
├── annotators <----- Code of annotators for shape masks
│ └── u2net_saliency_detection
├── dataset_loaders <----- Code of dataset loaders
├── examples <----- Example conditions for validation use
│ └── conditions
├── inference.py <----- Script to inference trained ControlNet model
├── runners <----- Source code of training and inference runners
│ ├── controlnet_inference_runner.py
│ └── controlnet_train_runner.py
├── train.py <----- Script to train ControlNet model
└── utils <----- Code of toolkit functions
The ControlNet model is trained on COCO 2014 dataset with 100,000 iterations, along with a batch size of 4.
Each data sample consists of an image, a descriptive caption, and a shape mask.
The image caption directly uses the official annotations (i.e., captions_train2014.json
) in COCO dataset.
To obtain the shape mask, I select an off-the-shelf saliency detection model u2net
to do the automatic annotation for each image.
Model weights of the annotator and the trained ControlNet are released at the Hugging Face repo.
- To install all the dependencies, you can use the one-click installation script
install.sh
, by simply running:
bash install.sh
- Follow the following steps to prepare the dataset:
python annotators/u2net_saliency_detection/generate_masks.py --indir COCO_TRAIN2014_IMAGES_PATH --outdir COCO_TRAIN2014_MASKS_PATH --model_dir U2NET_CHECKPOINT_PATH
You can refer to this example command line:
python annotators/u2net_saliency_detection/generate_masks.py --indir ./data/COCO2014/train2014 --outdir ./data/COCO2014/train2014_masks --model_dir ./checkpoints/u2net.pth
Once the training data is ready, it should follow the structure below:
COCO2014
├── train2014 <----- Training images
├── train2014_masks <----- Annotated shape masks
├── val2014
├── test2014
└── annotations <----- Annotation files
├── captions_train2014.json <----- We will use the annotated captions in this file
├── captions_val2014.json
├── image_info_test-dev2015.json
├── image_info_test2015.json
├── instances_train2014.json
├── instances_val2014.json
├── person_keypoints_train2014.json
└── person_keypoints_val2014.json
Or you can simply download the pre-annotated dataset from this HF dataset repo. 3. To prepare the pre-trained model weights of Stable Diffusion, you can download the model weights from our Hugging Face repo.
Once the data and pre-trained model weights are ready, you can train the ControlNet model with the following command:
python train.py --pretrained_model_name_or_path SD_V1.5_CHECKPOINTS_PATH --train_batch_size TRAIN_BATCH_SIZE --output_dir OUTPUT_DIR --image_path IMAGES_PATH --caption_path ANNOTATION_FILE_PATH --condition_path CONDITION_PATH --validation_steps VALIDATION_STEPS --validation_image VALIDATION_IMAGE --validation_prompt VALIDATION_PROMPT --checkpointing_steps CHECKPOINTING_STEPS
You can refer to the following example command line:
python train.py --pretrained_model_name_or_path ./checkpoints/stable-diffusion-v1.5 --train_batch_size 4 --output_dir ./outputs/shape-guided-controlnet --image_path ./data/COCO2014/train2014 --caption_path ./data/COCO2014/annotations/captions_train2014.json --condition_path ./data/COCO2014/train2014_masks --validation_steps 1000 --validation_image "examples/bag" "examples/sport_car.png" "examples/truck.png" --validation_prompt "a red bag" "a sport car" "a blue truck" --checkpointing_steps 1000
Note that three example conditions are included in ./examples
for the use of validating the intermediate trained model.
Once the ControlNet model is trained, you can generate images with the trained model with the following command:
python inference.py --condition_image CONDITION_IMAGE_PATH --prompt PROMPT --controlnet_model CONTROLNET_CHECKPOINTS_PATH --sd_model SD_V1.5_CHECKPOINTS_PATH --output_path OUTPUT_PATH --seed SEED
You can refer to the following example command line:
python inference.py --condition_image ./examples/conditions/sport_car.png --prompt "a sport car" --controlnet_model ./outputs/shape-guided-controlnet/checkpoint-100000/controlnet/ --sd_model ./checkpoints/stable-diffusion-v1.5/ --output_path outputs/inference/ --seed 1234
The output image will be saved to ./outputs/inference/generated_image.png
by default.
Here are some example results generated by the trained model:
- "A red bag"
- "A sport car"
- "A blue truck"