diff --git a/projects/glide/configs/README.md b/projects/glide/configs/README.md index 91ba07ac40..0d7e4e3ca6 100644 --- a/projects/glide/configs/README.md +++ b/projects/glide/configs/README.md @@ -34,9 +34,10 @@ Diffusion models have recently been shown to generate high-quality synthetic ima **Laion** -| Method | Resolution | Config | Weights | -| ------ | ---------- | -------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- | -| Glide | 64x64 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py) | [model](https://download.openmmlab.com/mmagic/glide/glide_laion-64x64-02afff47.pth) | +| Method | Resolution | Config | Weights | +| ------ | ---------------- | --------------------------------------------------------------------------- | --------------------------------------------------------------------------------------- | +| Glide | 64x64 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64x64-02afff47.pth) | +| Glide | 64x64 -> 256x256 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64-256-02afff47.pth) | ## Quick Start @@ -66,6 +67,34 @@ with torch.no_grad(): show_progress=True)['samples'] ``` +You can synthesis images with 256x256 resolution: + +```python +import torch +from torchvision.utils import save_image +from mmedit.apis import init_model +from mmengine.registry import init_default_scope +from projects.glide.models import * + +init_default_scope('mmedit') + +config = 'projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py' +ckpt = 'https://download.openmmlab.com/mmediting/glide/glide_laion-64-256-02afff47.pth' +model = init_model(config, ckpt).cuda().eval() +prompt = "an oil painting of a corgi" + +with torch.no_grad(): + samples = model.infer(init_image=None, + prompt=prompt, + batch_size=16, + guidance_scale=3., + num_inference_steps=100, + labels=None, + classifier_scale=0.0, + show_progress=True)['samples'] +save_image(samples, "corgi.png", nrow=4, normalize=True, value_range=(-1, 1)) +``` + ## Citation ```bibtex diff --git a/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py b/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py new file mode 100644 index 0000000000..23774d0741 --- /dev/null +++ b/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py @@ -0,0 +1,68 @@ +unet_cfg = dict( + type='Text2ImUNet', + image_size=64, + base_channels=192, + in_channels=3, + resblocks_per_downsample=3, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, +) +unet_up_cfg = dict( + type='SuperResText2ImUNet', + image_size=256, + base_channels=192, + in_channels=3, + output_cfg=dict(var='FIXED'), + resblocks_per_downsample=2, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, +) + +model = dict( + type='Glide', + data_preprocessor=dict(type='DataPreprocessor', mean=[127.5], std=[127.5]), + unet=unet_cfg, + diffusion_scheduler=dict( + type='EditDDIMScheduler', + variance_type='learned_range', + beta_schedule='squaredcos_cap_v2'), + unet_up=unet_up_cfg, + diffusion_scheduler_up=dict( + type='EditDDIMScheduler', + variance_type='learned_range', + beta_schedule='linear'), + use_fp16=False) diff --git a/projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py b/projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py index f848349a25..fd2472b2a7 100644 --- a/projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py +++ b/projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py @@ -28,7 +28,7 @@ xf_padding=True, ), diffusion_scheduler=dict( - type='DDIMScheduler', + type='EditDDIMScheduler', variance_type='learned_range', beta_schedule='squaredcos_cap_v2'), use_fp16=False) diff --git a/projects/glide/models/__init__.py b/projects/glide/models/__init__.py index 6b40ebf542..59947b2cb8 100644 --- a/projects/glide/models/__init__.py +++ b/projects/glide/models/__init__.py @@ -1,4 +1,4 @@ from .glide import Glide -from .text2im_unet import Text2ImUNet +from .text2im_unet import SuperResText2ImUNet, Text2ImUNet -__all__ = ['Text2ImUNet', 'Glide'] +__all__ = ['Text2ImUNet', 'Glide', 'SuperResText2ImUNet'] diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index ad0eb2529e..dda6ea2989 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -12,12 +12,10 @@ from mmengine.runner.checkpoint import _load_checkpoint_with_prefix from tqdm import tqdm -from mmagic.registry import DIFFUSION_SCHEDULERS, MODELS, MODULES +from mmagic.registry import DIFFUSION_SCHEDULERS, MODELS from mmagic.structures import DataSample from mmagic.utils.typing import ForwardInputs, SampleList -# from .guider import ImageTextGuider - ModelType = Union[Dict, nn.Module] @@ -32,38 +30,70 @@ def classifier_grad(classifier, x, t, y=None, classifier_scale=1.0): return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale +@MODELS.register_module('GLIDE') @MODELS.register_module() class Glide(BaseModel): - """Guided diffusion Model. + """GLIDE: Guided language to image diffusion for generation and editing. + Refer to: https://github.com/openai/glide-text2im. + Args: - data_preprocessor (dict, optional): The pre-process config of + data_preprocessor (dict, optional): The pre-process configuration for :class:`BaseDataPreprocessor`. - unet (ModelType): Config of denoising Unet. - diffusion_scheduler (ModelType): Config of diffusion_scheduler + unet (ModelType): Configuration for the denoising Unet. + diffusion_scheduler (ModelType): Configuration for the diffusion scheduler. - use_fp16 (bool): Whether to use fp16 for unet model. Defaults to False. - classifier (ModelType): Config of classifier. Defaults to None. - pretrained_cfgs (dict): Path Config for pretrained weights. Usually - this is a dict contains module name and the corresponding ckpt - path.Defaults to None. + unet_up (ModelType, optional): Configuration for the upsampling + denoising UNet. Defaults to None. + diffusion_scheduler_up (ModelType, optional): Configuration for + the upsampling diffusion scheduler. Defaults to None. + use_fp16 (bool, optional): Whether to use fp16 for the unet model. + Defaults to False. + classifier (ModelType, optional): Configuration for the classifier. + Defaults to None. + classifier_scale (float): Classifier scale for classifier guidance. + Defaults to 1.0. + data_preprocessor (Optional[ModelType]): Configuration for the data + preprocessor. + pretrained_cfgs (dict, optional): Path configuration for pretrained + weights. Usually, this is a dict containing the module name and + the corresponding ckpt path. Defaults to None. """ def __init__(self, - data_preprocessor, - unet, - diffusion_scheduler, - use_fp16=False, - classifier=None, - classifier_scale=1.0, - pretrained_cfgs=None): + unet: ModelType, + diffusion_scheduler: ModelType, + unet_up: Optional[ModelType] = None, + diffusion_scheduler_up: Optional[ModelType] = None, + use_fp16: Optional[bool] = False, + classifier: Optional[dict] = None, + classifier_scale: float = 1.0, + data_preprocessor: Optional[ModelType] = dict( + type='DataPreprocessor'), + pretrained_cfgs: Optional[dict] = None): super().__init__(data_preprocessor=data_preprocessor) - self.unet = MODULES.build(unet) + self.unet = unet if isinstance(unet, nn.Module) else MODELS.build(unet) self.diffusion_scheduler = DIFFUSION_SCHEDULERS.build( - diffusion_scheduler) + diffusion_scheduler) if isinstance(diffusion_scheduler, + dict) else diffusion_scheduler + + self.unet_up = None + self.diffusion_scheduler_up = None + if unet_up: + self.unet_up = unet_up if isinstance( + unet_up, nn.Module) else MODELS.build(unet_up) + if diffusion_scheduler_up: + self.diffusion_scheduler_up = DIFFUSION_SCHEDULERS.build( + diffusion_scheduler_up) if isinstance( + diffusion_scheduler_up, + dict) else diffusion_scheduler_up + else: + self.diffusion_scheduler_up = deepcopy( + self.diffusion_scheduler) + if classifier: - self.classifier = MODULES.build(classifier) + self.classifier = MODELS.build(classifier) else: self.classifier = None self.classifier_scale = classifier_scale @@ -101,26 +131,34 @@ def device(self): @torch.no_grad() def infer(self, - init_image=None, - prompt=None, - batch_size=1, - guidance_scale=3., - num_inference_steps=50, - labels=None, - classifier_scale=0.0, - show_progress=False): - """_summary_ + init_image: Optional[torch.Tensor] = None, + prompt: str = None, + batch_size: Optional[int] = 1, + guidance_scale: float = 3., + num_inference_steps: int = 50, + num_inference_steps_up: Optional[int] = 27, + labels: Optional[torch.Tensor] = None, + classifier_scale: float = 0.0, + show_progress: Optional[bool] = False): + """Inference function for guided diffusion. Args: - init_image (_type_, optional): _description_. Defaults to None. - batch_size (int, optional): _description_. Defaults to 1. - num_inference_steps (int, optional): _description_. - Defaults to 1000. - labels (_type_, optional): _description_. Defaults to None. - show_progress (bool, optional): _description_. Defaults to False. + init_image (torch.Tensor, optional): Starting noise for diffusion. + Defaults to None. + prompt (str): The prompt to guide the image generation. + batch_size (int, optional): Batch size for generation. + Defaults to 1. + num_inference_steps (int, optional): The number of denoising steps. + Defaults to 50. + num_inference_steps_up (int, optional): The number of upsampling + denoising steps. Defaults to 27. + labels (torch.Tensor, optional): Labels for the classifier. + Defaults to None. + show_progress (bool, optional): Whether to show the progress bar. + Defaults to False. Returns: - _type_: _description_ + torch.Tensor: Generated images. """ # Sample gaussian noise to begin loop if init_image is None: @@ -167,9 +205,6 @@ def infer(self, half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) noise_pred = torch.cat([eps, rest], dim=1) - # noise_pred_text, noise_pred_uncond = model_output.chunk(2) - # noise_pred = noise_pred_uncond + guidance_scale * - # (noise_pred_text - noise_pred_uncond) # 2. compute previous image: x_t -> x_t-1 diffusion_scheduler_output = self.diffusion_scheduler.step( @@ -191,8 +226,79 @@ def infer(self, else: image = diffusion_scheduler_output['prev_sample'] + # abandon unconditional image + image = image[:image.shape[0] // 2] + + if self.unet_up: + image = self.infer_up( + low_res_img=image, + batch_size=batch_size, + prompt=prompt, + num_inference_steps=num_inference_steps_up) + return {'samples': image} + @torch.no_grad() + def infer_up(self, + low_res_img: torch.Tensor, + batch_size: int = 1, + init_image: Optional[torch.Tensor] = None, + prompt: Optional[str] = None, + num_inference_steps: int = 27, + show_progress: bool = False): + """Inference function for upsampling guided diffusion. + + Args: + low_res_img (torch.Tensor): Low resolution image + (shape: [B, C, H, W]) for upsampling. + batch_size (int, optional): Batch size for generation. + Defaults to 1. + init_image (torch.Tensor, optional): Starting noise + (shape: [B, C, H, W]) for diffusion. Defaults to None. + prompt (str, optional): The text prompt to guide the image + generation. Defaults to None. + num_inference_steps (int, optional): The number of denoising + steps. Defaults to 27. + show_progress (bool, optional): Whether to show the progress bar. + Defaults to False. + + Returns: + torch.Tensor: Generated upsampled images (shape: [B, C, H, W]). + """ + if init_image is None: + image = torch.randn( + (batch_size, self.unet_up.in_channels // 2, + self.unet_up.image_size, self.unet_up.image_size)) + image = image.to(self.device) + else: + image = init_image + + # set step values + if num_inference_steps > 0: + self.diffusion_scheduler_up.set_timesteps(num_inference_steps) + timesteps = self.diffusion_scheduler_up.timesteps + + # text embedding + tokens = self.unet.tokenizer.encode(prompt) + tokens, mask = self.unet.tokenizer.padded_tokens_and_mask(tokens, 128) + tokens = torch.tensor( + [tokens] * batch_size, dtype=torch.bool, device=self.device) + mask = torch.tensor( + [mask] * batch_size, dtype=torch.bool, device=self.device) + + if show_progress and mmengine.dist.is_main_process(): + timesteps = tqdm(timesteps) + + for t in timesteps: + noise_pred = self.unet_up( + image, t, low_res=low_res_img, tokens=tokens, mask=mask) + # compute previous image: x_t -> x_t-1 + diffusion_scheduler_output = self.diffusion_scheduler_up.step( + noise_pred, t, image) + image = diffusion_scheduler_output['prev_sample'] + + return image + def forward(self, inputs: ForwardInputs, data_samples: Optional[list] = None, @@ -253,6 +359,7 @@ def forward(self, batch_sample_list.append(gen_sample) return batch_sample_list + @torch.no_grad() def val_step(self, data: dict) -> SampleList: """Gets the generated image of given data. @@ -271,6 +378,7 @@ def val_step(self, data: dict) -> SampleList: outputs = self(**data) return outputs + @torch.no_grad() def test_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Same as :meth:`val_step`. diff --git a/projects/glide/models/text2im_unet.py b/projects/glide/models/text2im_unet.py index a312299d65..7640a3cac0 100644 --- a/projects/glide/models/text2im_unet.py +++ b/projects/glide/models/text2im_unet.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F from mmagic.models import DenoisingUnet from mmagic.registry import MODELS @@ -9,15 +10,25 @@ @MODELS.register_module() class Text2ImUNet(DenoisingUnet): - """A UNetModel that conditions on text with an encoding transformer. - Expects an extra kwarg `tokens` of text. - - :param text_ctx: number of text tokens to expect. - :param xf_width: width of the transformer. - :param xf_layers: depth of the transformer. - :param xf_heads: heads in the transformer. - :param xf_final_ln: use a LayerNorm after the output layer. - :param tokenizer: the text tokenizer for sampling/vocab size. + """A UNetModel used in GLIDE that conditions on text with an encoding + transformer. Expects an extra kwarg `tokens` of text. + + Args: + text_ctx (int): Number of text tokens to expect. + xf_width (int): Width of the transformer. + xf_layers (int): Depth of the transformer. + xf_heads (int): Number of heads in the transformer. + xf_final_ln (bool): Whether to use a LayerNorm after the output layer. + tokenizer (callable, optional): Text tokenizer for sampling/vocab + size. Defaults to get_encoder(). + cache_text_emb (bool, optional): Whether to cache text embeddings. + Defaults to False. + xf_ar (float, optional): Autoregressive weight for the transformer. + Defaults to 0.0. + xf_padding (bool, optional): Whether to use padding in the transformer. + Defaults to False. + share_unemb (bool, optional): Whether to share UNet embeddings. + Defaults to False. """ def __init__( @@ -46,8 +57,6 @@ def __init__( else: super().__init__(*args, **kwargs, encoder_channels=xf_width) - # del self.label_embedding - if self.xf_width: self.transformer = Transformer( text_ctx, @@ -77,18 +86,6 @@ def __init__( self.cache_text_emb = cache_text_emb self.cache = None - # def convert_to_fp16(self): - # super().convert_to_fp16() - # if self.xf_width: - # self.transformer.apply(convert_module_to_f16) - # self.transformer_proj.to(torch.float16) - # self.token_embedding.to(torch.float16) - # self.positional_embedding.to(torch.float16) - # if self.xf_padding: - # self.padding_embedding.to(torch.float16) - # if self.xf_ar: - # self.unemb.to(torch.float16) - def get_text_emb(self, tokens, mask): assert tokens is not None @@ -134,7 +131,6 @@ def forward(self, x, timesteps, tokens=None, mask=None): elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(x.device) - # TODO not sure if timesteps.shape[0] != x.shape[0]: timesteps = timesteps.repeat(x.shape[0]) emb = self.time_embedding(timesteps) @@ -155,3 +151,29 @@ def forward(self, x, timesteps, tokens=None, mask=None): h = h.type(x.dtype) h = self.out(h) return h + + +@MODELS.register_module() +class SuperResText2ImUNet(Text2ImUNet): + """A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, *args, **kwargs): + if 'in_channels' in kwargs: + kwargs = dict(kwargs) + kwargs['in_channels'] = kwargs['in_channels'] * 2 + else: + args = list(args) + args[1] = args[1] * 2 + super().__init__(*args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate( + low_res, (new_height, new_width), + mode='bilinear', + align_corners=False) + x = torch.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) diff --git a/tests/test_models/test_editors/test_glide/test_glide.py b/tests/test_models/test_editors/test_glide/test_glide.py new file mode 100644 index 0000000000..fd561a7e01 --- /dev/null +++ b/tests/test_models/test_editors/test_glide/test_glide.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform +import unittest +from unittest import TestCase + +import torch + +from mmagic.models.diffusion_schedulers import EditDDIMScheduler +from mmagic.utils import register_all_modules +from projects.glide.models import Glide, SuperResText2ImUNet, Text2ImUNet + +register_all_modules() + + +class TestGLIDE(TestCase): + + def setUp(self): + # low resolution cfg + unet_cfg = dict( + image_size=64, + base_channels=192, + in_channels=3, + resblocks_per_downsample=3, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, + ) + diffusion_scheduler_cfg = dict( + variance_type='learned_range', beta_schedule='squaredcos_cap_v2') + + # unet + self.unet = Text2ImUNet(**unet_cfg) + # diffusion_scheduler + self.diffusion_scheduler = EditDDIMScheduler(**diffusion_scheduler_cfg) + + # high resolution cfg + unet_up_cfg = dict( + image_size=256, + base_channels=192, + in_channels=3, + output_cfg=dict(var='FIXED'), + resblocks_per_downsample=2, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, + ) + diffusion_scheduler_up_cfg = dict( + variance_type='learned_range', beta_schedule='linear') + + # unet up + self.unet_up = SuperResText2ImUNet(**unet_up_cfg) + self.diffusion_scheduler_up = EditDDIMScheduler( + **diffusion_scheduler_up_cfg) + + @unittest.skipIf( + 'win' in platform.system().lower(), + reason='skip on windows due to limited RAM.') + def test_init(self): + self.GLIDE = Glide( + unet=self.unet, + diffusion_scheduler=self.diffusion_scheduler, + unet_up=self.unet_up, + diffusion_scheduler_up=self.diffusion_scheduler_up) + + @unittest.skipIf( + ('win' in platform.system().lower()) + or (not torch.cuda.is_available()), + reason='skip on windows and cpu due to limited RAM.') + def test_infer(self): + # test infer resolution + text_prompts = 'clouds surround the mountains and palaces,sunshine' + image = self.GLIDE.infer( + prompt=text_prompts, show_progress=True, + num_inference_steps=2)['samples'] + assert image.shape == (1, 3, 256, 256)