Shu Zhang*1, Xinyi Yang*1, Yihao Feng*1, Can Qin3, Chia-Chih Chen1, Ning Yu1, Zeyuan Chen1, Huan Wang1, Silvio Savarese1,2, Stefano Ermon2, Caiming Xiong1, and Ran Xu1
1Salesforce AI, 2Stanford University, 3Northeastern University
*denotes equal contribution
arXiv 2023
This is a PyTorch implementation of HIVE: Harnessing Human Feedback for Instructional Visual Editing. The major part of the code follows InstructPix2Pix. In this repo, we have implemented both stable diffusion v1.5-base and stable diffusion v2.1-base as the backbone.
- 07/08/23: Training code and training data is public.😊
First set-up the hive
enviroment and download the pretrianed model as below. This is only verified on CUDA 11.0 and CUDA 11.3 with NVIDIA A100 GPU.
conda env create -f environment.yaml
conda activate hive
bash scripts/download_checkpoints.sh
To fine-tune a stable diffusion model, you need to obtain the pre-trained stable diffusion models following their instructions. If you use SD-V1.5, you can download the huggingface weights HuggingFace SD 1.5. If you use SD-V2.1, the weights can be downloaded on HuggingFace SD 2.1. You can decide which version of checkpoint to use. We use v2-1_512-ema-pruned.ckpt
. Download the model to checkpoints/.
We suggest to install Gcloud CLI following Gcloud download. To obtain both training and evaluation data, run
bash scripts/download_hive_data.sh
An alternative method is to directly download the data through Evaluation data and Evaluation instructions.
For SD v2.1, we run
python main.py --name step1 --base configs/train_v21_base.yaml --train --gpus 0,1,2,3,4,5,6,7
Samples can be obtained by running the command.
For SD v2.1, if we use the conditional reward, we run
python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" --ckpt checkpoints/hive_v2_rw_condition.ckpt \
--config configs/generate_v21_base.yaml
or run batch inference on our inference data:
python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv21_rw_label/ --ckpt checkpoints/hive_v2_rw_condition.ckpt \
--config configs/generate_v21_base.yaml --image_dir data/evaluation/
For SD v2.1, if we use the weighted reward, we can run
python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_v2_rw.ckpt --config configs/generate_v21_base.yaml
or run batch inference on our inference data:
python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv21/ --ckpt checkpoints/hive_v2_rw.ckpt \
--config configs/generate_v21_base.yaml --image_dir data/evaluation/
For SD v1.5, if we use the conditional reward, we can run
python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml
or run batch inference on our inference data:
python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv15_rw_label/ \
--ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml \
--image_dir data/evaluation/
For SD v1.5, if we use the weighted reward, we run
python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 --input imgs/example1.jpg \
--output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml
or run batch inference on our inference data:
python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv15/ \
--ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml \
--image_dir data/evaluation/
@article{zhang2023hive,
title={HIVE: Harnessing Human Feedback for Instructional Visual Editing},
author={Zhang, Shu and Yang, Xinyi and Feng, Yihao and Qin, Can and Chen, Chia-Chih and Yu, Ning and Chen, Zeyuan and Wang, Huan and Savarese, Silvio and Ermon, Stefano and Xiong, Caiming and Xu, Ran},
journal={arXiv preprint arXiv:2303.09618},
year={2023}
}