diff --git a/configs/_base_/datasets/imagenet_512.py b/configs/_base_/datasets/imagenet_512.py
new file mode 100644
index 0000000000..042b141737
--- /dev/null
+++ b/configs/_base_/datasets/imagenet_512.py
@@ -0,0 +1,45 @@
+# dataset settings
+dataset_type = 'ImageNet'
+
+# different from mmcls, we adopt the setting used in BigGAN.
+# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing.
+train_pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='RandomCropLongEdge', keys=['img']),
+ dict(type='Resize', scale=(512, 512), keys=['img'], backend='pillow'),
+ dict(type='Flip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='PackEditInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='CenterCropLongEdge', keys=['img']),
+ dict(type='Resize', scale=(512, 512), backend='pillow'),
+ dict(type='PackEditInputs')
+]
+
+train_dataloader = dict(
+ batch_size=None,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='./data/imagenet/',
+ ann_file='meta/train.txt',
+ data_prefix='train',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ persistent_workers=True)
+
+val_dataloader = dict(
+ batch_size=None,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='./data/imagenet/',
+ ann_file='meta/train.txt',
+ data_prefix='train',
+ pipeline=test_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ persistent_workers=True)
+
+test_dataloader = val_dataloader
diff --git a/configs/_base_/datasets/imagenet_64.py b/configs/_base_/datasets/imagenet_64.py
new file mode 100644
index 0000000000..6250de5b19
--- /dev/null
+++ b/configs/_base_/datasets/imagenet_64.py
@@ -0,0 +1,45 @@
+# dataset settings
+dataset_type = 'ImageNet'
+
+# different from mmcls, we adopt the setting used in BigGAN.
+# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing.
+train_pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='RandomCropLongEdge', keys=['img']),
+ dict(type='Resize', scale=(64, 64), keys=['img'], backend='pillow'),
+ dict(type='Flip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='PackEditInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile', key='img'),
+ dict(type='CenterCropLongEdge', keys=['img']),
+ dict(type='Resize', scale=(64, 64), backend='pillow'),
+ dict(type='PackEditInputs')
+]
+
+train_dataloader = dict(
+ batch_size=None,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='./data/imagenet/',
+ ann_file='meta/train.txt',
+ data_prefix='train',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ persistent_workers=True)
+
+val_dataloader = dict(
+ batch_size=64,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='./data/imagenet/',
+ ann_file='meta/train.txt',
+ data_prefix='train',
+ pipeline=test_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ persistent_workers=True)
+
+test_dataloader = val_dataloader
diff --git a/configs/disco_diffusion/README.md b/configs/disco_diffusion/README.md
new file mode 100644
index 0000000000..b5aab49dbc
--- /dev/null
+++ b/configs/disco_diffusion/README.md
@@ -0,0 +1,135 @@
+# Disco Diffusion
+
+> [Disco Diffusion](https://github.com/alembics/disco-diffusion)
+
+> **Task**: Text2Image, Image2Image
+
+
+
+## Abstract
+
+
+
+Disco Diffusion (DD) is a Google Colab Notebook which leverages an AI Image generating technique called CLIP-Guided Diffusion to allow you to create compelling and beautiful images from text inputs.
+
+Created by Somnai, augmented by Gandamu, and building on the work of RiversHaveWings, nshepperd, and many others.
+
+
+
+
+
+
+
+## Results and models
+
+We have converted several `unet` weights and offer related configs. Or usage of different `unet`, please refer to tutorial.
+
+| Diffusion Model | Config | Weights |
+| ---------------------------------------- | --------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- |
+| 512x512_diffusion_uncond_finetune_008100 | [config](configs/disco/disco-diffusion_adm-u-finetuned_imagenet-512x512.py) | [weights](https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u_finetuned_imagenet-512x512-ab471d70.pth) |
+| 256x256_diffusion_uncond | [config](configs/disco/disco-diffusion_adm-u-finetuned_imagenet-256x256.py) | [weights](<>) |
+| portrait_generator_v001 | [config](configs/disco/disco-diffusion_portrait_generator_v001.py) | [weights](https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u-cvt-rgb_portrait-v001-f4a3f3bc.pth) |
+| pixelartdiffusion_expanded | Coming soon! | |
+| pixel_art_diffusion_hard_256 | Coming soon! | |
+| pixel_art_diffusion_soft_256 | Coming soon! | |
+| pixelartdiffusion4k | Coming soon! | |
+| watercolordiffusion_2 | Coming soon! | |
+| watercolordiffusion | Coming soon! | |
+| PulpSciFiDiffusion | Coming soon! | |
+
+## To-do List
+
+- [ ] pixelart, watercolor, sci-fiction diffusion models
+- [ ] image prompt
+- [ ] video generation
+- [ ] faster sampler(plms, dpm-solver etc.)
+
+We really welcome community users supporting these items and any other interesting staffs!
+
+## Quick Start
+
+Running the following codes, you can get a text-generated image.
+
+```python
+from mmengine import Config, MODELS
+from mmedit.utils import register_all_modules
+from torchvision.utils import save_image
+
+register_all_modules()
+
+disco = MODELS.build(
+ Config.fromfile('configs/disco/disco-baseline.py').model).cuda().eval()
+text_prompts = {
+ 0: [
+ "A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.",
+ "yellow color scheme"
+ ]
+}
+image = disco.infer(
+ height=768,
+ width=1280,
+ text_prompts=text_prompts,
+ show_progress=True,
+ num_inference_steps=250,
+ eta=0.8)['samples']
+save_image(image, "image.png")
+
+```
+
+## Tutorials
+
+Coming soon!
+
+## Credits
+
+Since our adaptation of disco-diffusion are heavily influenced by disco [colab](https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb#scrollTo=License), here we copy the credits below.
+
+
+Original notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.
+
+Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.
+
+Further improvements from Dango233 and nshepperd helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.
+
+Vark added code to load in multiple Clip models at once, which all prompts are evaluated against, which may greatly improve accuracy.
+
+The latest zoom, pan, rotation, and keyframes features were taken from Chigozie Nri's VQGAN Zoom Notebook (https://github.com/chigozienri, https://twitter.com/chigozienri)
+
+Advanced DangoCutn Cutout method is also from Dango223.
+
+\--
+
+Disco:
+
+Somnai (https://twitter.com/Somnai_dreams) added Diffusion Animation techniques, QoL improvements and various implementations of tech and techniques, mostly listed in the changelog below.
+
+3D animation implementation added by Adam Letts (https://twitter.com/gandamu_ml) in collaboration with Somnai. Creation of disco.py and ongoing maintenance.
+
+Turbo feature by Chris Allen (https://twitter.com/zippy731)
+
+Improvements to ability to run on local systems, Windows support, and dependency installation by HostsServer (https://twitter.com/HostsServer)
+
+VR Mode by Tom Mason (https://twitter.com/nin_artificial)
+
+Horizontal and Vertical symmetry functionality by nshepperd. Symmetry transformation_steps by huemin (https://twitter.com/huemin_art). Symmetry integration into Disco Diffusion by Dmitrii Tochilkin (https://twitter.com/cut_pow).
+
+Warp and custom model support by Alex Spirin (https://twitter.com/devdef).
+
+Pixel Art Diffusion, Watercolor Diffusion, and Pulp SciFi Diffusion models from KaliYuga (https://twitter.com/KaliYuga_ai). Follow KaliYuga's Twitter for the latest models and for notebooks with specialized settings.
+
+Integration of OpenCLIP models and initiation of integration of KaliYuga models by Palmweaver / Chris Scalf (https://twitter.com/ChrisScalf11)
+
+Integrated portrait_generator_v001 from Felipe3DArtist (https://twitter.com/Felipe3DArtist)
+
+
+
+## Citation
+
+```bibtex
+@misc{github,
+ author={alembics},
+ title={disco-diffusion},
+ year={2022},
+ url={https://github.com/alembics/disco-diffusion},
+}
+```
diff --git a/configs/disco_diffusion/disco-diffusion_adm-u-finetuned_imagenet-256x256.py b/configs/disco_diffusion/disco-diffusion_adm-u-finetuned_imagenet-256x256.py
new file mode 100644
index 0000000000..8bd44d064c
--- /dev/null
+++ b/configs/disco_diffusion/disco-diffusion_adm-u-finetuned_imagenet-256x256.py
@@ -0,0 +1,47 @@
+unet = dict(
+ type='DenoisingUnet',
+ image_size=256,
+ in_channels=3,
+ base_channels=256,
+ resblocks_per_downsample=2,
+ attention_res=(32, 16, 8),
+ norm_cfg=dict(type='GN32', num_groups=32),
+ dropout=0.0,
+ num_classes=0,
+ use_fp16=True,
+ resblock_updown=True,
+ attention_cfg=dict(
+ type='MultiHeadAttentionBlock',
+ num_heads=4,
+ num_head_channels=64,
+ use_new_attention_order=False),
+ use_scale_shift_norm=True)
+
+unet_ckpt_path = 'work_dirs/adm-cvt-rgb_finetuned_imagenet-256x256.pth' # noqa
+secondary_model_ckpt_path = 'https://download.openmmlab.com/mmediting/synthesizers/disco/secondary_model_imagenet_2.pth' # noqa
+pretrained_cfgs = dict(
+ unet=dict(ckpt_path=unet_ckpt_path, prefix='unet'),
+ secondary_model=dict(ckpt_path=secondary_model_ckpt_path, prefix=''))
+
+secondary_model = dict(type='SecondaryDiffusionImageNet2')
+
+diffusion_scheduler = dict(
+ type='DDIMScheduler',
+ variance_type='learned_range',
+ beta_schedule='linear',
+ clip_sample=False)
+
+clip_models = [
+ dict(type='ClipWrapper', clip_type='clip', name='ViT-B/32', jit=False),
+ dict(type='ClipWrapper', clip_type='clip', name='ViT-B/16', jit=False),
+ dict(type='ClipWrapper', clip_type='clip', name='RN50', jit=False)
+]
+
+model = dict(
+ type='DiscoDiffusion',
+ unet=unet,
+ diffusion_scheduler=diffusion_scheduler,
+ secondary_model=secondary_model,
+ clip_models=clip_models,
+ use_fp16=True,
+ pretrained_cfgs=pretrained_cfgs)
diff --git a/configs/disco_diffusion/disco-diffusion_adm-u-finetuned_imagenet-512x512.py b/configs/disco_diffusion/disco-diffusion_adm-u-finetuned_imagenet-512x512.py
new file mode 100644
index 0000000000..f839a5a7b6
--- /dev/null
+++ b/configs/disco_diffusion/disco-diffusion_adm-u-finetuned_imagenet-512x512.py
@@ -0,0 +1,47 @@
+unet = dict(
+ type='DenoisingUnet',
+ image_size=512,
+ in_channels=3,
+ base_channels=256,
+ resblocks_per_downsample=2,
+ attention_res=(32, 16, 8),
+ norm_cfg=dict(type='GN32', num_groups=32),
+ dropout=0.0,
+ num_classes=0,
+ use_fp16=True,
+ resblock_updown=True,
+ attention_cfg=dict(
+ type='MultiHeadAttentionBlock',
+ num_heads=4,
+ num_head_channels=64,
+ use_new_attention_order=False),
+ use_scale_shift_norm=True)
+
+unet_ckpt_path = 'https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u_finetuned_imagenet-512x512-ab471d70.pth' # noqa
+secondary_model_ckpt_path = 'https://download.openmmlab.com/mmediting/synthesizers/disco/secondary_model_imagenet_2.pth' # noqa
+pretrained_cfgs = dict(
+ unet=dict(ckpt_path=unet_ckpt_path, prefix='unet'),
+ secondary_model=dict(ckpt_path=secondary_model_ckpt_path, prefix=''))
+
+secondary_model = dict(type='SecondaryDiffusionImageNet2')
+
+diffusion_scheduler = dict(
+ type='DDIMScheduler',
+ variance_type='learned_range',
+ beta_schedule='linear',
+ clip_sample=False)
+
+clip_models = [
+ dict(type='ClipWrapper', clip_type='clip', name='ViT-B/32', jit=False),
+ dict(type='ClipWrapper', clip_type='clip', name='ViT-B/16', jit=False),
+ dict(type='ClipWrapper', clip_type='clip', name='RN50', jit=False)
+]
+
+model = dict(
+ type='DiscoDiffusion',
+ unet=unet,
+ diffusion_scheduler=diffusion_scheduler,
+ secondary_model=secondary_model,
+ clip_models=clip_models,
+ use_fp16=True,
+ pretrained_cfgs=pretrained_cfgs)
diff --git a/configs/disco_diffusion/disco-diffusion_portrait_generator_v001.py b/configs/disco_diffusion/disco-diffusion_portrait_generator_v001.py
new file mode 100644
index 0000000000..884240b975
--- /dev/null
+++ b/configs/disco_diffusion/disco-diffusion_portrait_generator_v001.py
@@ -0,0 +1,7 @@
+_base_ = ['./disco-diffusion_adm-u-finetuned_imagenet-512x512.py']
+unet_ckpt_path = 'https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u-cvt-rgb_portrait-v001-f4a3f3bc.pth' # noqa
+model = dict(
+ unet=dict(base_channels=128),
+ secondary_model=None,
+ pretrained_cfgs=dict(
+ _delete_=True, unet=dict(ckpt_path=unet_ckpt_path, prefix='unet')))
diff --git a/configs/disco_diffusion/metafile.yml b/configs/disco_diffusion/metafile.yml
new file mode 100644
index 0000000000..f203d6ea27
--- /dev/null
+++ b/configs/disco_diffusion/metafile.yml
@@ -0,0 +1,9 @@
+Collections:
+- Metadata:
+ Architecture:
+ - Disco Diffusion
+ Name: Disco Diffusion
+ Paper:
+ - https://github.com/alembics/disco-diffusion
+ README: configs/disco_diffusion/README.md
+Models: []
diff --git a/configs/guided_diffusion/README.md b/configs/guided_diffusion/README.md
new file mode 100644
index 0000000000..eb659b8e72
--- /dev/null
+++ b/configs/guided_diffusion/README.md
@@ -0,0 +1,45 @@
+# Guided Diffusion (NeurIPS'2021)
+
+> [Diffusion Models Beat GANs on Image Synthesis](https://papers.nips.cc/paper/2021/file/49ad23d1ec9fa4bd8d77d02681df5cfa-Paper.pdf)
+
+> **Task**: Image Generation
+
+
+
+## Abstract
+
+
+
+We show that diffusion models can achieve image sample quality superior to the current state-of-the-art generative models. We achieve this on unconditional image synthesis by finding a better architecture through a series of ablations. For conditional image synthesis, we further improve sample quality with classifier guidance: a simple, compute-efficient method for trading off diversity for fidelity using gradients from a classifier. We achieve an FID of 2.97 on ImageNet 128x128, 4.59 on ImageNet 256x256, and 7.72 on ImageNet 512x512, and we match BigGAN-deep even with as few as 25 forward passes per sample, all while maintaining better coverage of the distribution. Finally, we find that classifier guidance combines well with upsampling diffusion models, further improving FID to 3.94 on ImageNet 256x256 and 3.85 on ImageNet 512x512.
+
+
+
+
+
+
+
+## Results and models
+
+**ImageNet**
+
+| Method | Resolution | Config | Weights |
+| ------ | ---------- | ------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------- |
+| adm-u | 64x64 | [config](configs/guided_diffusion/adm-u_8xb32_imagenet-64x64.py) | [model](https://download.openmmlab.com/mmgen/guided_diffusion/adm-u-cvt-rgb_8xb32_imagenet-64x64-7ff0080b.pth) |
+| adm-u | 512x512 | [config](configs/guided_diffusion/adm-u_8xb32_imagenet-512x512.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmgen/guided_diffusion/adm-u_8xb32_imagenet-512x512-60b381cb.pth) |
+
+**Note** To support disco diffusion, we support guided diffusion briefly. Complete support of guided diffusion with metrics and test/train logs will come soom!
+
+## Quick Start
+
+Coming soon!
+
+## Citation
+
+```bibtex
+@article{PrafullaDhariwal2021DiffusionMB,
+ title={Diffusion Models Beat GANs on Image Synthesis},
+ author={Prafulla Dhariwal and Alex Nichol},
+ journal={arXiv: Learning},
+ year={2021}
+}
+```
diff --git a/configs/guided_diffusion/adm-u_8xb32_imagenet-512x512.py b/configs/guided_diffusion/adm-u_8xb32_imagenet-512x512.py
new file mode 100644
index 0000000000..35d841b617
--- /dev/null
+++ b/configs/guided_diffusion/adm-u_8xb32_imagenet-512x512.py
@@ -0,0 +1,45 @@
+_base_ = [
+ '../_base_/datasets/imagenet_512.py',
+ '../_base_/gen_default_runtime.py',
+]
+
+model = dict(
+ type='AblatedDiffusionModel',
+ data_preprocessor=dict(
+ type='EditDataPreprocessor', mean=[127.5], std=[127.5]),
+ unet=dict(
+ type='DenoisingUnet',
+ image_size=512,
+ in_channels=3,
+ base_channels=256,
+ resblocks_per_downsample=2,
+ attention_res=(32, 16, 8),
+ norm_cfg=dict(type='GN32', num_groups=32),
+ dropout=0.1,
+ num_classes=1000,
+ use_fp16=False,
+ resblock_updown=True,
+ attention_cfg=dict(
+ type='MultiHeadAttentionBlock',
+ num_heads=4,
+ num_head_channels=64,
+ use_new_attention_order=False),
+ use_scale_shift_norm=True),
+ diffusion_scheduler=dict(
+ type='DDPMScheduler',
+ variance_type='learned_range',
+ beta_schedule='linear'),
+ use_fp16=False)
+
+test_dataloader = dict(batch_size=32, num_workers=8)
+
+metrics = [
+ dict(
+ type='FrechetInceptionDistance',
+ prefix='FID-Full-50k',
+ fake_nums=50000,
+ inception_style='StyleGAN')
+]
+
+val_evaluator = dict(metrics=metrics)
+test_evaluator = dict(metrics=metrics)
diff --git a/configs/guided_diffusion/adm-u_8xb32_imagenet-64x64.py b/configs/guided_diffusion/adm-u_8xb32_imagenet-64x64.py
new file mode 100644
index 0000000000..5150b04d6e
--- /dev/null
+++ b/configs/guided_diffusion/adm-u_8xb32_imagenet-64x64.py
@@ -0,0 +1,45 @@
+_base_ = [
+ '../_base_/datasets/imagenet_64.py',
+ '../_base_/gen_default_runtime.py',
+]
+
+model = dict(
+ type='AblatedDiffusionModel',
+ data_preprocessor=dict(
+ type='EditDataPreprocessor', mean=[127.5], std=[127.5]),
+ unet=dict(
+ type='DenoisingUnet',
+ image_size=64,
+ in_channels=3,
+ base_channels=192,
+ resblocks_per_downsample=3,
+ attention_res=(32, 16, 8),
+ norm_cfg=dict(type='GN32', num_groups=32),
+ dropout=0.1,
+ num_classes=1000,
+ use_fp16=False,
+ resblock_updown=True,
+ attention_cfg=dict(
+ type='MultiHeadAttentionBlock',
+ num_heads=4,
+ num_head_channels=64,
+ use_new_attention_order=True),
+ use_scale_shift_norm=True),
+ diffusion_scheduler=dict(
+ type='DDPMScheduler',
+ variance_type='learned_range',
+ beta_schedule='squaredcos_cap_v2'),
+ use_fp16=False)
+
+test_dataloader = dict(batch_size=32, num_workers=8)
+
+metrics = [
+ dict(
+ type='FrechetInceptionDistance',
+ prefix='FID-Full-50k',
+ fake_nums=50000,
+ inception_style='StyleGAN')
+]
+
+val_evaluator = dict(metrics=metrics)
+test_evaluator = dict(metrics=metrics)
diff --git a/configs/guided_diffusion/adm-u_ddim250_8xb32_imagenet-512x512.py b/configs/guided_diffusion/adm-u_ddim250_8xb32_imagenet-512x512.py
new file mode 100644
index 0000000000..0540aebc2a
--- /dev/null
+++ b/configs/guided_diffusion/adm-u_ddim250_8xb32_imagenet-512x512.py
@@ -0,0 +1,45 @@
+_base_ = [
+ '../_base_/datasets/imagenet_512.py',
+ '../_base_/gen_default_runtime.py',
+]
+
+model = dict(
+ type='AblatedDiffusionModel',
+ data_preprocessor=dict(
+ type='EditDataPreprocessor', mean=[127.5], std=[127.5]),
+ unet=dict(
+ type='DenoisingUnet',
+ image_size=512,
+ in_channels=3,
+ base_channels=256,
+ resblocks_per_downsample=2,
+ attention_res=(32, 16, 8),
+ norm_cfg=dict(type='GN32', num_groups=32),
+ dropout=0.1,
+ num_classes=1000,
+ use_fp16=False,
+ resblock_updown=True,
+ attention_cfg=dict(
+ type='MultiHeadAttentionBlock',
+ num_heads=4,
+ num_head_channels=64,
+ use_new_attention_order=False),
+ use_scale_shift_norm=True),
+ diffusion_scheduler=dict(
+ type='DDIMScheduler',
+ variance_type='learned_range',
+ beta_schedule='linear'),
+ use_fp16=False)
+
+test_dataloader = dict(batch_size=32, num_workers=8)
+
+metrics = [
+ dict(
+ type='FrechetInceptionDistance',
+ prefix='FID-Full-50k',
+ fake_nums=50000,
+ inception_style='StyleGAN')
+]
+
+val_evaluator = dict(metrics=metrics)
+test_evaluator = dict(metrics=metrics)
diff --git a/configs/guided_diffusion/metafile.yml b/configs/guided_diffusion/metafile.yml
new file mode 100644
index 0000000000..1eedef6e1a
--- /dev/null
+++ b/configs/guided_diffusion/metafile.yml
@@ -0,0 +1,9 @@
+Collections:
+- Metadata:
+ Architecture:
+ - Guided Diffusion
+ Name: Guided Diffusion
+ Paper:
+ - https://papers.nips.cc/paper/2021/file/49ad23d1ec9fa4bd8d77d02681df5cfa-Paper.pdf
+ README: configs/guided_diffusion/README.md
+Models: []
diff --git a/mmedit/datasets/transforms/random_degradations.py b/mmedit/datasets/transforms/random_degradations.py
index 64433c43d5..65e7fb7849 100644
--- a/mmedit/datasets/transforms/random_degradations.py
+++ b/mmedit/datasets/transforms/random_degradations.py
@@ -292,9 +292,6 @@ def _apply_random_noise(self, imgs):
Args:
imgs (Tensor): training images
- Raises:
- NotImplementedError: _description_
-
Returns:
_type_: _description_
"""
@@ -361,10 +358,7 @@ def _random_resize(self, imgs):
augmentation.
Args:
- imgs (Tensor): training images
-
- Raises:
- NotImplementedError: _description_
+ imgs (Tensor): training images.
Returns:
Tensor: images after radomly resized
diff --git a/mmedit/models/editors/__init__.py b/mmedit/models/editors/__init__.py
index d3d9aefc0f..ce2ebfc506 100644
--- a/mmedit/models/editors/__init__.py
+++ b/mmedit/models/editors/__init__.py
@@ -7,6 +7,8 @@
from .cain import CAIN, CAINNet
from .cyclegan import CycleGAN
from .dcgan import DCGAN
+from .ddim import DDIMScheduler
+from .ddpm import DDPMScheduler, DenoisingUnet
from .deepfillv1 import (ContextualAttentionModule, ContextualAttentionNeck,
DeepFillDecoder, DeepFillEncoder, DeepFillRefiner,
DeepFillv1Discriminators, DeepFillv1Inpaintor)
@@ -14,6 +16,7 @@
from .dic import (DIC, DICNet, FeedbackBlock, FeedbackBlockCustom,
FeedbackBlockHeatmapAttention, LightCNN, MaxFeature)
from .dim import DIM
+from .disco_diffusion import ClipWrapper, DiscoDiffusion
from .edsr import EDSRNet
from .edvr import EDVR, EDVRNet
from .esrgan import ESRGAN, RRDBNet
@@ -24,6 +27,7 @@
from .glean import GLEANStyleGANv2
from .global_local import (GLDecoder, GLDilationNeck, GLEncoder,
GLEncoderDecoder)
+from .guided_diffusion import AblatedDiffusionModel
from .iconvsr import IconVSRNet
from .indexnet import (DepthwiseIndexBlock, HolisticIndexBlock,
IndexedUpsample, IndexNet, IndexNetDecoder,
@@ -74,7 +78,9 @@
'LIIF', 'MLPRefiner', 'PlainRefiner', 'PlainDecoder', 'FBAResnetDilated',
'FBADecoder', 'WGANGP', 'CycleGAN', 'SAGAN', 'LSGAN', 'GGAN', 'Pix2Pix',
'StyleGAN1', 'StyleGAN2', 'StyleGAN3', 'BigGAN', 'DCGAN',
- 'ProgressiveGrowingGAN', 'SinGAN', 'IDLossModel', 'PESinGAN',
- 'MSPIEStyleGAN2', 'StyleGAN3Generator', 'InstColorization', 'NAFBaseline',
- 'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal'
+ 'ProgressiveGrowingGAN', 'SinGAN', 'AblatedDiffusionModel',
+ 'DiscoDiffusion', 'IDLossModel', 'PESinGAN', 'MSPIEStyleGAN2',
+ 'StyleGAN3Generator', 'InstColorization', 'NAFBaseline',
+ 'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DDIMScheduler',
+ 'DDPMScheduler', 'DenoisingUnet', 'ClipWrapper'
]
diff --git a/mmedit/models/editors/ddim/__init__.py b/mmedit/models/editors/ddim/__init__.py
new file mode 100644
index 0000000000..4b14e89b77
--- /dev/null
+++ b/mmedit/models/editors/ddim/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .ddim_scheduler import DDIMScheduler
+
+__all__ = ['DDIMScheduler']
diff --git a/mmedit/models/editors/ddim/ddim_scheduler.py b/mmedit/models/editors/ddim/ddim_scheduler.py
new file mode 100644
index 0000000000..104501d9e5
--- /dev/null
+++ b/mmedit/models/editors/ddim/ddim_scheduler.py
@@ -0,0 +1,220 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+import numpy as np
+import torch
+
+from mmedit.registry import DIFFUSION_SCHEDULERS
+from ...utils.diffusion_utils import betas_for_alpha_bar
+
+
+@DIFFUSION_SCHEDULERS.register_module()
+class DDIMScheduler:
+ """```DDIMScheduler``` support the diffusion and reverse process formulated
+ in https://arxiv.org/abs/2010.02502.
+
+ The code is heavily influenced by https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py. # noqa
+ The difference is that we ensemble gradient-guided sampling in step function.
+
+ Args:
+ num_train_timesteps (int, optional): _description_. Defaults to 1000.
+ beta_start (float, optional): _description_. Defaults to 0.0001.
+ beta_end (float, optional): _description_. Defaults to 0.02.
+ beta_schedule (str, optional): _description_. Defaults to "linear".
+ variance_type (str, optional): _description_. Defaults to 'learned_range'.
+ timestep_values (_type_, optional): _description_. Defaults to None.
+ clip_sample (bool, optional): _description_. Defaults to True.
+ set_alpha_to_one (bool, optional): _description_. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ num_train_timesteps=1000,
+ beta_start=0.0001,
+ beta_end=0.02,
+ beta_schedule='linear',
+ variance_type='learned_range',
+ timestep_values=None,
+ clip_sample=True,
+ set_alpha_to_one=True,
+ ):
+ self.num_train_timesteps = num_train_timesteps
+ self.beta_start = beta_start
+ self.beta_end = beta_end
+ self.beta_schedule = beta_schedule
+ self.variance_type = variance_type
+ self.timestep_values = timestep_values
+ self.clip_sample = clip_sample
+ self.set_alpha_to_one = set_alpha_to_one
+
+ if beta_schedule == 'linear':
+ self.betas = np.linspace(
+ beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == 'scaled_linear':
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_train_timesteps,
+ dtype=np.float32)**2
+ elif beta_schedule == 'squaredcos_cap_v2':
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(
+ f'{beta_schedule} does is not implemented for {self.__class__}'
+ )
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ # At every step in ddim, we are looking into the
+ # previous alphas_cumprod. For the final step,
+ # there is no previous alphas_cumprod because we are already
+ # at 0 `set_alpha_to_one` decides whether we set this paratemer
+ # simply to one or whether we use the final alpha of the
+ # "non-previous" one.
+ self.final_alpha_cumprod = np.array(
+ 1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ def set_timesteps(self, num_inference_steps, offset=0):
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(
+ 0, self.num_train_timesteps,
+ self.num_train_timesteps // self.num_inference_steps)[::-1].copy()
+ self.timesteps += offset
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[
+ prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+ variance = (beta_prod_t_prev /
+ beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+ return variance
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ cond_fn=None,
+ cond_kwargs={},
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ ):
+ output = {}
+ if self.num_inference_steps is None:
+ raise ValueError("Number of inference steps is 'None', '\
+ 'you need to run 'set_timesteps' '\
+ 'after creating the scheduler")
+
+ pred = None
+ if isinstance(model_output, dict):
+ pred = model_output['pred']
+ model_output = model_output['eps']
+ elif model_output.shape[1] == sample.shape[
+ 1] * 2 and self.variance_type in ['learned', 'learned_range']:
+ model_output, _ = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ raise TypeError
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # noqa
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointingc to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = (
+ timestep - self.num_train_timesteps // self.num_inference_steps)
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[
+ prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # noqa
+ pred_original_sample = (sample - (
+ (beta_prod_t)**(0.5)) * model_output) / alpha_prod_t**(0.5)
+ if pred is not None:
+ pred_original_sample = pred
+
+ gradient = 0.
+ if cond_fn is not None:
+ gradient = cond_fn(
+ cond_kwargs.pop('unet'), self, sample, timestep, beta_prod_t,
+ cond_kwargs.pop('model_stats'), **cond_kwargs)
+ model_output = model_output - (beta_prod_t**0.5) * gradient
+ pred_original_sample = (
+ sample -
+ (beta_prod_t**(0.5)) * model_output) / alpha_prod_t**(0.5)
+ # 4. Clip "predicted x_0"
+ if self.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance**(0.5)
+ output.update(dict(sigma=std_dev_t))
+
+ if use_clipped_model_output:
+ # the model_output is always
+ # re-derived from the clipped x_0 in Glide
+ model_output = (sample - (alpha_prod_t**(0.5)) *
+ pred_original_sample) / beta_prod_t**(0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # noqa
+ pred_sample_direction = (1 - alpha_prod_t_prev -
+ std_dev_t**2)**(0.5) * model_output
+
+ # 7. compute x_t without "random noise" of
+ # formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_mean = alpha_prod_t_prev**(
+ 0.5) * pred_original_sample + pred_sample_direction
+ output.update(dict(mean=prev_mean, prev_sample=prev_mean))
+
+ if eta > 0:
+ device = model_output.device if torch.is_tensor(
+ model_output) else 'cpu'
+ noise = torch.randn(
+ model_output.shape, generator=generator).to(device)
+ variance = std_dev_t * noise
+
+ if not torch.is_tensor(model_output):
+ variance = variance.numpy()
+
+ prev_sample = prev_mean + variance
+ output.update({'prev_sample': prev_sample})
+
+ # NOTE: this x0 is twice computed
+ output.update({
+ 'original_sample': pred_original_sample,
+ 'beta_prod_t': beta_prod_t
+ })
+ return output
+
+ def add_noise(self, original_samples, noise, timesteps):
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5
+ noisy_samples = (
+ sqrt_alpha_prod * original_samples +
+ sqrt_one_minus_alpha_prod * noise)
+ return noisy_samples
+
+ def __len__(self):
+ return self.num_train_timesteps
diff --git a/mmedit/models/editors/ddpm/__init__.py b/mmedit/models/editors/ddpm/__init__.py
new file mode 100644
index 0000000000..2b94f11031
--- /dev/null
+++ b/mmedit/models/editors/ddpm/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .ddpm_scheduler import DDPMScheduler
+from .denoising_unet import DenoisingUnet
+
+__all__ = ['DDPMScheduler', 'DenoisingUnet']
diff --git a/mmedit/models/editors/ddpm/ddpm_scheduler.py b/mmedit/models/editors/ddpm/ddpm_scheduler.py
new file mode 100644
index 0000000000..d06f2e85bb
--- /dev/null
+++ b/mmedit/models/editors/ddpm/ddpm_scheduler.py
@@ -0,0 +1,200 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+import numpy as np
+import torch
+
+from mmedit.registry import DIFFUSION_SCHEDULERS
+from ...utils.diffusion_utils import betas_for_alpha_bar
+
+
+@DIFFUSION_SCHEDULERS.register_module()
+class DDPMScheduler:
+
+ def __init__(self,
+ num_train_timesteps=1000,
+ beta_start=0.0001,
+ beta_end=0.02,
+ beta_schedule='linear',
+ trained_betas=None,
+ variance_type='fixed_small',
+ clip_sample=True):
+ """```DDPMScheduler``` support the diffusion and reverse process
+ formulated in https://arxiv.org/abs/2006.11239.
+
+ The code is heavily influenced by https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py. # noqa
+
+ Args:
+ num_train_timesteps (int, optional): _description_. Defaults to 1000.
+ beta_start (float, optional): _description_. Defaults to 0.0001.
+ beta_end (float, optional): _description_. Defaults to 0.02.
+ beta_schedule (str, optional): _description_. Defaults to 'linear'.
+ trained_betas (_type_, optional): _description_. Defaults to None.
+ variance_type (str, optional): _description_. Defaults to 'fixed_small'.
+ clip_sample (bool, optional): _description_. Defaults to True.
+ """
+ self.num_train_timesteps = num_train_timesteps
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ elif beta_schedule == 'linear':
+ self.betas = np.linspace(
+ beta_start, beta_end, num_train_timesteps, dtype=np.float64)
+ elif beta_schedule == 'scaled_linear':
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_train_timesteps,
+ dtype=np.float32)**2
+ elif beta_schedule == 'squaredcos_cap_v2':
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(
+ f'{beta_schedule} does is not implemented for {self.__class__}'
+ )
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+ self.one = np.array(1.0)
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.variance_type = variance_type
+ self.clip_sample = clip_sample
+
+ def set_timesteps(self, num_inference_steps):
+ num_inference_steps = min(self.num_train_timesteps,
+ num_inference_steps)
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(
+ 0, self.num_train_timesteps,
+ self.num_train_timesteps // self.num_inference_steps)[::-1].copy()
+
+ def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # noqa
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample # noqa
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
+
+ if t == 0:
+ log_variance = (1 - alpha_prod_t_prev) / (
+ 1 - alpha_prod_t) * self.betas[1]
+ else:
+ log_variance = np.log(variance)
+
+ if variance_type is None:
+ variance_type = self.variance_type
+
+ # hacks - were probs added for training stability
+ if variance_type == 'fixed_small':
+ variance = np.clip(variance, min_value=1e-20)
+ # for rl-diffusion_scheduler https://arxiv.org/abs/2205.09991
+ elif variance_type == 'fixed_small_log':
+ variance = np.log(np.clip(variance, min_value=1e-20))
+ elif variance_type == 'fixed_large':
+ variance = self.betas[t]
+ elif variance_type == 'fixed_large_log':
+ # Glide max_log
+ variance = np.log(self.betas[t])
+ elif variance_type == 'learned':
+ return predicted_variance
+ elif variance_type == 'learned_range':
+ min_log = log_variance
+ max_log = np.log(self.betas[t])
+ frac = (predicted_variance + 1) / 2
+ log_variance = frac * max_log + (1 - frac) * min_log
+ variance = torch.exp(log_variance)
+
+ return variance
+
+ def step(self,
+ model_output: Union[torch.FloatTensor],
+ timestep: int,
+ sample: Union[torch.FloatTensor],
+ predict_epsilon=True,
+ generator=None):
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[
+ 1] * 2 and self.variance_type in ['learned', 'learned_range']:
+ model_output, predicted_variance = torch.split(
+ model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # noqa
+ if predict_epsilon:
+ pred_original_sample = (
+ (sample - beta_prod_t**(0.5) * model_output) /
+ alpha_prod_t**(0.5))
+ else:
+ pred_original_sample = model_output
+
+ # 3. Clip "predicted x_0"
+ if self.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # noqa
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev**(0.5) *
+ self.betas[t]) / beta_prod_t
+ current_sample_coeff = self.alphas[t]**(
+ 0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_mean = (
+ pred_original_sample_coeff * pred_original_sample +
+ current_sample_coeff * sample)
+
+ # 6. Add noise
+ noise = torch.randn_like(model_output)
+ sigma = 0
+ if t > 0:
+ sigma = self._get_variance(
+ t, predicted_variance=predicted_variance)**0.5
+
+ pred_prev_sample = pred_prev_mean + sigma * noise
+
+ return {
+ 'prev_sample': pred_prev_sample,
+ 'mean': pred_prev_mean,
+ 'sigma': sigma,
+ 'noise': noise
+ }
+
+ def add_noise(self, original_samples, noise, timesteps):
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod,
+ original_samples)
+
+ noisy_samples = (
+ sqrt_alpha_prod * original_samples +
+ sqrt_one_minus_alpha_prod * noise)
+ return noisy_samples
+
+ def training_loss(self, model, x_0, t):
+ raise NotImplementedError(
+ 'This function is supposed to return '
+ 'a dict containing loss items giving sampled x0 and timestep.')
+
+ def sample_timestep(self):
+ raise NotImplementedError
+
+ def __len__(self):
+ return self.num_train_timesteps
diff --git a/mmedit/models/editors/ddpm/denoising_unet.py b/mmedit/models/editors/ddpm/denoising_unet.py
new file mode 100644
index 0000000000..1154825077
--- /dev/null
+++ b/mmedit/models/editors/ddpm/denoising_unet.py
@@ -0,0 +1,1001 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from copy import deepcopy
+from functools import partial
+
+import mmengine
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn.bricks import build_norm_layer
+from mmcv.cnn.bricks.conv_module import ConvModule
+from mmengine.logging import MMLogger
+from mmengine.model import BaseModule, constant_init
+from mmengine.runner import load_checkpoint
+from mmengine.utils.dl_utils import TORCH_VERSION
+from mmengine.utils.version_utils import digit_version
+
+from mmedit.registry import MODELS, MODULES
+
+
+class EmbedSequential(nn.Sequential):
+ """A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+
+ Modified from
+ https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/unet.py#L35
+ """
+
+ def forward(self, x, y):
+ for layer in self:
+ if isinstance(layer, DenoisingResBlock):
+ x = layer(x, y)
+ else:
+ x = layer(x)
+ return x
+
+
+@MODELS.register_module('GN32')
+class GroupNorm32(nn.GroupNorm):
+
+ def __init__(self, num_channels, num_groups=32, **kwargs):
+ super().__init__(num_groups, num_channels, **kwargs)
+
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def convert_module_to_f16(layer):
+ """Convert primitive modules to float16."""
+ if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ layer.weight.data = layer.weight.data.half()
+ if layer.bias is not None:
+ layer.bias.data = layer.bias.data.half()
+
+
+def convert_module_to_f32(layer):
+ """Convert primitive modules to float32, undoing
+ convert_module_to_f16()."""
+ if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ layer.weight.data = layer.weight.data.float()
+ if layer.bias is not None:
+ layer.bias.data = layer.bias.data.float()
+
+
+@MODELS.register_module()
+class SiLU(BaseModule):
+ r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
+ The SiLU function is also known as the swish function.
+ Args:
+ input (bool, optional): Use inplace operation or not.
+ Defaults to `False`.
+ """
+
+ def __init__(self, inplace=False):
+ super().__init__()
+ if digit_version(TORCH_VERSION) <= digit_version('1.6.0') and inplace:
+ mmengine.print_log(
+ 'Inplace version of \'SiLU\' is not supported for '
+ f'torch < 1.6.0, found \'{torch.version}\'.')
+ self.inplace = inplace
+
+ def forward(self, x):
+ """Forward function for SiLU.
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Tensor after activation.
+ """
+
+ if digit_version(TORCH_VERSION) <= digit_version('1.6.0'):
+ return x * torch.sigmoid(x)
+
+ return F.silu(x, inplace=self.inplace)
+
+
+@MODULES.register_module()
+class MultiHeadAttention(BaseModule):
+ """An attention block allows spatial position to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. # noqa
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ num_heads (int, optional): Number of heads in the attention.
+ norm_cfg (dict, optional): Config for normalization layer. Default
+ to ``dict(type='GN', num_groups=32)``
+ """
+
+ def __init__(self,
+ in_channels,
+ num_heads=1,
+ norm_cfg=dict(type='GN', num_groups=32)):
+ super().__init__()
+ self.num_heads = num_heads
+ _, self.norm = build_norm_layer(norm_cfg, in_channels)
+ self.qkv = nn.Conv1d(in_channels, in_channels * 3, 1)
+ self.proj = nn.Conv1d(in_channels, in_channels, 1)
+ self.init_weights()
+
+ @staticmethod
+ def QKVAttention(qkv):
+ channel = qkv.shape[1] // 3
+ q, k, v = torch.chunk(qkv, 3, dim=1)
+ scale = 1 / np.sqrt(np.sqrt(channel))
+ weight = torch.einsum('bct,bcs->bts', q * scale, k * scale)
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ weight = torch.einsum('bts,bcs->bct', weight, v)
+ return weight
+
+ def forward(self, x):
+ """Forward function for multi head attention.
+ Args:
+ x (torch.Tensor): Input feature map.
+
+ Returns:
+ torch.Tensor: Feature map after attention.
+ """
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
+ h = self.QKVAttention(qkv)
+ h = h.reshape(b, -1, h.shape[-1])
+ h = self.proj(h)
+ return (h + x).reshape(b, c, *spatial)
+
+ def init_weights(self):
+ constant_init(self.proj, 0)
+
+
+@MODULES.register_module()
+class MultiHeadAttentionBlock(BaseModule):
+ """An attention block that allows spatial positions to attend to each
+ other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(self,
+ in_channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_new_attention_order=False,
+ norm_cfg=dict(type='GN32', num_groups=32)):
+ super().__init__()
+ self.in_channels = in_channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (in_channels % num_head_channels == 0), (
+ f'q,k,v channels {in_channels} is not divisible by '
+ 'num_head_channels {num_head_channels}')
+ self.num_heads = in_channels // num_head_channels
+ _, self.norm = build_norm_layer(norm_cfg, in_channels)
+ self.qkv = nn.Conv1d(in_channels, in_channels * 3, 1)
+
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = nn.Conv1d(in_channels, in_channels, 1)
+
+ def forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+@MODULES.register_module()
+class QKVAttentionLegacy(BaseModule):
+ """A module which performs QKV attention.
+
+ Matches legacy QKVAttention + input/output heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
+ ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = torch.einsum(
+ 'bct,bcs->bts', q * scale,
+ k * scale) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = torch.einsum('bts,bcs->bct', weight, v)
+ return a.reshape(bs, -1, length)
+
+
+@MODULES.register_module()
+class QKVAttention(BaseModule):
+ """A module which performs QKV attention and splits in a different
+ order."""
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """Apply QKV attention.
+
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = torch.einsum(
+ 'bct,bcs->bts',
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = torch.einsum('bts,bcs->bct', weight,
+ v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+
+@MODULES.register_module()
+class TimeEmbedding(BaseModule):
+ """Time embedding layer, reference to Two level embedding. First embedding
+ time by an embedding function, then feed to neural networks.
+
+ Args:
+ in_channels (int): The channel number of the input feature map.
+ embedding_channels (int): The channel number of the output embedding.
+ embedding_mode (str, optional): Embedding mode for the time embedding.
+ Defaults to 'sin'.
+ embedding_cfg (dict, optional): Config for time embedding.
+ Defaults to None.
+ act_cfg (dict, optional): Config for activation layer. Defaults to
+ ``dict(type='SiLU', inplace=False)``.
+ """
+
+ def __init__(self,
+ in_channels,
+ embedding_channels,
+ embedding_mode='sin',
+ embedding_cfg=None,
+ act_cfg=dict(type='SiLU', inplace=False)):
+ super().__init__()
+ self.blocks = nn.Sequential(
+ nn.Linear(in_channels, embedding_channels), MODELS.build(act_cfg),
+ nn.Linear(embedding_channels, embedding_channels))
+
+ # add `dim` to embedding config
+ embedding_cfg_ = dict(dim=in_channels)
+ if embedding_cfg is not None:
+ embedding_cfg_.update(embedding_cfg)
+ if embedding_mode.upper() == 'SIN':
+ self.embedding_fn = partial(self.sinusodial_embedding,
+ **embedding_cfg_)
+ else:
+ raise ValueError('Only support `SIN` for time embedding, '
+ f'but receive {embedding_mode}.')
+
+ @staticmethod
+ def sinusodial_embedding(timesteps, dim, max_period=10000):
+ """Create sinusoidal timestep embeddings.
+
+ Args:
+ timesteps (torch.Tensor): Timestep to embedding. 1-D tensor shape
+ as ``[bz, ]``, one per batch element.
+ dim (int): The dimension of the embedding.
+ max_period (int, optional): Controls the minimum frequency of the
+ embeddings. Defaults to ``10000``.
+
+ Returns:
+ torch.Tensor: Embedding results shape as `[bz, dim]`.
+ """
+
+ half = dim // 2
+ freqs = torch.exp(
+ -np.log(max_period) *
+ torch.arange(start=0, end=half, dtype=torch.float32) /
+ half).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ """Forward function for time embedding layer.
+ Args:
+ t (torch.Tensor): Input timesteps.
+
+ Returns:
+ torch.Tensor: Timesteps embedding.
+
+ """
+ return self.blocks(self.embedding_fn(t))
+
+
+@MODULES.register_module()
+class DenoisingResBlock(BaseModule):
+ """Resblock for the denoising network. If `in_channels` not equals to
+ `out_channels`, a learnable shortcut with conv layers will be added.
+
+ Args:
+ in_channels (int): Number of channels of the input feature map.
+ embedding_channels (int): Number of channels of the input embedding.
+ use_scale_shift_norm (bool): Whether use scale-shift-norm in
+ `NormWithEmbedding` layer.
+ dropout (float): Probability of the dropout layers.
+ out_channels (int, optional): Number of output channels of the
+ ResBlock. If not defined, the output channels will equal to the
+ `in_channels`. Defaults to `None`.
+ norm_cfg (dict, optional): The config for the normalization layers.
+ Defaults too ``dict(type='GN', num_groups=32)``.
+ act_cfg (dict, optional): The config for the activation layers.
+ Defaults to ``dict(type='SiLU', inplace=False)``.
+ shortcut_kernel_size (int, optional): The kernel size for the shortcut
+ conv. Defaults to ``1``.
+ """
+
+ def __init__(self,
+ in_channels,
+ embedding_channels,
+ use_scale_shift_norm,
+ dropout,
+ out_channels=None,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='SiLU', inplace=False),
+ shortcut_kernel_size=1,
+ up=False,
+ down=False):
+ super().__init__()
+ out_channels = in_channels if out_channels is None else out_channels
+
+ _norm_cfg = deepcopy(norm_cfg)
+
+ _, norm_1 = build_norm_layer(_norm_cfg, in_channels)
+ conv_1 = [
+ norm_1,
+ MODELS.build(act_cfg),
+ nn.Conv2d(in_channels, out_channels, 3, padding=1)
+ ]
+ self.conv_1 = nn.Sequential(*conv_1)
+
+ norm_with_embedding_cfg = dict(
+ in_channels=out_channels,
+ embedding_channels=embedding_channels,
+ use_scale_shift=use_scale_shift_norm,
+ norm_cfg=_norm_cfg)
+ self.norm_with_embedding = MODULES.build(
+ dict(type='NormWithEmbedding'),
+ default_args=norm_with_embedding_cfg)
+
+ conv_2 = [
+ MODELS.build(act_cfg),
+ nn.Dropout(dropout),
+ nn.Conv2d(out_channels, out_channels, 3, padding=1)
+ ]
+ self.conv_2 = nn.Sequential(*conv_2)
+
+ assert shortcut_kernel_size in [
+ 1, 3
+ ], ('Only support `1` and `3` for `shortcut_kernel_size`, but '
+ f'receive {shortcut_kernel_size}.')
+
+ self.learnable_shortcut = out_channels != in_channels
+
+ if self.learnable_shortcut:
+ shortcut_padding = 1 if shortcut_kernel_size == 3 else 0
+ self.shortcut = nn.Conv2d(
+ in_channels,
+ out_channels,
+ shortcut_kernel_size,
+ padding=shortcut_padding)
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = DenoisingUpsample(in_channels, False)
+ self.x_upd = DenoisingUpsample(in_channels, False)
+ elif down:
+ self.h_upd = DenoisingDownsample(in_channels, False)
+ self.x_upd = DenoisingDownsample(in_channels, False)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.init_weights()
+
+ def forward_shortcut(self, x):
+ if self.learnable_shortcut:
+ return self.shortcut(x)
+ return x
+
+ def forward(self, x, y):
+ """Forward function.
+
+ Args:
+ x (torch.Tensor): Input feature map tensor.
+ y (torch.Tensor): Shared time embedding or shared label embedding.
+
+ Returns:
+ torch.Tensor : Output feature map tensor.
+ """
+ if self.updown:
+ in_rest, in_conv = self.conv_1[:-1], self.conv_1[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.conv_1(x)
+
+ shortcut = self.forward_shortcut(x)
+ h = self.norm_with_embedding(h, y)
+ h = self.conv_2(h)
+ return h + shortcut
+
+ def init_weights(self):
+ # apply zero init to last conv layer
+ constant_init(self.conv_2[-1], 0)
+
+
+@MODULES.register_module()
+class NormWithEmbedding(BaseModule):
+ """Nornalization with embedding layer. If `use_scale_shift == True`,
+ embedding results will be chunked and used to re-shift and re-scale
+ normalization results. Otherwise, embedding results will directly add to
+ input of normalization layer.
+
+ Args:
+ in_channels (int): Number of channels of the input feature map.
+ embedding_channels (int) Number of channels of the input embedding.
+ norm_cfg (dict, optional): Config for the normalization operation.
+ Defaults to `dict(type='GN', num_groups=32)`.
+ act_cfg (dict, optional): Config for the activation layer. Defaults
+ to `dict(type='SiLU', inplace=False)`.
+ use_scale_shift (bool): If True, the output of Embedding layer will be
+ split to 'scale' and 'shift' and map the output of normalization
+ layer to ``out * (1 + scale) + shift``. Otherwise, the output of
+ Embedding layer will be added with the input before normalization
+ operation. Defaults to True.
+ """
+
+ def __init__(self,
+ in_channels,
+ embedding_channels,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='SiLU', inplace=False),
+ use_scale_shift=True):
+ super().__init__()
+ self.use_scale_shift = use_scale_shift
+ _, self.norm = build_norm_layer(norm_cfg, in_channels)
+
+ embedding_output = in_channels * 2 if use_scale_shift else in_channels
+ self.embedding_layer = nn.Sequential(
+ MODELS.build(act_cfg),
+ nn.Linear(embedding_channels, embedding_output))
+
+ def forward(self, x, y):
+ """Forward function.
+
+ Args:
+ x (torch.Tensor): Input feature map tensor.
+ y (torch.Tensor): Shared time embedding or shared label embedding.
+
+ Returns:
+ torch.Tensor : Output feature map tensor.
+ """
+ embedding = self.embedding_layer(y).type(x.dtype)
+ embedding = embedding[:, :, None, None]
+ if self.use_scale_shift:
+ scale, shift = torch.chunk(embedding, 2, dim=1)
+ x = self.norm(x)
+ x = x * (1 + scale) + shift
+ else:
+ x = self.norm(x + embedding)
+ return x
+
+
+@MODULES.register_module()
+class DenoisingDownsample(BaseModule):
+ """Downsampling operation used in the denoising network. Support average
+ pooling and convolution for downsample operation.
+
+ Args:
+ in_channels (int): Number of channels of the input feature map to be
+ downsampled.
+ with_conv (bool, optional): Whether use convolution operation for
+ downsampling. Defaults to `True`.
+ """
+
+ def __init__(self, in_channels, with_conv=True):
+ super().__init__()
+ if with_conv:
+ self.downsample = nn.Conv2d(in_channels, in_channels, 3, 2, 1)
+ else:
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2)
+
+ def forward(self, x):
+ """Forward function for downsampling operation.
+ Args:
+ x (torch.Tensor): Feature map to downsample.
+
+ Returns:
+ torch.Tensor: Feature map after downsampling.
+ """
+ return self.downsample(x)
+
+
+@MODULES.register_module()
+class DenoisingUpsample(BaseModule):
+ """Upsampling operation used in the denoising network. Allows users to
+ apply an additional convolution layer after the nearest interpolation
+ operation.
+
+ Args:
+ in_channels (int): Number of channels of the input feature map to be
+ downsampled.
+ with_conv (bool, optional): Whether apply an additional convolution
+ layer after upsampling. Defaults to `True`.
+ """
+
+ def __init__(self, in_channels, with_conv=True):
+ super().__init__()
+ self.with_conv = with_conv
+ if with_conv:
+ self.conv = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
+
+ def forward(self, x):
+ """Forward function for upsampling operation.
+ Args:
+ x (torch.Tensor): Feature map to upsample.
+
+ Returns:
+ torch.Tensor: Feature map after upsampling.
+ """
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+@MODULES.register_module()
+class DenoisingUnet(BaseModule):
+ """Denoising Unet. This network receives a diffused image ``x_t`` and
+ current timestep ``t``, and returns a ``output_dict`` corresponding to the
+ passed ``output_cfg``.
+
+ ``output_cfg`` defines the number of channels and the meaning of the
+ output. ``output_cfg`` mainly contains keys of ``mean`` and ``var``,
+ denoting how the network outputs mean and variance required for the
+ denoising process.
+ For ``mean``:
+ 1. ``dict(mean='EPS')``: Model will predict noise added in the
+ diffusion process, and the ``output_dict`` will contain a key named
+ ``eps_t_pred``.
+ 2. ``dict(mean='START_X')``: Model will direct predict the mean of the
+ original image `x_0`, and the ``output_dict`` will contain a key named
+ ``x_0_pred``.
+ 3. ``dict(mean='X_TM1_PRED')``: Model will predict the mean of diffused
+ image at `t-1` timestep, and the ``output_dict`` will contain a key
+ named ``x_tm1_pred``.
+
+ For ``var``:
+ 1. ``dict(var='FIXED_SMALL')`` or ``dict(var='FIXED_LARGE')``: Variance in
+ the denoising process is regarded as a fixed value. Therefore only
+ 'mean' will be predicted, and the output channels will equal to the
+ input image (e.g., three channels for RGB image.)
+ 2. ``dict(var='LEARNED')``: Model will predict `log_variance` in the
+ denoising process, and the ``output_dict`` will contain a key named
+ ``log_var``.
+ 3. ``dict(var='LEARNED_RANGE')``: Model will predict an interpolation
+ factor and the `log_variance` will be calculated as
+ `factor * upper_bound + (1-factor) * lower_bound`. The ``output_dict``
+ will contain a key named ``factor``.
+
+ If ``var`` is not ``FIXED_SMALL`` or ``FIXED_LARGE``, the number of output
+ channels will be the double of input channels, where the first half part
+ contains predicted mean values and the other part is the predicted
+ variance values. Otherwise, the number of output channels equals to the
+ input channels, only containing the predicted mean values.
+
+ Args:
+ image_size (int | list[int]): The size of image to denoise.
+ in_channels (int, optional): The input channels of the input image.
+ Defaults as ``3``.
+ base_channels (int, optional): The basic channel number of the
+ generator. The other layers contain channels based on this number.
+ Defaults to ``128``.
+ resblocks_per_downsample (int, optional): Number of ResBlock used
+ between two downsample operations. The number of ResBlock between
+ upsample operations will be the same value to keep symmetry.
+ Defaults to 3.
+ num_timesteps (int, optional): The total timestep of the denoising
+ process and the diffusion process. Defaults to ``1000``.
+ use_rescale_timesteps (bool, optional): Whether rescale the input
+ timesteps in range of [0, 1000]. Defaults to ``True``.
+ dropout (float, optional): The probability of dropout operation of
+ each ResBlock. Pass ``0`` to do not use dropout. Defaults as 0.
+ embedding_channels (int, optional): The output channels of time
+ embedding layer and label embedding layer. If not passed (or
+ passed ``-1``), output channels of the embedding layers will set
+ as four times of ``base_channels``. Defaults to ``-1``.
+ num_classes (int, optional): The number of conditional classes. If set
+ to 0, this model will be degraded to an unconditional model.
+ Defaults to 0.
+ channels_cfg (list | dict[list], optional): Config for input channels
+ of the intermedia blocks. If list is passed, each element of the
+ list indicates the scale factor for the input channels of the
+ current block with regard to the ``base_channels``. For block
+ ``i``, the input and output channels should be
+ ``channels_cfg[i] * base_channels`` and
+ ``channels_cfg[i+1] * base_channels`` If dict is provided, the key
+ of the dict should be the output scale and corresponding value
+ should be a list to define channels. Default: Please refer to
+ ``_defualt_channels_cfg``.
+ output_cfg (dict, optional): Config for output variables. Defaults to
+ ``dict(mean='eps', var='learned_range')``.
+ norm_cfg (dict, optional): The config for normalization layers.
+ Defaults to ``dict(type='GN', num_groups=32)``.
+ act_cfg (dict, optional): The config for activation layers. Defaults
+ to ``dict(type='SiLU', inplace=False)``.
+ shortcut_kernel_size (int, optional): The kernel size for shortcut
+ conv in ResBlocks. The value of this argument will overwrite the
+ default value of `resblock_cfg`. Defaults to `3`.
+ use_scale_shift_norm (bool, optional): Whether perform scale and shift
+ after normalization operation. Defaults to True.
+ num_heads (int, optional): The number of attention heads. Defaults to
+ 4.
+ time_embedding_mode (str, optional): Embedding method of
+ ``time_embedding``. Defaults to 'sin'.
+ time_embedding_cfg (dict, optional): Config for ``time_embedding``.
+ Defaults to None.
+ resblock_cfg (dict, optional): Config for ResBlock. Defaults to
+ ``dict(type='DenoisingResBlock')``.
+ attention_cfg (dict, optional): Config for attention operation.
+ Defaults to ``dict(type='MultiHeadAttention')``.
+ upsample_conv (bool, optional): Whether use conv in upsample block.
+ Defaults to ``True``.
+ downsample_conv (bool, optional): Whether use conv operation in
+ downsample block. Defaults to ``True``.
+ upsample_cfg (dict, optional): Config for upsample blocks.
+ Defaults to ``dict(type='DenoisingDownsample')``.
+ downsample_cfg (dict, optional): Config for downsample blocks.
+ Defaults to ``dict(type='DenoisingUpsample')``.
+ attention_res (int | list[int], optional): Resolution of feature maps
+ to apply attention operation. Defaults to ``[16, 8]``.
+ pretrained (str | dict, optional): Path for the pretrained model or
+ dict containing information for pretained models whose necessary
+ key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
+ the generator part from the whole state dict. Defaults to None.
+ """
+
+ _default_channels_cfg = {
+ 512: [0.5, 1, 1, 2, 2, 4, 4],
+ 256: [1, 1, 2, 2, 4, 4],
+ 128: [1, 1, 2, 3, 4],
+ 64: [1, 2, 3, 4],
+ 32: [1, 2, 2, 2]
+ }
+
+ def __init__(self,
+ image_size,
+ in_channels=3,
+ base_channels=128,
+ resblocks_per_downsample=3,
+ num_timesteps=1000,
+ use_rescale_timesteps=False,
+ dropout=0,
+ embedding_channels=-1,
+ num_classes=0,
+ use_fp16=False,
+ channels_cfg=None,
+ output_cfg=dict(mean='eps', var='learned_range'),
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='SiLU', inplace=False),
+ shortcut_kernel_size=1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ num_heads=4,
+ time_embedding_mode='sin',
+ time_embedding_cfg=None,
+ resblock_cfg=dict(type='DenoisingResBlock'),
+ attention_cfg=dict(type='MultiHeadAttention'),
+ downsample_conv=True,
+ upsample_conv=True,
+ downsample_cfg=dict(type='DenoisingDownsample'),
+ upsample_cfg=dict(type='DenoisingUpsample'),
+ attention_res=[16, 8],
+ pretrained=None):
+
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_timesteps = num_timesteps
+ self.use_rescale_timesteps = use_rescale_timesteps
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+
+ self.output_cfg = deepcopy(output_cfg)
+ self.mean_mode = self.output_cfg.get('mean', 'eps')
+ self.var_mode = self.output_cfg.get('var', 'learned_range')
+ self.in_channels = in_channels
+
+ # double output_channels to output mean and var at same time
+ out_channels = in_channels if 'FIXED' in self.var_mode.upper() \
+ else 2 * in_channels
+ self.out_channels = out_channels
+
+ # check type of image_size
+ if not isinstance(image_size, int) and not isinstance(
+ image_size, list):
+ raise TypeError(
+ 'Only support `int` and `list[int]` for `image_size`.')
+ if isinstance(image_size, list):
+ assert len(
+ image_size) == 2, 'The length of `image_size` should be 2.'
+ assert image_size[0] == image_size[
+ 1], 'Width and height of the image should be same.'
+ image_size = image_size[0]
+ self.image_size = image_size
+
+ channels_cfg = deepcopy(self._default_channels_cfg) \
+ if channels_cfg is None else deepcopy(channels_cfg)
+ if isinstance(channels_cfg, dict):
+ if image_size not in channels_cfg:
+ raise KeyError(f'`image_size={image_size} is not found in '
+ '`channels_cfg`, only support configs for '
+ f'{[chn for chn in channels_cfg.keys()]}')
+ self.channel_factor_list = channels_cfg[image_size]
+ elif isinstance(channels_cfg, list):
+ self.channel_factor_list = channels_cfg
+ else:
+ raise ValueError('Only support list or dict for `channels_cfg`, '
+ f'receive {type(channels_cfg)}')
+
+ embedding_channels = base_channels * 4 \
+ if embedding_channels == -1 else embedding_channels
+ self.time_embedding = TimeEmbedding(
+ base_channels,
+ embedding_channels=embedding_channels,
+ embedding_mode=time_embedding_mode,
+ embedding_cfg=time_embedding_cfg,
+ act_cfg=act_cfg)
+
+ if self.num_classes != 0:
+ self.label_embedding = nn.Embedding(self.num_classes,
+ embedding_channels)
+
+ self.resblock_cfg = deepcopy(resblock_cfg)
+ self.resblock_cfg.setdefault('dropout', dropout)
+ self.resblock_cfg.setdefault('norm_cfg', norm_cfg)
+ self.resblock_cfg.setdefault('act_cfg', act_cfg)
+ self.resblock_cfg.setdefault('embedding_channels', embedding_channels)
+ self.resblock_cfg.setdefault('use_scale_shift_norm',
+ use_scale_shift_norm)
+ self.resblock_cfg.setdefault('shortcut_kernel_size',
+ shortcut_kernel_size)
+
+ # get scales of ResBlock to apply attention
+ attention_scale = [image_size // int(res) for res in attention_res]
+ self.attention_cfg = deepcopy(attention_cfg)
+ self.attention_cfg.setdefault('num_heads', num_heads)
+ self.attention_cfg.setdefault('norm_cfg', norm_cfg)
+
+ self.downsample_cfg = deepcopy(downsample_cfg)
+ self.downsample_cfg.setdefault('with_conv', downsample_conv)
+ self.upsample_cfg = deepcopy(upsample_cfg)
+ self.upsample_cfg.setdefault('with_conv', upsample_conv)
+
+ # init the channel scale factor
+ scale = 1
+ ch = int(base_channels * self.channel_factor_list[0])
+ self.in_blocks = nn.ModuleList(
+ [EmbedSequential(nn.Conv2d(in_channels, ch, 3, 1, padding=1))])
+ self.in_channels_list = [ch]
+
+ # construct the encoder part of Unet
+ for level, factor in enumerate(self.channel_factor_list):
+ in_channels_ = ch if level == 0 \
+ else int(base_channels * self.channel_factor_list[level - 1])
+ out_channels_ = int(base_channels * factor)
+ for _ in range(resblocks_per_downsample):
+ layers = [
+ MODULES.build(
+ self.resblock_cfg,
+ default_args={
+ 'in_channels': in_channels_,
+ 'out_channels': out_channels_
+ })
+ ]
+ in_channels_ = out_channels_
+
+ if scale in attention_scale:
+ layers.append(
+ MODULES.build(
+ self.attention_cfg,
+ default_args={'in_channels': in_channels_}))
+
+ self.in_channels_list.append(in_channels_)
+ self.in_blocks.append(EmbedSequential(*layers))
+
+ if level != len(self.channel_factor_list) - 1:
+ self.in_blocks.append(
+ EmbedSequential(
+ DenoisingResBlock(
+ out_channels_,
+ embedding_channels,
+ use_scale_shift_norm,
+ dropout,
+ norm_cfg=norm_cfg,
+ out_channels=out_channels_,
+ down=True) if resblock_updown else MODULES.build(
+ self.downsample_cfg,
+ default_args={'in_channels': in_channels_})))
+ self.in_channels_list.append(in_channels_)
+ scale *= 2
+
+ # construct the bottom part of Unet
+ self.mid_blocks = EmbedSequential(
+ MODULES.build(
+ self.resblock_cfg, default_args={'in_channels': in_channels_}),
+ MODULES.build(
+ self.attention_cfg, default_args={'in_channels':
+ in_channels_}),
+ MODULES.build(
+ self.resblock_cfg, default_args={'in_channels': in_channels_}),
+ )
+
+ # construct the decoder part of Unet
+ in_channels_list = deepcopy(self.in_channels_list)
+ self.out_blocks = nn.ModuleList()
+ for level, factor in enumerate(self.channel_factor_list[::-1]):
+ for idx in range(resblocks_per_downsample + 1):
+ layers = [
+ MODULES.build(
+ self.resblock_cfg,
+ default_args={
+ 'in_channels':
+ in_channels_ + in_channels_list.pop(),
+ 'out_channels': int(base_channels * factor)
+ })
+ ]
+ in_channels_ = int(base_channels * factor)
+ if scale in attention_scale:
+ layers.append(
+ MODULES.build(
+ self.attention_cfg,
+ default_args={'in_channels': in_channels_}))
+ if (level != len(self.channel_factor_list) - 1
+ and idx == resblocks_per_downsample):
+ out_channels_ = in_channels_
+ layers.append(
+ DenoisingResBlock(
+ in_channels_,
+ embedding_channels,
+ use_scale_shift_norm,
+ dropout,
+ norm_cfg=norm_cfg,
+ out_channels=out_channels_,
+ up=True) if resblock_updown else MODULES.build(
+ self.upsample_cfg,
+ default_args={'in_channels': in_channels_}))
+ scale //= 2
+ self.out_blocks.append(EmbedSequential(*layers))
+
+ self.out = ConvModule(
+ in_channels=in_channels_,
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ bias=True,
+ order=('norm', 'act', 'conv'))
+
+ self.init_weights(pretrained)
+
+ def forward(self, x_t, t, label=None, return_noise=False):
+ """Forward function.
+ Args:
+ x_t (torch.Tensor): Diffused image at timestep `t` to denoise.
+ t (torch.Tensor): Current timestep.
+ label (torch.Tensor | callable | None): You can directly give a
+ batch of label through a ``torch.Tensor`` or offer a callable
+ function to sample a batch of label data. Otherwise, the
+ ``None`` indicates to use the default label sampler.
+ return_noise (bool, optional): If True, inputted ``x_t`` and ``t``
+ will be returned in a dict with output desired by
+ ``output_cfg``. Defaults to False.
+
+ Returns:
+ torch.Tensor | dict: If not ``return_noise``
+ """
+ if not torch.is_tensor(t):
+ t = torch.tensor([t], dtype=torch.long, device=x_t.device)
+ elif torch.is_tensor(t) and len(t.shape) == 0:
+ t = t[None].to(x_t.device)
+
+ embedding = self.time_embedding(t)
+
+ if label is not None:
+ assert hasattr(self, 'label_embedding')
+ embedding = self.label_embedding(label) + embedding
+
+ h, hs = x_t, []
+ h = h.type(self.dtype)
+ # forward downsample blocks
+ for block in self.in_blocks:
+ h = block(h, embedding)
+ hs.append(h)
+
+ # forward middle blocks
+ h = self.mid_blocks(h, embedding)
+
+ # forward upsample blocks
+ for block in self.out_blocks:
+ h = block(torch.cat([h, hs.pop()], dim=1), embedding)
+ h = h.type(x_t.dtype)
+ outputs = self.out(h)
+
+ return {'outputs': outputs}
+
+ def init_weights(self, pretrained=None):
+ """Init weights for models.
+
+ We just use the initialization method proposed in the original paper.
+
+ Args:
+ pretrained (str, optional): Path for pretrained weights. If given
+ None, pretrained weights will not be loaded. Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = MMLogger.get_current_instance()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ # As Improved-DDPM, we apply zero-initialization to
+ # second conv block in ResBlock (keywords: conv_2)
+ # the output layer of the Unet (keywords: 'out' but
+ # not 'out_blocks')
+ # projection layer in Attention layer (keywords: proj)
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Conv2d) and ('conv_2' in n or
+ ('out' in n
+ and 'out_blocks' not in n)):
+ constant_init(m, 0)
+ if isinstance(m, nn.Conv1d) and 'proj' in n:
+ constant_init(m, 0)
+ else:
+ raise TypeError('pretrained must be a str or None but'
+ f' got {type(pretrained)} instead.')
+
+ def convert_to_fp16(self):
+ """Convert the precision of the model to float16."""
+ self.in_blocks.apply(convert_module_to_f16)
+ self.mid_blocks.apply(convert_module_to_f16)
+ self.out_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """Convert the precision of the model to float32."""
+ self.in_blocks.apply(convert_module_to_f32)
+ self.mid_blocks.apply(convert_module_to_f32)
+ self.out_blocks.apply(convert_module_to_f32)
diff --git a/mmedit/models/editors/disco_diffusion/__init__.py b/mmedit/models/editors/disco_diffusion/__init__.py
new file mode 100644
index 0000000000..ae1b8ebba7
--- /dev/null
+++ b/mmedit/models/editors/disco_diffusion/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .clip_wrapper import ClipWrapper
+from .disco import DiscoDiffusion
+from .guider import ImageTextGuider
+from .secondary_model import SecondaryDiffusionImageNet2, alpha_sigma_to_t
+
+__all__ = [
+ 'DiscoDiffusion', 'ImageTextGuider', 'ClipWrapper',
+ 'SecondaryDiffusionImageNet2', 'alpha_sigma_to_t'
+]
diff --git a/mmedit/models/editors/disco_diffusion/clip_wrapper.py b/mmedit/models/editors/disco_diffusion/clip_wrapper.py
new file mode 100644
index 0000000000..463dad657d
--- /dev/null
+++ b/mmedit/models/editors/disco_diffusion/clip_wrapper.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from mmedit.registry import MODELS
+
+
+@MODELS.register_module()
+class ClipWrapper(nn.Module):
+ """Clip Models wrapper for disco-diffusion.
+
+ We provide wrappers for the clip models of ``openai`` and
+ ``mlfoundations``, where the user can specify ``clip_type``
+ as ``clip`` or ``open_clip``, and then initialize a clip model
+ using the same arguments as in the original codebase. The
+ following clip models settings are provided in the official
+ repo of disco diffusion:
+
+ | Setting | Source | Arguments | # noqa
+ |:-----------------------------:|-----------|--------------------------------------------------------------| # noqa
+ | ViTB32 | clip | name='ViT-B/32', jit=False | # noqa
+ | ViTB16 | clip | name='ViT-B/16', jit=False | # noqa
+ | ViTL14 | clip | name='ViT-L/14', jit=False | # noqa
+ | ViTL14_336px | clip | name='ViT-L/14@336px', jit=False | # noqa
+ | RN50 | clip | name='RN50', jit=False | # noqa
+ | RN50x4 | clip | name='RN50x4', jit=False | # noqa
+ | RN50x16 | clip | name='RN50x16', jit=False | # noqa
+ | RN50x64 | clip | name='RN50x64', jit=False | # noqa
+ | RN101 | clip | name='RN101', jit=False | # noqa
+ | ViTB32_laion2b_e16 | open_clip | name='ViT-B-32', pretrained='laion2b_e16' | # noqa
+ | ViTB32_laion400m_e31 | open_clip | model_name='ViT-B-32', pretrained='laion400m_e31' | # noqa
+ | ViTB32_laion400m_32 | open_clip | model_name='ViT-B-32', pretrained='laion400m_e32' | # noqa
+ | ViTB32quickgelu_laion400m_e31 | open_clip | model_name='ViT-B-32-quickgelu', pretrained='laion400m_e31' | # noqa
+ | ViTB32quickgelu_laion400m_e32 | open_clip | model_name='ViT-B-32-quickgelu', pretrained='laion400m_e32' | # noqa
+ | ViTB16_laion400m_e31 | open_clip | model_name='ViT-B-16', pretrained='laion400m_e31' | # noqa
+ | ViTB16_laion400m_e32 | open_clip | model_name='ViT-B-16', pretrained='laion400m_e32' | # noqa
+ | RN50_yffcc15m | open_clip | model_name='RN50', pretrained='yfcc15m' | # noqa
+ | RN50_cc12m | open_clip | model_name='RN50', pretrained='cc12m' | # noqa
+ | RN50_quickgelu_yfcc15m | open_clip | model_name='RN50-quickgelu', pretrained='yfcc15m' | # noqa
+ | RN50_quickgelu_cc12m | open_clip | model_name='RN50-quickgelu', pretrained='cc12m' | # noqa
+ | RN101_yfcc15m | open_clip | model_name='RN101', pretrained='yfcc15m' | # noqa
+ | RN101_quickgelu_yfcc15m | open_clip | model_name='RN101-quickgelu', pretrained='yfcc15m' | # noqa
+
+ An example of a ``clip_modes_cfg`` is as follows:
+ .. code-block:: python
+
+ clip_models = [
+ dict(type='ClipWrapper', clip_type='clip', name='ViT-B/32', jit=False),
+ dict(type='ClipWrapper', clip_type='clip', name='ViT-B/16', jit=False),
+ dict(type='ClipWrapper', clip_type='clip', name='RN50', jit=False)
+ ]
+
+ Args:
+ clip_type (List[Dict]): The original source of the clip model. Whether be
+ ``clip`` or ``open_clip``.
+ """
+
+ def __init__(self, clip_type, *args, **kwargs):
+
+ super().__init__()
+ self.clip_type = clip_type
+ assert clip_type in ['clip', 'open_clip']
+ if clip_type == 'clip':
+ try:
+ import clip
+ except ImportError:
+ raise ImportError(
+ 'clip need to be installed! Run `pip install -r requirements/optional.txt` and try again' # noqa
+ ) # noqa
+ self.model, _ = clip.load(*args, **kwargs)
+ elif clip_type == 'open_clip':
+ try:
+ import open_clip
+ except ImportError:
+ raise ImportError(
+ 'open_clip_torch need to be installed! Run `pip install -r requirements/optional.txt` and try again' # noqa
+ ) # noqa
+ self.model = open_clip.create_model(*args, **kwargs)
+ self.model.eval().requires_grad_(False)
+
+ def forward(self, *args, **kwargs):
+ """Forward function."""
+ return self.model(*args, **kwargs)
diff --git a/mmedit/models/editors/disco_diffusion/disco.py b/mmedit/models/editors/disco_diffusion/disco.py
new file mode 100644
index 0000000000..3a4e01e700
--- /dev/null
+++ b/mmedit/models/editors/disco_diffusion/disco.py
@@ -0,0 +1,236 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Union
+
+import mmcv
+import mmengine
+import torch
+import torch.nn as nn
+from mmengine.runner import set_random_seed
+from mmengine.runner.checkpoint import (_load_checkpoint,
+ _load_checkpoint_with_prefix)
+from tqdm import tqdm
+
+from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS, MODULES
+from .guider import ImageTextGuider
+
+ModelType = Union[Dict, nn.Module]
+
+
+@MODELS.register_module('disco')
+@MODELS.register_module('dd')
+@MODELS.register_module()
+class DiscoDiffusion(nn.Module):
+ """Disco Diffusion (DD) is a Google Colab Notebook which leverages an AI
+ Image generating technique called CLIP-Guided Diffusion to allow you to
+ create compelling and beautiful images from just text inputs. Created by
+ Somnai, augmented by Gandamu, and building on the work of RiversHaveWings,
+ nshepperd, and many others.
+
+ Ref:
+ Github Repo: https://github.com/alembics/disco-diffusion
+ Colab: https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb # noqa
+
+ Args:
+ unet (ModelType): Config of denoising Unet.
+ diffusion_scheduler (ModelType): Config of diffusion_scheduler scheduler.
+ secondary_model (ModelType): A smaller secondary diffusion model
+ trained by Katherine Crowson to remove noise from intermediate
+ timesteps to prepare them for CLIP.
+ Ref: https://twitter.com/rivershavewings/status/1462859669454536711 # noqa
+ Defaults to None.
+ clip_models (list): Config of clip models. Defaults to [].
+ use_fp16 (bool): Whether to use fp16 for unet model. Defaults to False.
+ pretrained_cfgs (dict): Path Config for pretrained weights. Usually
+ this is a dict contains module name and the corresponding ckpt
+ path. Defaults to None.
+ """
+
+ def __init__(self,
+ unet,
+ diffusion_scheduler,
+ secondary_model=None,
+ clip_models=[],
+ use_fp16=False,
+ pretrained_cfgs=None):
+ super().__init__()
+ self.unet = unet if isinstance(unet,
+ nn.Module) else MODULES.build(unet)
+ self.diffusion_scheduler = DIFFUSION_SCHEDULERS.build(
+ diffusion_scheduler) if isinstance(diffusion_scheduler,
+ dict) else diffusion_scheduler
+
+ assert len(clip_models) > 0
+ if isinstance(clip_models[0], nn.Module):
+ _clip_models = clip_models
+ else:
+ _clip_models = []
+ for clip_cfg in clip_models:
+ _clip_models.append(MODULES.build(clip_cfg))
+ self.guider = ImageTextGuider(_clip_models)
+
+ if secondary_model is not None:
+ self.secondary_model = secondary_model if isinstance(
+ secondary_model, nn.Module) else MODULES.build(secondary_model)
+ self.with_secondary_model = True
+ else:
+ self.with_secondary_model = False
+
+ if pretrained_cfgs:
+ self.load_pretrained_models(pretrained_cfgs)
+ if use_fp16:
+ mmengine.print_log('Convert unet modules to floatpoint16')
+ self.unet.convert_to_fp16()
+
+ def load_pretrained_models(self, pretrained_cfgs):
+ """Loading pretrained weights to model. ``pretrained_cfgs`` is a dict
+ consist of module name as key and checkpoint path as value.
+
+ Args:
+ pretrained_cfgs (dict): Path Config for pretrained weights.
+ Usually this is a dict contains module name and the
+ corresponding ckpt path. Defaults to None.
+ """
+ for key, ckpt_cfg in pretrained_cfgs.items():
+ prefix = ckpt_cfg.get('prefix', '')
+ map_location = ckpt_cfg.get('map_location', 'cpu')
+ strict = ckpt_cfg.get('strict', True)
+ ckpt_path = ckpt_cfg.get('ckpt_path')
+ if prefix:
+ state_dict = _load_checkpoint_with_prefix(
+ prefix, ckpt_path, map_location)
+ else:
+ state_dict = _load_checkpoint(ckpt_path, map_location)
+ getattr(self, key).load_state_dict(state_dict, strict=strict)
+ mmengine.print_log(f'Load pretrained {key} from {ckpt_path}')
+
+ @property
+ def device(self):
+ """Get current device of the model.
+
+ Returns:
+ torch.device: The current device of the model.
+ """
+ return next(self.parameters()).device
+
+ @torch.no_grad()
+ def infer(self,
+ scheduler_kwargs=None,
+ height=None,
+ width=None,
+ init_image=None,
+ batch_size=1,
+ num_inference_steps=1000,
+ skip_steps=0,
+ show_progress=False,
+ text_prompts=[],
+ image_prompts=[],
+ eta=0.8,
+ clip_guidance_scale=5000,
+ init_scale=1000,
+ tv_scale=0.,
+ sat_scale=0.,
+ range_scale=150,
+ cut_overview=[12] * 400 + [4] * 600,
+ cut_innercut=[4] * 400 + [12] * 600,
+ cut_ic_pow=[1] * 1000,
+ cut_icgray_p=[0.2] * 400 + [0] * 600,
+ cutn_batches=4,
+ seed=None):
+ """Inference API for disco diffusion.
+
+ Args:
+ scheduler_kwargs (dict): Args for infer time diffusion
+ scheduler. Defaults to None.
+ height (int): Height of output image. Defaults to None.
+ width (int): Width of output image. Defaults to None.
+ init_image (str): Initial image at the start point
+ of denoising. Defaults to None.
+ batch_size (int): Batch size. Defaults to 1.
+ num_inference_steps (int): Number of inference steps.
+ Defaults to 1000.
+ skip_steps (int): Denoising steps to skip, usually set
+ with ``init_image``. Defaults to 0.
+ show_progress (bool): Whether to show progress.
+ Defaults to False.
+ text_prompts (list): Text prompts. Defaults to [].
+ image_prompts (list): Image prompts, this is not the same as
+ ``init_image``, they works the same way with
+ ``text_prompts``. Defaults to [].
+ eta (float): Eta for ddim sampling. Defaults to 0.8.
+ clip_guidance_scale (int): The Scale of influence of prompts
+ on output image. Defaults to 1000.
+ seed (int): Sampling seed. Defaults to None.
+ """
+ # set diffusion_scheduler
+ if scheduler_kwargs is not None:
+ mmengine.print_log('Switch to infer diffusion scheduler!',
+ 'current')
+ infer_scheduler = DIFFUSION_SCHEDULERS.build(scheduler_kwargs)
+ else:
+ infer_scheduler = self.diffusion_scheduler
+ # set random seed
+ if isinstance(seed, int):
+ set_random_seed(seed=seed)
+
+ # set step values
+ if num_inference_steps > 0:
+ infer_scheduler.set_timesteps(num_inference_steps)
+
+ _ = image_prompts
+
+ height = (height // 64) * 64 if height else self.unet.image_size
+ width = (width // 64) * 64 if width else self.unet.image_size
+ if init_image is None:
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, height, width))
+ image = image.to(self.device)
+ else:
+ init = mmcv.imread(init_image, channel_order='rgb')
+ init = mmcv.imresize(
+ init, (width, height), interpolation='lanczos') / 255.
+ init_image = torch.as_tensor(
+ init,
+ dtype=torch.float32).to(self.device).unsqueeze(0).permute(
+ 0, 3, 1, 2).mul(2).sub(1)
+ image = init_image.clone()
+ image = infer_scheduler.add_noise(
+ image, torch.randn_like(image),
+ infer_scheduler.timesteps[skip_steps])
+ # get stats from text prompts and image prompts
+ model_stats = self.guider.compute_prompt_stats(
+ text_prompts=text_prompts)
+ timesteps = infer_scheduler.timesteps[skip_steps:]
+ if show_progress:
+ timesteps = tqdm(timesteps)
+ for t in timesteps:
+ # 1. predicted model_output
+ model_output = self.unet(image, t)['outputs']
+
+ # 2. compute previous image: x_t -> x_t-1
+ cond_kwargs = dict(
+ model_stats=model_stats,
+ init_image=init_image,
+ unet=self.unet,
+ clip_guidance_scale=clip_guidance_scale,
+ init_scale=init_scale,
+ tv_scale=tv_scale,
+ sat_scale=sat_scale,
+ range_scale=range_scale,
+ cut_overview=cut_overview,
+ cut_innercut=cut_innercut,
+ cut_ic_pow=cut_ic_pow,
+ cut_icgray_p=cut_icgray_p,
+ cutn_batches=cutn_batches,
+ )
+ if self.with_secondary_model:
+ cond_kwargs.update(secondary_model=self.secondary_model)
+ diffusion_scheduler_output = infer_scheduler.step(
+ model_output,
+ t,
+ image,
+ cond_fn=self.guider.cond_fn,
+ cond_kwargs=cond_kwargs,
+ eta=eta)
+
+ image = diffusion_scheduler_output['prev_sample']
+ return {'samples': image}
diff --git a/mmedit/models/editors/disco_diffusion/guider.py b/mmedit/models/editors/disco_diffusion/guider.py
new file mode 100644
index 0000000000..b1c425dd3b
--- /dev/null
+++ b/mmedit/models/editors/disco_diffusion/guider.py
@@ -0,0 +1,500 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import clip
+import lpips
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+from resize_right import resize
+
+from mmedit.models.losses import tv_loss
+from .secondary_model import alpha_sigma_to_t
+
+normalize = T.Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073],
+ std=[0.26862954, 0.26130258, 0.27577711])
+
+
+def sinc(x):
+ """
+ Sinc function.
+ If x equal to 0,
+ sinc(x) = 1
+ else:
+ sinc(x) = sin(x)/ x
+ Args:
+ x (torch.Tensor): Input Tensor
+
+ Returns:
+ torch.Tensor: Function output.
+ """
+ return torch.where(x != 0,
+ torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
+
+
+def lanczos(x, a):
+ """Lanczos filter's reconstruction kernel L(x)."""
+ cond = torch.logical_and(-a < x, x < a)
+ out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
+ return out / out.sum()
+
+
+def ramp(ratio, width):
+ """_summary_
+
+ Args:
+ ratio (_type_): _description_
+ width (_type_): _description_
+
+ Returns:
+ _type_: _description_
+ """
+ n = math.ceil(width / ratio + 1)
+ out = torch.empty([n])
+ cur = 0
+ for i in range(out.shape[0]):
+ out[i] = cur
+ cur += ratio
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
+
+
+def resample(input, size, align_corners=True):
+ """Lanczos resampling image.
+
+ Args:
+ input (torch.Tensor): Input image tensor.
+ size (Tuple[int, int]): Output image size.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Defaults to True.
+
+ Returns:
+ torch.Tensor: Resampling results.
+ """
+ n, c, h, w = input.shape
+ dh, dw = size
+
+ input = input.reshape([n * c, 1, h, w])
+
+ if dh < h:
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
+ pad_h = (kernel_h.shape[0] - 1) // 2
+ input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
+ input = F.conv2d(input, kernel_h[None, None, :, None])
+
+ if dw < w:
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
+ pad_w = (kernel_w.shape[0] - 1) // 2
+ input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
+ input = F.conv2d(input, kernel_w[None, None, None, :])
+
+ input = input.reshape([n, c, h, w])
+ return F.interpolate(
+ input, size, mode='bicubic', align_corners=align_corners)
+
+
+def range_loss(input):
+ """range loss."""
+ return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
+
+
+def spherical_dist_loss(x, y):
+ """spherical distance loss."""
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+class MakeCutouts(nn.Module):
+ """Each iteration, the AI cuts the image into smaller pieces known as cuts.
+
+ , and compares each cut to the prompt to decide how to guide the next
+ diffusion step.
+ This classes will randomly cut patches and perform image augmentation to
+ these patches.
+
+ Args:
+ cut_size (int): Size of the patches.
+ cutn (int): Number of patches to cut.
+ """
+
+ def __init__(self, cut_size, cutn):
+ super().__init__()
+ self.cut_size = cut_size
+ self.cutn = cutn
+ self.augs = T.Compose([
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomPerspective(distortion_scale=0.4, p=0.7),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.15),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ ])
+
+ def forward(self, input, skip_augs=False):
+ input = T.Pad(input.shape[2] // 4, fill=0)(input)
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+
+ cutouts = []
+ for ch in range(self.cutn):
+ if ch > self.cutn - self.cutn // 4:
+ cutout = input.clone()
+ else:
+ size = int(max_size * torch.zeros(1, ).normal_(
+ mean=.8, std=.3).clip(float(self.cut_size / max_size), 1.))
+ offsetx = torch.randint(0, abs(sideX - size + 1), ())
+ offsety = torch.randint(0, abs(sideY - size + 1), ())
+ cutout = input[:, :, offsety:offsety + size,
+ offsetx:offsetx + size]
+
+ if not skip_augs:
+ cutout = self.augs(cutout)
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
+ del cutout
+
+ cutouts = torch.cat(cutouts, dim=0)
+ return cutouts
+
+
+class MakeCutoutsDango(nn.Module):
+ """Dango233(https://github.com/Dango233)'s version of MakeCutouts.
+
+ The improvement compared to ``MakeCutouts`` is that it use partial
+ greyscale augmentation to capture structure, and partial rotation
+ augmentation to capture whole frames.
+
+ Args:
+ cut_size (int): Size of the patches.
+ Overview (int): The total number of overview cuts.
+ In details,
+ Overview=1, Add whole frame;
+ Overview=2, Add grayscaled frame;
+ Overview=3, Add horizontal flip frame;
+ Overview=4, Add grayscaled horizontal flip frame;
+ Overview>4, Repeat add frame Overview times.
+ Defaults to 4.
+ InnerCrop (int): The total number of inner cuts.
+ Defaults to 0.
+ IC_Size_Pow (float): This sets the size of the border
+ used for inner cuts. High values have larger borders,
+ and therefore the cuts themselves will be smaller and
+ provide finer details. Defaults to 0.5.
+ IC_Grey_P (float): The portion of the inner cuts can be set to be
+ grayscale instead of color. This may help with improved
+ definition of shapes and edges, especially in the early
+ diffusion steps where the image structure is being defined.
+ Defaults to 0.2.
+ """
+
+ def __init__(self,
+ cut_size,
+ Overview=4,
+ InnerCrop=0,
+ IC_Size_Pow=0.5,
+ IC_Grey_P=0.2):
+ super().__init__()
+ self.cut_size = cut_size
+ self.Overview = Overview
+ self.InnerCrop = InnerCrop
+ self.IC_Size_Pow = IC_Size_Pow
+ self.IC_Grey_P = IC_Grey_P
+
+ self.augs = T.Compose([
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(
+ degrees=10,
+ translate=(0.05, 0.05),
+ interpolation=T.InterpolationMode.BILINEAR),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.1),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.ColorJitter(
+ brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+ ])
+
+ def forward(self, input, skip_augs=False):
+ """Forward function."""
+ cutouts = []
+ gray = T.Grayscale(3)
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ output_shape = [1, 3, self.cut_size, self.cut_size]
+ pad_input = F.pad(input,
+ ((sideY - max_size) // 2, (sideY - max_size) // 2,
+ (sideX - max_size) // 2, (sideX - max_size) // 2))
+ cutout = resize(pad_input, out_shape=output_shape)
+
+ if self.Overview > 0:
+ if self.Overview <= 4:
+ if self.Overview >= 1:
+ cutouts.append(cutout)
+ if self.Overview >= 2:
+ cutouts.append(gray(cutout))
+ if self.Overview >= 3:
+ cutouts.append(TF.hflip(cutout))
+ if self.Overview == 4:
+ cutouts.append(gray(TF.hflip(cutout)))
+ else:
+ cutout = resize(pad_input, out_shape=output_shape)
+ for _ in range(self.Overview):
+ cutouts.append(cutout)
+
+ if self.InnerCrop > 0:
+ for i in range(self.InnerCrop):
+ size = int(
+ torch.rand([])**self.IC_Size_Pow * (max_size - min_size) +
+ min_size)
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = input[:, :, offsety:offsety + size,
+ offsetx:offsetx + size]
+ if i <= int(self.IC_Grey_P * self.InnerCrop):
+ cutout = gray(cutout)
+ cutout = resize(cutout, out_shape=output_shape)
+ cutouts.append(cutout)
+ cutouts = torch.cat(cutouts)
+ if not skip_augs:
+ cutouts = self.augs(cutouts)
+ return cutouts
+
+
+def parse_prompt(prompt):
+ """Parse prompt, return text and text weight."""
+ if prompt.startswith('http://') or prompt.startswith('https://'):
+ vals = prompt.rsplit(':', 2)
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
+ else:
+ vals = prompt.rsplit(':', 1)
+ vals = vals + ['', '1'][len(vals):]
+ return vals[0], float(vals[1])
+
+
+def split_prompts(prompts, max_frames=1):
+ """Split prompts to a list of prompts."""
+ prompt_series = pd.Series([np.nan for a in range(max_frames)])
+ for i, prompt in prompts.items():
+ prompt_series[i] = prompt
+ # prompt_series = prompt_series.astype(str)
+ prompt_series = prompt_series.ffill().bfill()
+ return prompt_series
+
+
+class ImageTextGuider(nn.Module):
+ """Disco-Diffusion uses text and images to guide image generation. We will
+ use the clip models to extract text and image features as prompts, and then
+ during the iteration, the features of the image patches are computed, and
+ the similarity loss between the prompts features and the generated features
+ is computed. Other losses also include RGB Range loss, total variation
+ loss. Using these losses we can guide the image generation towards the
+ desired target.
+
+ Args:
+ clip_models (List[Dict]): List of clip model settings.
+ """
+
+ def __init__(self, clip_models):
+ super().__init__()
+ self.clip_models = clip_models
+ self.lpips_model = lpips.LPIPS(net='vgg')
+
+ def frame_prompt_from_text(self, text_prompts, frame_num=0):
+ """Get current frame prompt."""
+ prompts_series = split_prompts(text_prompts)
+ if prompts_series is not None and frame_num >= len(prompts_series):
+ frame_prompt = prompts_series[-1]
+ elif prompts_series is not None:
+ frame_prompt = prompts_series[frame_num]
+ else:
+ frame_prompt = []
+ return frame_prompt
+
+ def compute_prompt_stats(self,
+ text_prompts=[],
+ image_prompt=None,
+ fuzzy_prompt=False,
+ rand_mag=0.05):
+ """Compute prompts statistics.
+
+ Args:
+ text_prompts (list): Text prompts. Defaults to [].
+ image_prompt (list): Image prompts. Defaults to None.
+ fuzzy_prompt (bool, optional): Controls whether to add multiple
+ noisy prompts to the prompt losses. If True, can increase
+ variability of image output. Defaults to False.
+ rand_mag (float, optional): Controls the magnitude of the
+ random noise added by fuzzy_prompt. Defaults to 0.05.
+ """
+ model_stats = []
+ frame_prompt = self.frame_prompt_from_text(text_prompts)
+ for clip_model in self.clip_models:
+ model_stat = {
+ 'clip_model': None,
+ 'target_embeds': [],
+ 'make_cutouts': None,
+ 'weights': []
+ }
+ model_stat['clip_model'] = clip_model
+
+ for prompt in frame_prompt:
+ txt, weight = parse_prompt(prompt)
+ txt = clip_model.model.encode_text(
+ clip.tokenize(prompt).to(self.device)).float()
+
+ if fuzzy_prompt:
+ for i in range(25):
+ model_stat['target_embeds'].append(
+ (txt +
+ torch.randn(txt.shape).cuda() * rand_mag).clamp(
+ 0, 1))
+ model_stat['weights'].append(weight)
+ else:
+ model_stat['target_embeds'].append(txt)
+ model_stat['weights'].append(weight)
+ model_stat['target_embeds'] = torch.cat(
+ model_stat['target_embeds'])
+ model_stat['weights'] = torch.tensor(
+ model_stat['weights'], device=self.device)
+ if model_stat['weights'].sum().abs() < 1e-3:
+ raise RuntimeError('The weights must not sum to 0.')
+ model_stat['weights'] /= model_stat['weights'].sum().abs()
+ model_stats.append(model_stat)
+ return model_stats
+
+ def cond_fn(self,
+ model,
+ diffusion_scheduler,
+ x,
+ t,
+ beta_prod_t,
+ model_stats,
+ secondary_model=None,
+ init_image=None,
+ clamp_grad=True,
+ clamp_max=0.05,
+ clip_guidance_scale=5000,
+ init_scale=1000,
+ tv_scale=0.,
+ sat_scale=0.,
+ range_scale=150,
+ cut_overview=[12] * 400 + [4] * 600,
+ cut_innercut=[4] * 400 + [12] * 600,
+ cut_ic_pow=[1] * 1000,
+ cut_icgray_p=[0.2] * 400 + [0] * 600,
+ cutn_batches=4):
+ """Clip guidance function.
+
+ Args:
+ model (nn.Module): _description_
+ diffusion_scheduler (object): _description_
+ x (torch.Tensor): _description_
+ t (int): _description_
+ beta_prod_t (torch.Tensor): _description_
+ model_stats (List[torch.Tensor]): _description_
+ secondary_model (nn.Module): A smaller secondary diffusion model
+ trained by Katherine Crowson to remove noise from intermediate
+ timesteps to prepare them for CLIP.
+ Ref: https://twitter.com/rivershavewings/status/1462859669454536711 # noqa
+ Defaults to None.
+ init_image (torch.Tensor): Initial image for denoising.
+ Defaults to None.
+ clamp_grad (bool, optional): Whether clamp gradient. Defaults to True.
+ clamp_max (float, optional): Clamp max values. Defaults to 0.05.
+ clip_guidance_scale (int, optional): The scale of influence of
+ clip guidance on image generation. Defaults to 5000.
+ """
+ with torch.enable_grad():
+ x_is_NaN = False
+ x = x.detach().requires_grad_()
+ n = x.shape[0]
+ if secondary_model is not None:
+ alpha = torch.tensor(
+ diffusion_scheduler.alphas_cumprod[t]**0.5,
+ dtype=torch.float32)
+ sigma = torch.tensor(
+ (1 - diffusion_scheduler.alphas_cumprod[t])**0.5,
+ dtype=torch.float32)
+ cosine_t = alpha_sigma_to_t(alpha, sigma).to(x.device)
+ model_output = secondary_model(
+ x, cosine_t[None].repeat([x.shape[0]]))
+ pred_original_sample = model_output['pred']
+ else:
+ model_output = model(x, t)['outputs']
+ model_output, predicted_variance = torch.split(
+ model_output, x.shape[1], dim=1)
+ alpha_prod_t = 1 - beta_prod_t
+ pred_original_sample = (x - beta_prod_t**(0.5) *
+ model_output) / alpha_prod_t**(0.5)
+ # fac = diffusion_scheduler_output['beta_prod_t']** (0.5)
+ # x_in = diffusion_scheduler_output['original_sample'] * fac + x * (1 - fac) # noqa
+ fac = beta_prod_t**(0.5)
+ x_in = pred_original_sample * fac + x * (1 - fac)
+ x_in_grad = torch.zeros_like(x_in)
+ for model_stat in model_stats:
+ for i in range(cutn_batches):
+ t_int = int(t.item()) + 1
+ try:
+ input_resolution = model_stat[
+ 'clip_model'].model.visual.input_resolution
+ except AttributeError:
+ input_resolution = 224
+
+ cuts = MakeCutoutsDango(
+ input_resolution,
+ Overview=cut_overview[1000 - t_int],
+ InnerCrop=cut_innercut[1000 - t_int],
+ IC_Size_Pow=cut_ic_pow[1000 - t_int],
+ IC_Grey_P=cut_icgray_p[1000 - t_int])
+ clip_in = normalize(cuts(x_in.add(1).div(2)))
+ image_embeds = model_stat['clip_model'].model.encode_image(
+ clip_in).float()
+ dists = spherical_dist_loss(
+ image_embeds.unsqueeze(1),
+ model_stat['target_embeds'].unsqueeze(0))
+ dists = dists.view([
+ cut_overview[1000 - t_int] +
+ cut_innercut[1000 - t_int], n, -1
+ ])
+ losses = dists.mul(model_stat['weights']).sum(2).mean(0)
+ x_in_grad += torch.autograd.grad(
+ losses.sum() * clip_guidance_scale,
+ x_in)[0] / cutn_batches
+ tv_losses = tv_loss(x_in)
+ range_losses = range_loss(pred_original_sample)
+ sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean()
+ loss = tv_losses.sum() * tv_scale + range_losses.sum(
+ ) * range_scale + sat_losses.sum() * sat_scale
+ if init_image is not None and init_scale:
+ init_losses = self.lpips_model(x_in, init_image)
+ loss = loss + init_losses.sum() * init_scale
+ x_in_grad += torch.autograd.grad(loss, x_in)[0]
+ if not torch.isnan(x_in_grad).any():
+ grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
+ else:
+ x_is_NaN = True
+ grad = torch.zeros_like(x)
+ if clamp_grad and not x_is_NaN:
+ magnitude = grad.square().mean().sqrt()
+ return grad * magnitude.clamp(max=clamp_max) / magnitude
+ return grad
+
+ @property
+ def device(self):
+ """Get current device of the model.
+
+ Returns:
+ torch.device: The current device of the model.
+ """
+ return next(self.parameters()).device
+
+ def forward(self, x):
+ """forward function."""
+ raise NotImplementedError('No forward function for disco guider')
diff --git a/mmedit/models/editors/disco_diffusion/secondary_model.py b/mmedit/models/editors/disco_diffusion/secondary_model.py
new file mode 100644
index 0000000000..0fd0646c3d
--- /dev/null
+++ b/mmedit/models/editors/disco_diffusion/secondary_model.py
@@ -0,0 +1,165 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from mmedit.registry import MODELS
+
+# Note: This model is copied from Disco-Diffusion colab.
+# SourceCode: https://colab.research.google.com/drive/1uGKaBOEACeinAA7jX1_zSFtj_ZW-huHS#scrollTo=XIqUfrmvLIhg # noqa
+
+
+def append_dims(x, n):
+ """Append dims."""
+ return x[(Ellipsis, *(None, ) * (n - x.ndim))]
+
+
+def expand_to_planes(x, shape):
+ """Expand tensor to planes."""
+ return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])
+
+
+def alpha_sigma_to_t(alpha, sigma):
+ """convert alpha&sigma to timestep."""
+ return torch.atan2(sigma, alpha) * 2 / math.pi
+
+
+def t_to_alpha_sigma(t):
+ """convert timestep to alpha and sigma."""
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
+
+
+class ConvBlock(nn.Sequential):
+ """Convolution Block.
+
+ Args:
+ c_in (int): Input channels.
+ c_out (int): Output channels.
+ """
+
+ def __init__(self, c_in, c_out):
+ super().__init__(
+ nn.Conv2d(c_in, c_out, 3, padding=1),
+ nn.ReLU(inplace=True),
+ )
+
+
+class SkipBlock(nn.Module):
+ """Skip block wrapper. Wrapping main block and skip block and concat their
+ outputs together.
+
+ Args:
+ main (list): A list of main modules.
+ skip (nn.Module): Skip Module. If not given,
+ set to ``nn.Identity()``. Defaults to None.
+ """
+
+ def __init__(self, main, skip=None):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+ self.skip = skip if skip else nn.Identity()
+
+ def forward(self, input):
+ """Forward function."""
+ return torch.cat([self.main(input), self.skip(input)], dim=1)
+
+
+class FourierFeatures(nn.Module):
+ """Fourier features mapping MLP.
+
+ Args:
+ in_features (int): Input channels.
+ out_features (int): Output channels.
+ std (float): Standard deviation. Defaults to 1..
+ """
+
+ def __init__(self, in_features, out_features, std=1.):
+ super().__init__()
+ assert out_features % 2 == 0
+ self.weight = nn.Parameter(
+ torch.randn([out_features // 2, in_features]) * std)
+
+ def forward(self, input):
+ """Forward function."""
+ f = 2 * math.pi * input @ self.weight.T
+ return torch.cat([f.cos(), f.sin()], dim=-1)
+
+
+@MODELS.register_module()
+class SecondaryDiffusionImageNet2(nn.Module):
+ """A smaller secondary diffusion model trained by Katherine Crowson to
+ remove noise from intermediate timesteps to prepare them for CLIP.
+
+ Ref: https://twitter.com/rivershavewings/status/1462859669454536711 # noqa
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.in_channels = 3
+ c = 64 # The base channel count
+ cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]
+
+ self.timestep_embed = FourierFeatures(1, 16)
+ self.down = nn.AvgPool2d(2)
+ self.up = nn.Upsample(
+ scale_factor=2, mode='bilinear', align_corners=False)
+
+ self.net = nn.Sequential(
+ ConvBlock(3 + 16, cs[0]),
+ ConvBlock(cs[0], cs[0]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[0], cs[1]),
+ ConvBlock(cs[1], cs[1]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[1], cs[2]),
+ ConvBlock(cs[2], cs[2]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[2], cs[3]),
+ ConvBlock(cs[3], cs[3]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[3], cs[4]),
+ ConvBlock(cs[4], cs[4]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[4], cs[5]),
+ ConvBlock(cs[5], cs[5]),
+ ConvBlock(cs[5], cs[5]),
+ ConvBlock(cs[5], cs[4]),
+ self.up,
+ ]),
+ ConvBlock(cs[4] * 2, cs[4]),
+ ConvBlock(cs[4], cs[3]),
+ self.up,
+ ]),
+ ConvBlock(cs[3] * 2, cs[3]),
+ ConvBlock(cs[3], cs[2]),
+ self.up,
+ ]),
+ ConvBlock(cs[2] * 2, cs[2]),
+ ConvBlock(cs[2], cs[1]),
+ self.up,
+ ]),
+ ConvBlock(cs[1] * 2, cs[1]),
+ ConvBlock(cs[1], cs[0]),
+ self.up,
+ ]),
+ ConvBlock(cs[0] * 2, cs[0]),
+ nn.Conv2d(cs[0], 3, 3, padding=1),
+ )
+
+ def forward(self, input, t):
+ """Forward function."""
+ timestep_embed = expand_to_planes(
+ self.timestep_embed(t[:, None]), input.shape)
+ v = self.net(torch.cat([input, timestep_embed], dim=1))
+ alphas, sigmas = map(
+ partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
+ pred = input * alphas - v * sigmas
+ eps = input * sigmas + v * alphas
+ return dict(v=v, pred=pred, eps=eps)
diff --git a/mmedit/models/editors/guided_diffusion/__init__.py b/mmedit/models/editors/guided_diffusion/__init__.py
new file mode 100644
index 0000000000..e66b3cdec5
--- /dev/null
+++ b/mmedit/models/editors/guided_diffusion/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .adm import AblatedDiffusionModel
+
+__all__ = ['AblatedDiffusionModel']
diff --git a/mmedit/models/editors/guided_diffusion/adm.py b/mmedit/models/editors/guided_diffusion/adm.py
new file mode 100644
index 0000000000..9ae9de0e86
--- /dev/null
+++ b/mmedit/models/editors/guided_diffusion/adm.py
@@ -0,0 +1,304 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from typing import List, Optional
+
+import mmengine
+import torch
+import torch.nn.functional as F
+from mmengine import MessageHub
+from mmengine.model import BaseModel, is_model_wrapper
+from mmengine.optim import OptimWrapperDict
+from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
+from tqdm import tqdm
+
+from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS, MODULES
+from mmedit.structures import EditDataSample, PixelData
+from mmedit.utils.typing import ForwardInputs, SampleList
+
+
+def classifier_grad(classifier, x, t, y=None, classifier_scale=1.0):
+ """compute classification gradient to x."""
+ assert y is not None
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ logits = classifier(x_in, t)
+ log_probs = F.log_softmax(logits, dim=-1)
+ selected = log_probs[range(len(logits)), y.view(-1)]
+ return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
+
+
+@MODELS.register_module('ADM')
+@MODELS.register_module('GuidedDiffusion')
+@MODELS.register_module()
+class AblatedDiffusionModel(BaseModel):
+ """Guided diffusion Model.
+
+ Args:
+ data_preprocessor (dict, optional): The pre-process config of
+ :class:`BaseDataPreprocessor`.
+ unet (ModelType): Config of denoising Unet.
+ diffusion_scheduler (ModelType): Config of diffusion_scheduler
+ scheduler.
+ use_fp16 (bool): Whether to use fp16 for unet model. Defaults to False.
+ classifier (ModelType): Config of classifier. Defaults to None.
+ pretrained_cfgs (dict): Path Config for pretrained weights. Usually
+ this is a dict contains module name and the corresponding ckpt
+ path.Defaults to None.
+ """
+
+ def __init__(self,
+ data_preprocessor,
+ unet,
+ diffusion_scheduler,
+ use_fp16=False,
+ classifier=None,
+ classifier_scale=1.0,
+ pretrained_cfgs=None):
+
+ super().__init__(data_preprocessor=data_preprocessor)
+ self.unet = MODULES.build(unet)
+ self.diffusion_scheduler = DIFFUSION_SCHEDULERS.build(
+ diffusion_scheduler)
+ if classifier:
+ self.classifier = MODULES.build(classifier)
+ else:
+ self.classifier = None
+ self.classifier_scale = classifier_scale
+
+ if pretrained_cfgs:
+ self.load_pretrained_models(pretrained_cfgs)
+ if use_fp16:
+ mmengine.print_log('Convert unet modules to floatpoint16')
+ self.unet.convert_to_fp16()
+
+ def load_pretrained_models(self, pretrained_cfgs):
+ """_summary_
+
+ Args:
+ pretrained_cfgs (_type_): _description_
+ """
+ for key, ckpt_cfg in pretrained_cfgs.items():
+ prefix = ckpt_cfg.get('prefix', '')
+ map_location = ckpt_cfg.get('map_location', 'cpu')
+ strict = ckpt_cfg.get('strict', True)
+ ckpt_path = ckpt_cfg.get('ckpt_path')
+ state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
+ map_location)
+ getattr(self, key).load_state_dict(state_dict, strict=strict)
+ mmengine.print_log(f'Load pretrained {key} from {ckpt_path}')
+
+ @property
+ def device(self):
+ """Get current device of the model.
+
+ Returns:
+ torch.device: The current device of the model.
+ """
+ return next(self.parameters()).device
+
+ def infer(self,
+ init_image=None,
+ batch_size=1,
+ num_inference_steps=1000,
+ labels=None,
+ classifier_scale=0.0,
+ show_progress=False):
+ """_summary_
+
+ Args:
+ init_image (_type_, optional): _description_. Defaults to None.
+ batch_size (int, optional): _description_. Defaults to 1.
+ num_inference_steps (int, optional): _description_.
+ Defaults to 1000.
+ labels (_type_, optional): _description_. Defaults to None.
+ show_progress (bool, optional): _description_. Defaults to False.
+
+ Returns:
+ _type_: _description_
+ """
+ # Sample gaussian noise to begin loop
+ if init_image is None:
+ image = torch.randn((batch_size, self.unet.in_channels,
+ self.unet.image_size, self.unet.image_size))
+ image = image.to(self.device)
+ else:
+ image = init_image
+
+ if isinstance(labels, int):
+ labels = torch.tensor(labels).repeat(batch_size, 1)
+ elif labels is None:
+ labels = torch.randint(
+ low=0,
+ high=self.unet.num_classes,
+ size=(batch_size, ),
+ device=self.device)
+
+ # set step values
+ if num_inference_steps > 0:
+ self.diffusion_scheduler.set_timesteps(num_inference_steps)
+
+ timesteps = self.diffusion_scheduler.timesteps
+
+ if show_progress and mmengine.dist.is_main_process():
+ timesteps = tqdm(timesteps)
+ for t in timesteps:
+ # 1. predicted model_output
+ model_output = self.unet(image, t, label=labels)['outputs']
+
+ # 2. compute previous image: x_t -> x_t-1
+ diffusion_scheduler_output = self.diffusion_scheduler.step(
+ model_output, t, image)
+
+ # 3. applying classifier guide
+ if self.classifier and classifier_scale != 0.0:
+ gradient = classifier_grad(
+ self.classifier,
+ image,
+ t,
+ labels,
+ classifier_scale=classifier_scale)
+ guided_mean = (
+ diffusion_scheduler_output['mean'].float() +
+ diffusion_scheduler_output['sigma'] * gradient.float())
+ image = guided_mean + diffusion_scheduler_output[
+ 'sigma'] * diffusion_scheduler_output['noise']
+ else:
+ image = diffusion_scheduler_output['prev_sample']
+
+ return {'samples': image}
+
+ def forward(self,
+ inputs: ForwardInputs,
+ data_samples: Optional[list] = None,
+ mode: Optional[str] = None) -> List[EditDataSample]:
+ """_summary_
+
+ Args:
+ inputs (ForwardInputs): _description_
+ data_samples (Optional[list], optional): _description_.
+ Defaults to None.
+ mode (Optional[str], optional): _description_. Defaults to None.
+
+ Returns:
+ List[EditDataSample]: _description_
+ """
+ init_image = inputs.get('init_image', None)
+ batch_size = inputs.get('batch_size', 1)
+ labels = data_samples.get('labels', None)
+ sample_kwargs = inputs.get('sample_kwargs', dict())
+
+ num_inference_steps = sample_kwargs.get(
+ 'num_inference_steps',
+ self.diffusion_scheduler.num_train_timesteps)
+ show_progress = sample_kwargs.get('show_progress', False)
+ classifier_scale = sample_kwargs.get('classifier_scale',
+ self.classifier_scale)
+
+ outputs = self.infer(
+ init_image=init_image,
+ batch_size=batch_size,
+ num_inference_steps=num_inference_steps,
+ show_progress=show_progress,
+ classifier_scale=classifier_scale)
+
+ batch_sample_list = []
+ for idx in range(batch_size):
+ gen_sample = EditDataSample()
+ if data_samples:
+ gen_sample.update(data_samples[idx])
+ if isinstance(outputs, dict):
+ gen_sample.ema = EditDataSample(
+ fake_img=PixelData(data=outputs['ema'][idx]),
+ sample_model='ema')
+ gen_sample.orig = EditDataSample(
+ fake_img=PixelData(data=outputs['orig'][idx]),
+ sample_model='orig')
+ gen_sample.sample_model = 'ema/orig'
+ gen_sample.set_gt_label(labels[idx])
+ gen_sample.ema.set_gt_label(labels[idx])
+ gen_sample.orig.set_gt_label(labels[idx])
+ else:
+ gen_sample.fake_img = PixelData(data=outputs[idx])
+ gen_sample.set_gt_label(labels[idx])
+
+ # Append input condition (noise and sample_kwargs) to
+ # batch_sample_list
+ if init_image is not None:
+ gen_sample.noise = init_image[idx]
+ gen_sample.sample_kwargs = deepcopy(sample_kwargs)
+ batch_sample_list.append(gen_sample)
+ return batch_sample_list
+
+ def val_step(self, data: dict) -> SampleList:
+ """Gets the generated image of given data.
+
+ Calls ``self.data_preprocessor(data)`` and
+ ``self(inputs, data_sample, mode=None)`` in order. Return the
+ generated results which will be passed to evaluator.
+
+ Args:
+ data (dict): Data sampled from metric specific
+ sampler. More detials in `Metrics` and `Evaluator`.
+
+ Returns:
+ SampleList: Generated image or image dict.
+ """
+ data = self.data_preprocessor(data)
+ outputs = self(**data)
+ return outputs
+
+ def test_step(self, data: dict) -> SampleList:
+ """Gets the generated image of given data. Same as :meth:`val_step`.
+
+ Args:
+ data (dict): Data sampled from metric specific
+ sampler. More detials in `Metrics` and `Evaluator`.
+
+ Returns:
+ List[EditDataSample]: Generated image or image dict.
+ """
+ data = self.data_preprocessor(data)
+ outputs = self(**data)
+ return outputs
+
+ def train_step(self, data: dict, optim_wrapper: OptimWrapperDict):
+ """_summary_
+
+ Args:
+ data (dict): _description_
+ optim_wrapper (OptimWrapperDict): _description_
+
+ Returns:
+ _type_: _description_
+ """
+ message_hub = MessageHub.get_current_instance()
+ curr_iter = message_hub.get_info('iter')
+
+ # sampling x0 and timestep
+ data = self.data_preprocessor(data)
+ real_imgs = data['inputs']
+ timestep = self.diffusion_scheduler.sample_timestep()
+
+ # calculating loss
+ loss_dict = self.diffusion_scheduler.training_loss(
+ self.unet, real_imgs, timestep)
+ loss, log_vars = self._parse_losses(loss_dict)
+ optim_wrapper['denoising'].update_params(loss)
+
+ # update EMA
+ if self.with_ema_denoising and (curr_iter + 1) >= self.ema_start:
+ self.denoising_ema.update_parameters(
+ self.denoising_ema.
+ module if is_model_wrapper(self.denoising) else self.denoising)
+ # if not update buffer, copy buffer from orig model
+ if not self.denoising_ema.update_buffers:
+ self.denoising_ema.sync_buffers(
+ self.denoising.module
+ if is_model_wrapper(self.denoising) else self.denoising)
+ elif self.with_ema_denoising:
+ # before ema, copy weights from orig
+ self.denoising_ema.sync_parameters(
+ self.denoising.
+ module if is_model_wrapper(self.denoising) else self.denoising)
+
+ return log_vars
diff --git a/mmedit/models/losses/__init__.py b/mmedit/models/losses/__init__.py
index cf4cfb9ccb..9796b651ef 100644
--- a/mmedit/models/losses/__init__.py
+++ b/mmedit/models/losses/__init__.py
@@ -16,7 +16,7 @@
from .perceptual_loss import (PerceptualLoss, PerceptualVGG,
TransferalPerceptualLoss)
from .pixelwise_loss import (CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss,
- PSNRLoss)
+ PSNRLoss, tv_loss)
__all__ = [
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss',
@@ -28,5 +28,5 @@
'CLIPLoss', 'CLIPLossComps', 'DiscShiftLossComps', 'FaceIdLossComps',
'GANLossComps', 'GeneratorPathRegularizerComps',
'GradientPenaltyLossComps', 'R1GradientPenaltyComps', 'disc_shift_loss',
- 'PSNRLoss'
+ 'tv_loss', 'PSNRLoss'
]
diff --git a/mmedit/models/losses/pixelwise_loss.py b/mmedit/models/losses/pixelwise_loss.py
index e2faf0b5ad..3ba36d5fba 100644
--- a/mmedit/models/losses/pixelwise_loss.py
+++ b/mmedit/models/losses/pixelwise_loss.py
@@ -53,6 +53,14 @@ def charbonnier_loss(pred, target, eps=1e-12):
return torch.sqrt((pred - target)**2 + eps)
+def tv_loss(input):
+ """L2 total variation loss, as in Mahendran et al."""
+ input = F.pad(input, (0, 1, 0, 1), 'replicate')
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
+ return (x_diff**2 + y_diff**2).mean([1, 2, 3])
+
+
@LOSSES.register_module()
class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
diff --git a/mmedit/models/utils/diffusion_utils.py b/mmedit/models/utils/diffusion_utils.py
new file mode 100644
index 0000000000..b186b29e94
--- /dev/null
+++ b/mmedit/models/utils/diffusion_utils.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """Create a beta schedule that discretizes the given alpha_t_bar
+ function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Source: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py#L49 # noqa
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2)**2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float64)
diff --git a/mmedit/registry.py b/mmedit/registry.py
index ff334935d8..60564ca5da 100644
--- a/mmedit/registry.py
+++ b/mmedit/registry.py
@@ -62,3 +62,6 @@
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=registry.OPTIM_WRAPPERS)
+
+# manage diffusion_schedulers
+DIFFUSION_SCHEDULERS = Registry('diffusion scheduler')
diff --git a/model-index.yml b/model-index.yml
index e282efcfb7..f45127255f 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -10,6 +10,7 @@ Import:
- configs/deepfillv2/metafile.yml
- configs/dic/metafile.yml
- configs/dim/metafile.yml
+- configs/disco_diffusion/metafile.yml
- configs/edsr/metafile.yml
- configs/edvr/metafile.yml
- configs/esrgan/metafile.yml
@@ -18,6 +19,7 @@ Import:
- configs/ggan/metafile.yml
- configs/glean/metafile.yml
- configs/global_local/metafile.yml
+- configs/guided_diffusion/metafile.yml
- configs/iconvsr/metafile.yml
- configs/indexnet/metafile.yml
- configs/inst_colorization/metafile.yml
diff --git a/requirements/optional.txt b/requirements/optional.txt
index 80265b1165..7d9477240b 100644
--- a/requirements/optional.txt
+++ b/requirements/optional.txt
@@ -1,3 +1,4 @@
clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
mmdet >= 3.0.0rc2
+open_clip_torch
PyQt5
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index 02198bf6ec..1035119e93 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -14,6 +14,7 @@ opencv-python!=4.5.5.62,!=4.5.5.64
# https://github.com/opencv/opencv/issues/21366
# It seems to be fixed in https://github.com/opencv/opencv/pull/21382
Pillow
+resize_right
tensorboard
torch
torchvision
diff --git a/tests/test_models/test_editors/test_disco_diffusion/test_disco_diffusion.py b/tests/test_models/test_editors/test_disco_diffusion/test_disco_diffusion.py
new file mode 100644
index 0000000000..ffd350ecf9
--- /dev/null
+++ b/tests/test_models/test_editors/test_disco_diffusion/test_disco_diffusion.py
@@ -0,0 +1,197 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import pytest
+import torch
+import torch.nn as nn
+
+from mmedit.models import DDIMScheduler, DenoisingUnet, DiscoDiffusion
+from mmedit.utils import register_all_modules
+
+register_all_modules()
+
+
+class clip_mock(nn.Module):
+
+ def __init__(self, device='cuda'):
+ super().__init__()
+ self.register_buffer('tensor', torch.randn([1, 512]))
+
+ def encode_image(self, inputs):
+ return inputs.mean() * self.tensor.repeat(inputs.shape[0], 1).to(
+ inputs.device)
+
+ def encode_text(self, inputs):
+ return self.tensor.repeat(inputs.shape[0], 1).to(inputs.device)
+
+ def forward(self, x):
+ return x
+
+
+class clip_mock_wrapper(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.model = clip_mock()
+
+ def forward(self, x):
+ return x
+
+
+class TestDiscoDiffusion(TestCase):
+
+ def test_init(self):
+ # unet
+ unet32 = DenoisingUnet(
+ image_size=32,
+ in_channels=3,
+ base_channels=8,
+ resblocks_per_downsample=2,
+ attention_res=(8, ),
+ norm_cfg=dict(type='GN32', num_groups=8),
+ dropout=0.0,
+ num_classes=0,
+ use_fp16=True,
+ resblock_updown=True,
+ attention_cfg=dict(
+ type='MultiHeadAttentionBlock',
+ num_heads=2,
+ num_head_channels=8,
+ use_new_attention_order=False),
+ use_scale_shift_norm=True)
+ # mock clip
+ clip_models = [clip_mock_wrapper(), clip_mock_wrapper()]
+ # diffusion_scheduler
+ diffusion_scheduler = DDIMScheduler(
+ variance_type='learned_range',
+ beta_schedule='linear',
+ clip_sample=False)
+
+ self.disco_diffusion = DiscoDiffusion(
+ unet=unet32,
+ diffusion_scheduler=diffusion_scheduler,
+ secondary_model=None,
+ clip_models=clip_models,
+ use_fp16=True)
+
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
+ def test_infer(self):
+ self.disco_diffusion.cuda().eval()
+ # test model structure
+ text_prompts = {
+ 0: ['clouds surround the mountains and palaces,sunshine,lake']
+ }
+ image = self.disco_diffusion.infer(
+ text_prompts=text_prompts,
+ show_progress=True,
+ num_inference_steps=5,
+ eta=0.8)['samples']
+ assert image.shape == (1, 3, 32, 32)
+ # test with different text prompts
+ text_prompts = {
+ 0: [
+ 'a portrait of supergirl, by artgerm, rosstran, trending on artstation.' # noqa
+ ]
+ }
+ image = self.disco_diffusion.infer(
+ text_prompts=text_prompts,
+ show_progress=True,
+ num_inference_steps=5,
+ eta=0.8)['samples']
+ assert image.shape == (1, 3, 32, 32)
+
+ # test with init_image
+ init_image = 'tests/data/image/face/000001.png'
+ text_prompts = {
+ 0: [
+ 'a portrait of supergirl, by artgerm, rosstran, trending on artstation.' # noqa
+ ]
+ }
+ image = self.disco_diffusion.infer(
+ text_prompts=text_prompts,
+ init_image=init_image,
+ show_progress=True,
+ num_inference_steps=5,
+ eta=0.8)['samples']
+ assert image.shape == (1, 3, 32, 32)
+
+ # test with different image resolution
+ text_prompts = {
+ 0: ['clouds surround the mountains and palaces,sunshine,lake']
+ }
+ image = self.disco_diffusion.infer(
+ height=64,
+ width=128,
+ text_prompts=text_prompts,
+ show_progress=True,
+ num_inference_steps=5,
+ eta=0.8)['samples']
+ assert image.shape == (1, 3, 64, 128)
+
+ # clip guidance scale
+ image = self.disco_diffusion.infer(
+ text_prompts=text_prompts,
+ show_progress=True,
+ num_inference_steps=5,
+ clip_guidance_scale=8000,
+ eta=0.8)['samples']
+ assert image.shape == (1, 3, 32, 32)
+
+ # test with different loss settings
+ tv_scale = 0.5
+ sat_scale = 0.5
+ range_scale = 100
+ image = self.disco_diffusion.infer(
+ text_prompts=text_prompts,
+ show_progress=True,
+ num_inference_steps=5,
+ eta=0.8,
+ tv_scale=tv_scale,
+ sat_scale=sat_scale,
+ range_scale=range_scale)['samples']
+ assert image.shape == (1, 3, 32, 32)
+
+ # test with different cutter settings
+ cut_overview = [12] * 100 + [4] * 900
+ cut_innercut = [4] * 100 + [12] * 900
+ cut_ic_pow = [1] * 200 + [0] * 800
+ cut_icgray_p = [0.2] * 200 + [0] * 800
+ cutn_batches = 2
+ image = self.disco_diffusion.infer(
+ text_prompts=text_prompts,
+ show_progress=True,
+ num_inference_steps=5,
+ eta=0.8,
+ cut_overview=cut_overview,
+ cut_innercut=cut_innercut,
+ cut_ic_pow=cut_ic_pow,
+ cut_icgray_p=cut_icgray_p,
+ cutn_batches=cutn_batches)['samples']
+ assert image.shape == (1, 3, 32, 32)
+
+ # test with different unet
+ unet64 = DenoisingUnet(
+ image_size=64,
+ in_channels=3,
+ base_channels=8,
+ resblocks_per_downsample=2,
+ attention_res=(8, ),
+ norm_cfg=dict(type='GN32', num_groups=8),
+ dropout=0.0,
+ num_classes=0,
+ use_fp16=True,
+ resblock_updown=True,
+ attention_cfg=dict(
+ type='MultiHeadAttentionBlock',
+ num_heads=2,
+ num_head_channels=8,
+ use_new_attention_order=False),
+ use_scale_shift_norm=True).cuda()
+ unet64.convert_to_fp16()
+ self.disco_diffusion.unet = unet64
+ image = self.disco_diffusion.infer(
+ text_prompts=text_prompts,
+ show_progress=True,
+ num_inference_steps=5,
+ eta=0.8)['samples']
+ assert image.shape == (1, 3, 64, 64)