Diffusion models have recently emerged as a potent tool in generative modeling. However, their inherent iterative nature often results in sluggish image generation due to the requirement for multiple model evaluations. Recent progress has unveiled the intrinsic link between diffusion models and Probability Flow Ordinary Differential Equations (ODEs), thus enabling us to conceptualize diffusion models as ODE systems. Simultaneously, Physics Informed Neural Networks (PINNs) have substantiated their effectiveness in solving intricate differential equations through implicit modeling of their solutions. Building upon these foundational insights, we introduce Physics Informed Distillation (PID), which employs a student model to represent the solution of the ODE system corresponding to the teacher diffusion model, akin to the principles employed in PINNs. Through experiments on CIFAR 10 and ImageNet 64x64, we observe that PID achieves performance comparable to recent distillation methods. Notably, it demonstrates predictable trends concerning method-specific hyperparameters and eliminates the need for synthetic dataset generation during the distillation process. Both of which contribute to its easy-to-use nature as a distillation approach for Diffusion Models.
This repository is the official implementation of the paper: Physics Informed Distillation for Diffusion Models, accepted by Transactions on Machine Learning Research (TMLR). This repository is based on openai/consistency_models. Our modifications have enabled support for PID training and sampling.
An overview of the proposed method, which involves training a model
To install all packages in this codebase along with their dependencies, run
conda create -n pid-diffusion python=3.9
conda activate pid-diffusion
conda install pytorch=1.13.1 torchvision=0.14.1 pytorch-cuda=11.6 -c pytorch -c nvidia
conda install -c "nvidia/label/cuda-11.6.1" libcusolver-dev
conda install mpi4py
git clone https://github.com/pantheon5100/pid_diffusion.git
cd pid_diffusion
pip install -e .
For CIFAR10 and ImageNet 64x64 experiments, we use the teacher model from EDM. The released checkpoint is a pickle file, so we need to extract the weights first. Run the official image sampling code to save the model's state dict.
We provide the extracted checkpoints for direct use:
Place the downloaded checkpoints into the './model_zoo' directory.
To start the distillation, use the bash scripts:
bash ./scripts/distill_pid_diffusion.sh
We use Open MPI to launch our code. Before running the experiment, configure the following in the bash file:
a. Set the environment variable
OPENAI_LOGDIR
to specify where the experiment data will be stored (e.g.,../experiment/EXP_NAME
, whereEXP_NAME
is the experiment name).b. Specify the number of GPUs to use (e.g.,
-np 8
to use 8 GPUs).c. Set the total batch size across all GPUs (e.g.,
--global_batch_size 512
, which will result in a batch size of512/8=64
per GPU).
Use the bash script ./scripts/image_sampling.sh
to sample images from the pre-trained teacher model or the distilled model. The distilled PID model can be downloaded here.
To evaluate FID scores, use the provided bash script ./scripts/fid_eval.sh
, which will evaluate all checkpoints in the EXP_PATH
folder. Download the reference statistics for the teacher model from EDM and place them in ./model_zoo/stats/cifar10-32x32.npz
and ./model_zoo/stats/imagenet-64x64.npz
. Run the following to download the reference statistics:
mkdir ./model_zoo/stats
wget https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz -P ./model_zoo/stats
wget https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/imagenet-64x64.npz -o ./model_zoo/stats/imagenet-64x64.npz
To assess our pretrained CIFAR10 model, place it in model_zoo/pid_cifar/pid_cifar.pt
, then execute the following for evaluation:
EXP_PATH="./model_zoo/pid_cifar"
mpirun -np 1 python ./scripts/fid_evaluation.py \
--training_mode one_shot_pinn_edm_edm_one_shot \
--fid_dataset cifar10 \
--exp_dir $EXP_PATH\
--batch_size 125 \
--sigma_max 80 \
--sigma_min 0.002 \
--s_churn 0 \
--steps 35 \
--sampler oneshot \
--attention_resolutions "2" \
--class_cond False \
--dropout 0.0 \
--image_size 32 \
--num_channels 128 \
--num_res_blocks 4 \
--num_samples 50000 \
--resblock_updown True \
--use_fp16 False \
--use_scale_shift_norm True \
--weight_schedule uniform \
--seed 0