Skip to content

Commit

Permalink
Added "Annealed importance guidance" and DRaFT+ docs (#270)
Browse files Browse the repository at this point in the history
Signed-off-by: Rohit Jena <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rohitrango and pre-commit-ci[bot] authored Sep 6, 2024
1 parent 5a4a0f8 commit da3f5f8
Show file tree
Hide file tree
Showing 5 changed files with 596 additions and 6 deletions.
104 changes: 98 additions & 6 deletions docs/user-guide/draftp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You can then run the following snipet to convert it to a ``.tar`` file:
Reward Model
############
Currently, we only have support for `Pickscore <https://arxiv.org/pdf/2305.01569.pdf>`__ reward model. Since Pickscore is a CLIP-based model,
Currently, we only have support for `Pickscore-style <https://arxiv.org/pdf/2305.01569.pdf>`__ reward models (PickScore/HPSv2). Since Pickscore is a CLIP-based model,
you can use the `conversion script <https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py>`__ from NeMo to convert it from huggingface to NeMo.
DRaFT+ Training
Expand All @@ -81,8 +81,9 @@ To launch reward model training, you must have checkpoints for `UNet <https://hu
UNET_CKPT="/path/to/unet_weights.ckpt"
VAE_CKPT="/path/to/vae_weights.bin"
RM_CKPT="/path/to/reward_model.nemo"
DRAFTP_SCRIPT="train_sd_draftp.py" # or train_sdxl_draftp.py
torchrun --nproc_per_node=2 ${GPFS}/examples/mm/stable_diffusion/train_sd_draftp.py \
torchrun --nproc_per_node=2 ${GPFS}/examples/mm/stable_diffusion/${DRAFTP_SCRIPT} \
trainer.num_nodes=1 \
trainer.devices=2 \
model.micro_batch_size=1 \
Expand All @@ -92,7 +93,7 @@ To launch reward model training, you must have checkpoints for `UNet <https://hu
model.unet_config.from_pretrained=${UNET_CKPT} \
model.first_stage_config.from_pretrained=${VAE_CKPT} \
rm.model.restore_from_path=${RM_CKPT} \
model.data.trian.webdataset.local_root_path=${TRAIN_DATA_PATH} \
model.data.train.webdataset.local_root_path=${TRAIN_DATA_PATH} \
exp_manager.create_wandb_logger=False \
exp_manager.explicit_log_dir=/results
Expand Down Expand Up @@ -135,14 +136,16 @@ To launch reward model training, you must have checkpoints for `UNet <https://hu
MOUNTS="--container-mounts=MOUNTS" # mounts
DRAFTP_SCRIPT="train_sd_draftp.py" # or train_sdxl_draftp.py
read -r -d '' cmd <<EOF
echo "*******STARTING********" \
&& echo "---------------" \
&& echo "Starting training" \
&& cd ${GPFS} \
&& export PYTHONPATH="${GPFS}:${PYTHONPATH}" \
&& export HYDRA_FULL_ERROR=1 \
&& python -u ${GPFS}/examples/nlp/gpt/train_reward_model.py \
&& python -u ${GPFS}/examples/mm/stable_diffusion/${DRAFTP_SCRIPT} \
trainer.num_nodes=1 \
trainer.devices=8 \
model.micro_batch_size=2 \
Expand All @@ -164,13 +167,102 @@ To launch reward model training, you must have checkpoints for `UNet <https://hu
.. note::
For more info on DRaFT+ hyperparameters please see the model config file:
For more info on DRaFT+ hyperparameters please see the model config files (for SD and SDXL respectively):
``NeMo-Aligner/examples/mm/stable_diffusion/conf/draftp_sd.yaml``
``NeMo-Aligner/examples/mm/stable_diffusion/conf/draftp_sdxl.yaml``
DRaFT+ Results
%%%%%%%%%%%%%%
Once you have completed fine-tuning Stable Diffusion with DRaFT+, you can run inference on your saved model using the `sd_infer.py <https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py>`__
and `sd_lora_infer.py <https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py>`__ scripts from the NeMo codebase. The generated images with the fine-tuned model should have
better prompt alignment and aesthetic quality.
better prompt alignment and aesthetic quality.
User controllable finetuning with Annealed Importance Guidance (AIG)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
AIG provides the inference-time flexibility to interpolate between the base Stable Diffusion model (with low rewards and high diversity) and DRaFT-finetuned model (with high rewards and low diversity) to obtain images with high rewards and high diversity. AIG inference is easily done by specifying comma-separated `weight_type` strategies to interpolate between the base and finetuned model.
.. tab-set::
.. tab-item:: AIG on Stable Diffusion XL
:sync: key2
Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case).
Weight type of the form `power_<float>` interpolates using an exponential decay specified in the AIG paper.
To run AIG inference on the terminal directly:
.. code-block:: bash
NUMNODES=1
LR=${LR:=0.00025}
INF_STEPS=${INF_STEPS:=25}
KL_COEF=${KL_COEF:=0.1}
ETA=${ETA:=0.0}
DATASET=${DATASET:="pickapic50k.tar"}
MICRO_BS=${MICRO_BS:=1}
GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4}
PEFT=${PEFT:="sdlora"}
NUM_DEVICES=${NUM_DEVICES:=8}
GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES))
LOG_WANDB=${LOG_WANDB:="False"}
echo "additional kwargs: ${ADDITIONAL_KWARGS}"
WANDB_NAME=SDXL_Draft_annealing
WEBDATASET_PATH=/path/to/${DATASET}
CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf"
CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"}
UNET_CKPT="/path/to/unet.ckpt"
VAE_CKPT="/path/to/vae.ckpt"
RM_CKPT="/path/to/reward_model.nemo"
PROMPT=${PROMPT:="Bananas growing on an apple tree"}
DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir
if [ ! -z "${ACT_CKPT}" ]; then
ACT_CKPT="model.activation_checkpointing=$ACT_CKPT "
echo $ACT_CKPT
fi
EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"}
export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1
set -x
CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \
--config-path=${CONFIG_PATH} \
--config-name=${CONFIG_NAME} \
model.optim.lr=${LR} \
model.optim.weight_decay=0.0005 \
model.optim.sched.warmup_steps=0 \
model.sampling.base.steps=${INF_STEPS} \
model.kl_coeff=${KL_COEF} \
model.truncation_steps=1 \
trainer.draftp_sd.max_epochs=5 \
trainer.draftp_sd.max_steps=10000 \
trainer.draftp_sd.save_interval=200 \
trainer.draftp_sd.val_check_interval=20 \
trainer.draftp_sd.gradient_clip_val=10.0 \
model.micro_batch_size=${MICRO_BS} \
model.global_batch_size=${GLOBAL_BATCH_SIZE} \
model.peft.peft_scheme=${PEFT} \
model.data.webdataset.local_root_path=$WEBDATASET_PATH \
rm.model.restore_from_path=${RM_CKPT} \
trainer.devices=${NUM_DEVICES} \
trainer.num_nodes=${NUMNODES} \
rm.trainer.devices=${NUM_DEVICES} \
rm.trainer.num_nodes=${NUMNODES} \
+prompt="${PROMPT}" \
exp_manager.create_wandb_logger=${LOG_WANDB} \
model.first_stage_config.from_pretrained=${VAE_CKPT} \
model.first_stage_config.from_NeMo=True \
model.unet_config.from_pretrained=${UNET_CKPT} \
model.unet_config.from_NeMo=True \
$ACT_CKPT \
exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \
exp_manager.resume_if_exists=True \
exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \
exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0'
Loading

0 comments on commit da3f5f8

Please sign in to comment.