From 5b026f6f5f5437a00d7ac3f9f29410f61591a71d Mon Sep 17 00:00:00 2001 From: Taited Date: Thu, 23 Feb 2023 22:19:25 +0800 Subject: [PATCH 01/14] fix import MODULE bug in Glide --- projects/glide/models/glide.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index c3d9669f55..0f80f22ef9 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -12,7 +12,7 @@ from mmengine.runner.checkpoint import _load_checkpoint_with_prefix from tqdm import tqdm -from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS, MODULES +from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS from mmedit.structures import EditDataSample, PixelData from mmedit.utils.typing import ForwardInputs, SampleList @@ -59,11 +59,11 @@ def __init__(self, pretrained_cfgs=None): super().__init__(data_preprocessor=data_preprocessor) - self.unet = MODULES.build(unet) + self.unet = MODELS.build(unet) self.diffusion_scheduler = DIFFUSION_SCHEDULERS.build( diffusion_scheduler) if classifier: - self.classifier = MODULES.build(classifier) + self.classifier = MODELS.build(classifier) else: self.classifier = None self.classifier_scale = classifier_scale From dbf96fb03cd1098287abdd1d05440f362656516c Mon Sep 17 00:00:00 2001 From: Taited Date: Thu, 23 Feb 2023 22:20:43 +0800 Subject: [PATCH 02/14] support glide unet super resolution --- projects/glide/models/__init__.py | 4 ++-- projects/glide/models/text2im_unet.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/projects/glide/models/__init__.py b/projects/glide/models/__init__.py index 6b40ebf542..59947b2cb8 100644 --- a/projects/glide/models/__init__.py +++ b/projects/glide/models/__init__.py @@ -1,4 +1,4 @@ from .glide import Glide -from .text2im_unet import Text2ImUNet +from .text2im_unet import SuperResText2ImUNet, Text2ImUNet -__all__ = ['Text2ImUNet', 'Glide'] +__all__ = ['Text2ImUNet', 'Glide', 'SuperResText2ImUNet'] diff --git a/projects/glide/models/text2im_unet.py b/projects/glide/models/text2im_unet.py index c14d65d527..ed4b63ec9d 100644 --- a/projects/glide/models/text2im_unet.py +++ b/projects/glide/models/text2im_unet.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F from mmedit.models import DenoisingUnet from mmedit.registry import MODELS @@ -155,3 +156,27 @@ def forward(self, x, timesteps, tokens=None, mask=None): h = h.type(x.dtype) h = self.out(h) return h + + +@MODELS.register_module() +class SuperResText2ImUNet(Text2ImUNet): + """A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, *args, **kwargs): + if 'in_channels' in kwargs: + kwargs = dict(kwargs) + kwargs['in_channels'] = kwargs['in_channels'] * 2 + else: + args = list(args) + args[1] = args[1] * 2 + super().__init__(*args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate( + low_res, (new_height, new_width), mode='bilinear') + x = torch.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) From 220abeae89c2c578d680bf6433a37e6363c3fc2b Mon Sep 17 00:00:00 2001 From: Taited Date: Thu, 23 Feb 2023 23:34:22 +0800 Subject: [PATCH 03/14] support glide to upsample image from 64 to 256 --- ...im-classifier-free_laion-64x64->256x256.py | 69 ++++++++++++++++ projects/glide/models/glide.py | 81 +++++++++++++++++-- 2 files changed, 145 insertions(+), 5 deletions(-) create mode 100644 projects/glide/configs/glide_ddim-classifier-free_laion-64x64->256x256.py diff --git a/projects/glide/configs/glide_ddim-classifier-free_laion-64x64->256x256.py b/projects/glide/configs/glide_ddim-classifier-free_laion-64x64->256x256.py new file mode 100644 index 0000000000..eee8b54076 --- /dev/null +++ b/projects/glide/configs/glide_ddim-classifier-free_laion-64x64->256x256.py @@ -0,0 +1,69 @@ +unet_cfg = dict( + type='Text2ImUNet', + image_size=64, + base_channels=192, + in_channels=3, + resblocks_per_downsample=3, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, +) +unet_up_cfg = dict( + type='SuperResText2ImUNet', + image_size=256, + base_channels=192, + in_channels=3, + output_cfg=dict(var='FIXED'), + resblocks_per_downsample=2, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, +) + +model = dict( + type='Glide', + data_preprocessor=dict( + type='EditDataPreprocessor', mean=[127.5], std=[127.5]), + unet=unet_cfg, + diffusion_scheduler=dict( + type='DDIMScheduler', + variance_type='learned_range', + beta_schedule='squaredcos_cap_v2'), + unet_up=unet_up_cfg, + diffusion_scheduler_up=dict( + type='DDIMScheduler', + variance_type='learned_range', + beta_schedule='linear'), + use_fp16=False) diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index 0f80f22ef9..47196b1e09 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -16,8 +16,6 @@ from mmedit.structures import EditDataSample, PixelData from mmedit.utils.typing import ForwardInputs, SampleList -# from .guider import ImageTextGuider - ModelType = Union[Dict, nn.Module] @@ -53,6 +51,8 @@ def __init__(self, data_preprocessor, unet, diffusion_scheduler, + unet_up=None, + diffusion_scheduler_up=None, use_fp16=False, classifier=None, classifier_scale=1.0, @@ -62,6 +62,18 @@ def __init__(self, self.unet = MODELS.build(unet) self.diffusion_scheduler = DIFFUSION_SCHEDULERS.build( diffusion_scheduler) + + self.unet_up = None + self.diffusion_scheduler_up = None + if unet_up: + self.unet_up = MODELS.build(unet_up) + if diffusion_scheduler_up: + self.diffusion_scheduler_up = DIFFUSION_SCHEDULERS.build( + diffusion_scheduler_up) + else: + self.diffusion_scheduler_up = deepcopy( + self.diffusion_scheduler) + if classifier: self.classifier = MODELS.build(classifier) else: @@ -167,9 +179,6 @@ def infer(self, half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) noise_pred = torch.cat([eps, rest], dim=1) - # noise_pred_text, noise_pred_uncond = model_output.chunk(2) - # noise_pred = noise_pred_uncond + guidance_scale * - # (noise_pred_text - noise_pred_uncond) # 2. compute previous image: x_t -> x_t-1 diffusion_scheduler_output = self.diffusion_scheduler.step( @@ -191,8 +200,70 @@ def infer(self, else: image = diffusion_scheduler_output['prev_sample'] + # abandon unconditional image + image = image[:image.shape[0] // 2] + + if self.unet_up: + image = self.infer_up( + low_res_img=image, batch_size=batch_size, prompt=prompt) + return {'samples': image} + @torch.no_grad() + def infer_up(self, + low_res_img, + batch_size=1, + init_image=None, + prompt=None, + num_inference_steps=27, + show_progress=False): + """_summary_ + + Args: + init_image (_type_, optional): _description_. Defaults to None. + batch_size (int, optional): _description_. Defaults to 1. + num_inference_steps (int, optional): _description_. + Defaults to 1000. + labels (_type_, optional): _description_. Defaults to None. + show_progress (bool, optional): _description_. Defaults to False. + + Returns: + _type_: _description_ + """ + if init_image is None: + image = torch.randn( + (batch_size, self.unet_up.in_channels // 2, + self.unet_up.image_size, self.unet_up.image_size)) + image = image.to(self.device) + else: + image = init_image + + # set step values + if num_inference_steps > 0: + self.diffusion_scheduler_up.set_timesteps(num_inference_steps) + timesteps = self.diffusion_scheduler_up.timesteps + + # text embedding + tokens = self.unet.tokenizer.encode(prompt) + tokens, mask = self.unet.tokenizer.padded_tokens_and_mask(tokens, 128) + tokens = torch.tensor( + [tokens] * batch_size, dtype=torch.bool, device=self.device) + mask = torch.tensor( + [mask] * batch_size, dtype=torch.bool, device=self.device) + + if show_progress and mmengine.dist.is_main_process(): + timesteps = tqdm(timesteps) + + for t in timesteps: + noise_pred = self.unet_up( + image, t, low_res=low_res_img, tokens=tokens, mask=mask) + # compute previous image: x_t -> x_t-1 + diffusion_scheduler_output = self.diffusion_scheduler_up.step( + noise_pred, t, image) + image = diffusion_scheduler_output['prev_sample'] + + return image + def forward(self, inputs: ForwardInputs, data_samples: Optional[list] = None, From 26826c3ab4df684dc6b98f89a57f29cb5bae453b Mon Sep 17 00:00:00 2001 From: Taited Date: Fri, 24 Feb 2023 13:12:45 +0800 Subject: [PATCH 04/14] refactor config name, fix F.interpolation warning --- projects/glide/configs/README.md | 35 +++++++++++++++++-- ...lide_ddim-classifier-free_laion-64-256.py} | 0 projects/glide/models/text2im_unet.py | 4 ++- 3 files changed, 35 insertions(+), 4 deletions(-) rename projects/glide/configs/{glide_ddim-classifier-free_laion-64x64->256x256.py => glide_ddim-classifier-free_laion-64-256.py} (100%) diff --git a/projects/glide/configs/README.md b/projects/glide/configs/README.md index 321fec7b9f..3db4d470fe 100644 --- a/projects/glide/configs/README.md +++ b/projects/glide/configs/README.md @@ -34,9 +34,10 @@ Diffusion models have recently been shown to generate high-quality synthetic ima **Laion** -| Method | Resolution | Config | Weights | -| ------ | ---------- | -------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- | -| Glide | 64x64 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64x64-02afff47.pth) | +| Method | Resolution | Config | Weights | +| ------ | ---------------- | --------------------------------------------------------------------------- | --------------------------------------------------------------------------------------- | +| Glide | 64x64 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64x64-02afff47.pth) | +| Glide | 64x64 -> 256x256 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64-256-02afff47.pth) | ## Quick Start @@ -66,6 +67,34 @@ with torch.no_grad(): show_progress=True)['samples'] ``` +You can synthesis images with 256x256 resolution: + +```python +import torch +from torchvision.utils import save_image +from mmedit.apis import init_model +from mmengine.registry import init_default_scope +from projects.glide.models import * + +init_default_scope('mmedit') + +config = 'projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py' +ckpt = 'https://download.openmmlab.com/mmediting/glide/glide_laion-64-256-02afff47.pth' +model = init_model(config, ckpt).cuda().eval() +prompt = "an oil painting of a corgi" + +with torch.no_grad(): + samples = model.infer(init_image=None, + prompt=prompt, + batch_size=16, + guidance_scale=3., + num_inference_steps=100, + labels=None, + classifier_scale=0.0, + show_progress=True)['samples'] +save_image(samples, "corgi.png", nrow=4, normalize=True, value_range=(-1, 1)) +``` + ## Citation ```bibtex diff --git a/projects/glide/configs/glide_ddim-classifier-free_laion-64x64->256x256.py b/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py similarity index 100% rename from projects/glide/configs/glide_ddim-classifier-free_laion-64x64->256x256.py rename to projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py diff --git a/projects/glide/models/text2im_unet.py b/projects/glide/models/text2im_unet.py index ed4b63ec9d..6535b017b0 100644 --- a/projects/glide/models/text2im_unet.py +++ b/projects/glide/models/text2im_unet.py @@ -177,6 +177,8 @@ def __init__(self, *args, **kwargs): def forward(self, x, timesteps, low_res=None, **kwargs): _, _, new_height, new_width = x.shape upsampled = F.interpolate( - low_res, (new_height, new_width), mode='bilinear') + low_res, (new_height, new_width), + mode='bilinear', + align_corners=False) x = torch.cat([x, upsampled], dim=1) return super().forward(x, timesteps, **kwargs) From 7a0482c4f39168fc56586079b0f0758cc679ee11 Mon Sep 17 00:00:00 2001 From: Taited Date: Wed, 12 Apr 2023 11:57:12 +0800 Subject: [PATCH 05/14] add more comments --- projects/glide/models/glide.py | 125 +++++++++++++++----------- projects/glide/models/text2im_unet.py | 43 ++++----- 2 files changed, 93 insertions(+), 75 deletions(-) diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index 65d8f25a90..ecc057e66f 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -12,10 +12,8 @@ from mmengine.runner.checkpoint import _load_checkpoint_with_prefix from tqdm import tqdm - -from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS, MODULES +from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS from mmedit.structures import EditDataSample - from mmedit.utils.typing import ForwardInputs, SampleList ModelType = Union[Dict, nn.Module] @@ -32,33 +30,47 @@ def classifier_grad(classifier, x, t, y=None, classifier_scale=1.0): return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale +@MODELS.register_module('GLIDE') @MODELS.register_module() class Glide(BaseModel): - """Guided diffusion Model. + """GLIDE: Guided language to image diffusion for generation and editing. + Refer to: https://github.com/openai/glide-text2im. + Args: - data_preprocessor (dict, optional): The pre-process config of + data_preprocessor (dict, optional): The pre-process configuration for :class:`BaseDataPreprocessor`. - unet (ModelType): Config of denoising Unet. - diffusion_scheduler (ModelType): Config of diffusion_scheduler + unet (ModelType): Configuration for the denoising Unet. + diffusion_scheduler (ModelType): Configuration for the diffusion scheduler. - use_fp16 (bool): Whether to use fp16 for unet model. Defaults to False. - classifier (ModelType): Config of classifier. Defaults to None. - pretrained_cfgs (dict): Path Config for pretrained weights. Usually - this is a dict contains module name and the corresponding ckpt - path.Defaults to None. + unet_up (ModelType, optional): Configuration for the upsampling + denoising UNet. Defaults to None. + diffusion_scheduler_up (ModelType, optional): Configuration for + the upsampling diffusion scheduler. Defaults to None. + use_fp16 (bool, optional): Whether to use fp16 for the unet model. + Defaults to False. + classifier (ModelType, optional): Configuration for the classifier. + Defaults to None. + classifier_scale (float): Classifier scale for classifier guidance. + Defaults to 1.0. + data_preprocessor (Optional[ModelType]): Configuration for the data + preprocessor. + pretrained_cfgs (dict, optional): Path configuration for pretrained + weights. Usually, this is a dict containing the module name and + the corresponding ckpt path. Defaults to None. """ def __init__(self, - data_preprocessor, - unet, - diffusion_scheduler, - unet_up=None, - diffusion_scheduler_up=None, - use_fp16=False, - classifier=None, - classifier_scale=1.0, - pretrained_cfgs=None): + unet: ModelType, + diffusion_scheduler: ModelType, + unet_up: Optional[ModelType] = None, + diffusion_scheduler_up: Optional[ModelType] = None, + use_fp16: Optional[bool] = False, + classifier: Optional[dict] = None, + classifier_scale: float = 1.0, + data_preprocessor: Optional[ModelType] = dict( + type='EditDataPreprocessor'), + pretrained_cfgs: Optional[dict] = None): super().__init__(data_preprocessor=data_preprocessor) self.unet = MODELS.build(unet) @@ -115,26 +127,31 @@ def device(self): @torch.no_grad() def infer(self, - init_image=None, - prompt=None, - batch_size=1, - guidance_scale=3., - num_inference_steps=50, - labels=None, - classifier_scale=0.0, - show_progress=False): - """_summary_ + init_image: Optional[torch.Tensor] = None, + prompt: str = None, + batch_size: Optional[int] = 1, + guidance_scale: float = 3., + num_inference_steps: int = 50, + labels: Optional[torch.Tensor] = None, + classifier_scale: float = 0.0, + show_progress: Optional[bool] = False): + """Inference function for guided diffusion. Args: - init_image (_type_, optional): _description_. Defaults to None. - batch_size (int, optional): _description_. Defaults to 1. - num_inference_steps (int, optional): _description_. - Defaults to 1000. - labels (_type_, optional): _description_. Defaults to None. - show_progress (bool, optional): _description_. Defaults to False. + init_image (torch.Tensor, optional): Starting noise for diffusion. + Defaults to None. + prompt (str): The prompt to guide the image generation. + batch_size (int, optional): Batch size for generation. + Defaults to 1. + num_inference_steps (int, optional): The number of denoising steps. + Defaults to 50. + labels (torch.Tensor, optional): Labels for the classifier. + Defaults to None. + show_progress (bool, optional): Whether to show the progress bar. + Defaults to False. Returns: - _type_: _description_ + torch.Tensor: Generated images. """ # Sample gaussian noise to begin loop if init_image is None: @@ -213,24 +230,30 @@ def infer(self, @torch.no_grad() def infer_up(self, - low_res_img, - batch_size=1, - init_image=None, - prompt=None, - num_inference_steps=27, - show_progress=False): - """_summary_ + low_res_img: torch.Tensor, + batch_size: int = 1, + init_image: Optional[torch.Tensor] = None, + prompt: Optional[str] = None, + num_inference_steps: int = 27, + show_progress: bool = False): + """Inference function for upsampling guided diffusion. Args: - init_image (_type_, optional): _description_. Defaults to None. - batch_size (int, optional): _description_. Defaults to 1. - num_inference_steps (int, optional): _description_. - Defaults to 1000. - labels (_type_, optional): _description_. Defaults to None. - show_progress (bool, optional): _description_. Defaults to False. + low_res_img (torch.Tensor): Low resolution image + (shape: [B, C, H, W]) for upsampling. + batch_size (int, optional): Batch size for generation. + Defaults to 1. + init_image (torch.Tensor, optional): Starting noise + (shape: [B, C, H, W]) for diffusion. Defaults to None. + prompt (str, optional): The text prompt to guide the image + generation. Defaults to None. + num_inference_steps (int, optional): The number of denoising + steps. Defaults to 27. + show_progress (bool, optional): Whether to show the progress bar. + Defaults to False. Returns: - _type_: _description_ + torch.Tensor: Generated upsampled images (shape: [B, C, H, W]). """ if init_image is None: image = torch.randn( diff --git a/projects/glide/models/text2im_unet.py b/projects/glide/models/text2im_unet.py index 6535b017b0..cd77936900 100644 --- a/projects/glide/models/text2im_unet.py +++ b/projects/glide/models/text2im_unet.py @@ -10,15 +10,25 @@ @MODELS.register_module() class Text2ImUNet(DenoisingUnet): - """A UNetModel that conditions on text with an encoding transformer. - Expects an extra kwarg `tokens` of text. - - :param text_ctx: number of text tokens to expect. - :param xf_width: width of the transformer. - :param xf_layers: depth of the transformer. - :param xf_heads: heads in the transformer. - :param xf_final_ln: use a LayerNorm after the output layer. - :param tokenizer: the text tokenizer for sampling/vocab size. + """A UNetModel used in GLIDE that conditions on text with an encoding + transformer. Expects an extra kwarg `tokens` of text. + + Args: + text_ctx (int): Number of text tokens to expect. + xf_width (int): Width of the transformer. + xf_layers (int): Depth of the transformer. + xf_heads (int): Number of heads in the transformer. + xf_final_ln (bool): Whether to use a LayerNorm after the output layer. + tokenizer (callable, optional): Text tokenizer for sampling/vocab + size. Defaults to get_encoder(). + cache_text_emb (bool, optional): Whether to cache text embeddings. + Defaults to False. + xf_ar (float, optional): Autoregressive weight for the transformer. + Defaults to 0.0. + xf_padding (bool, optional): Whether to use padding in the transformer. + Defaults to False. + share_unemb (bool, optional): Whether to share UNet embeddings. + Defaults to False. """ def __init__( @@ -47,8 +57,6 @@ def __init__( else: super().__init__(*args, **kwargs, encoder_channels=xf_width) - # del self.label_embedding - if self.xf_width: self.transformer = Transformer( text_ctx, @@ -78,18 +86,6 @@ def __init__( self.cache_text_emb = cache_text_emb self.cache = None - # def convert_to_fp16(self): - # super().convert_to_fp16() - # if self.xf_width: - # self.transformer.apply(convert_module_to_f16) - # self.transformer_proj.to(torch.float16) - # self.token_embedding.to(torch.float16) - # self.positional_embedding.to(torch.float16) - # if self.xf_padding: - # self.padding_embedding.to(torch.float16) - # if self.xf_ar: - # self.unemb.to(torch.float16) - def get_text_emb(self, tokens, mask): assert tokens is not None @@ -135,7 +131,6 @@ def forward(self, x, timesteps, tokens=None, mask=None): elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(x.device) - # TODO not sure if timesteps.shape[0] != x.shape[0]: timesteps = timesteps.repeat(x.shape[0]) emb = self.time_embedding(timesteps) From 916603d57bf21f64915729b5fdfcba0a5d837a82 Mon Sep 17 00:00:00 2001 From: Taited Date: Wed, 12 Apr 2023 17:11:05 +0800 Subject: [PATCH 06/14] change DDIMScheduler to EditDDIMScheduler, add unit test --- ...glide_ddim-classifier-free_laion-64-256.py | 4 +- .../glide_ddim-classifier-free_laion-64x64.py | 2 +- projects/glide/models/glide.py | 8 +- .../test_editors/test_glide/test_glide.py | 104 ++++++++++++++++++ 4 files changed, 114 insertions(+), 4 deletions(-) create mode 100644 tests/test_models/test_editors/test_glide/test_glide.py diff --git a/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py b/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py index eee8b54076..6a357b9df6 100644 --- a/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py +++ b/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py @@ -58,12 +58,12 @@ type='EditDataPreprocessor', mean=[127.5], std=[127.5]), unet=unet_cfg, diffusion_scheduler=dict( - type='DDIMScheduler', + type='EditDDIMScheduler', variance_type='learned_range', beta_schedule='squaredcos_cap_v2'), unet_up=unet_up_cfg, diffusion_scheduler_up=dict( - type='DDIMScheduler', + type='EditDDIMScheduler', variance_type='learned_range', beta_schedule='linear'), use_fp16=False) diff --git a/projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py b/projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py index 5344540c3e..d4ea16fa25 100644 --- a/projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py +++ b/projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py @@ -29,7 +29,7 @@ xf_padding=True, ), diffusion_scheduler=dict( - type='DDIMScheduler', + type='EditDDIMScheduler', variance_type='learned_range', beta_schedule='squaredcos_cap_v2'), use_fp16=False) diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index ecc057e66f..34edb87151 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -132,6 +132,7 @@ def infer(self, batch_size: Optional[int] = 1, guidance_scale: float = 3., num_inference_steps: int = 50, + num_inference_steps_up: Optional[int] = 27, labels: Optional[torch.Tensor] = None, classifier_scale: float = 0.0, show_progress: Optional[bool] = False): @@ -145,6 +146,8 @@ def infer(self, Defaults to 1. num_inference_steps (int, optional): The number of denoising steps. Defaults to 50. + num_inference_steps_up (int, optional): The number of upsampling + denoising steps. Defaults to 27. labels (torch.Tensor, optional): Labels for the classifier. Defaults to None. show_progress (bool, optional): Whether to show the progress bar. @@ -224,7 +227,10 @@ def infer(self, if self.unet_up: image = self.infer_up( - low_res_img=image, batch_size=batch_size, prompt=prompt) + low_res_img=image, + batch_size=batch_size, + prompt=prompt, + num_inference_steps=num_inference_steps_up) return {'samples': image} diff --git a/tests/test_models/test_editors/test_glide/test_glide.py b/tests/test_models/test_editors/test_glide/test_glide.py new file mode 100644 index 0000000000..8f40545e0f --- /dev/null +++ b/tests/test_models/test_editors/test_glide/test_glide.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +from mmengine import MODELS, Config + +from mmedit.utils import register_all_modules + +register_all_modules() + +unet = dict( + type='Text2ImUNet', + image_size=64, + base_channels=192, + in_channels=3, + resblocks_per_downsample=3, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, +) + +diffusion_scheduler = dict( + type='EditDDIMScheduler', + variance_type='learned_range', + beta_schedule='squaredcos_cap_v2') + +unet_up = dict( + type='SuperResText2ImUNet', + image_size=256, + base_channels=192, + in_channels=3, + output_cfg=dict(var='FIXED'), + resblocks_per_downsample=2, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, +) + +diffusion_scheduler_up = dict( + type='EditDDIMScheduler', + variance_type='learned_range', + beta_schedule='linear') + +model = dict( + type='Glide', + data_preprocessor=dict( + type='EditDataPreprocessor', mean=[127.5], std=[127.5]), + unet=unet, + diffusion_scheduler=diffusion_scheduler, + unet_up=unet_up, + diffusion_scheduler_up=diffusion_scheduler_up, + use_fp16=False) + + +def test_glide(): + glide = MODELS.build(Config(model)) + prompt = 'an oil painting of a corgi' + + with pytest.raises(Exception): + glide.infer( + prompt=prompt, + batch_size=1, + num_inference_steps=1, + num_inference_steps_up=1) + + result = glide.infer( + init_image=torch.randn(1, 3, 64, 64), + prompt=prompt, + batch_size=1, + guidance_scale=3.0, + num_inference_steps=1, + num_inference_steps_up=1) + assert result['samples'].shape == (1, 3, 256, 256) From e350c882da64c9c98fbad3e037f88dfd98451040 Mon Sep 17 00:00:00 2001 From: Taited Date: Wed, 12 Apr 2023 21:22:04 +0800 Subject: [PATCH 07/14] modified unit test --- projects/glide/models/glide.py | 12 +- .../test_editors/test_glide/test_glide.py | 201 ++++++++++-------- 2 files changed, 120 insertions(+), 93 deletions(-) diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index 34edb87151..11e05c1279 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -73,17 +73,21 @@ def __init__(self, pretrained_cfgs: Optional[dict] = None): super().__init__(data_preprocessor=data_preprocessor) - self.unet = MODELS.build(unet) + self.unet = unet if isinstance(unet, nn.Module) else MODELS.build(unet) self.diffusion_scheduler = DIFFUSION_SCHEDULERS.build( - diffusion_scheduler) + diffusion_scheduler) if isinstance(diffusion_scheduler, + dict) else diffusion_scheduler self.unet_up = None self.diffusion_scheduler_up = None if unet_up: - self.unet_up = MODELS.build(unet_up) + self.unet_up = unet_up if isinstance( + unet_up, nn.Module) else MODELS.build(unet_up) if diffusion_scheduler_up: self.diffusion_scheduler_up = DIFFUSION_SCHEDULERS.build( - diffusion_scheduler_up) + diffusion_scheduler_up) if isinstance( + diffusion_scheduler_up, + dict) else diffusion_scheduler_up else: self.diffusion_scheduler_up = deepcopy( self.diffusion_scheduler) diff --git a/tests/test_models/test_editors/test_glide/test_glide.py b/tests/test_models/test_editors/test_glide/test_glide.py index 8f40545e0f..21e5420207 100644 --- a/tests/test_models/test_editors/test_glide/test_glide.py +++ b/tests/test_models/test_editors/test_glide/test_glide.py @@ -1,104 +1,127 @@ # Copyright (c) OpenMMLab. All rights reserved. -import pytest +import unittest +from copy import deepcopy +from unittest import TestCase + import torch -from mmengine import MODELS, Config +from mmedit.models.diffusion_schedulers import EditDDIMScheduler from mmedit.utils import register_all_modules +from projects.glide.models import Glide, SuperResText2ImUNet, Text2ImUNet register_all_modules() -unet = dict( - type='Text2ImUNet', - image_size=64, - base_channels=192, - in_channels=3, - resblocks_per_downsample=3, - attention_res=(32, 16, 8), - norm_cfg=dict(type='GN32', num_groups=32), - dropout=0.1, - num_classes=0, - use_fp16=False, - resblock_updown=True, - attention_cfg=dict( - type='MultiHeadAttentionBlock', - num_heads=1, - num_head_channels=64, - use_new_attention_order=False, - encoder_channels=512), - use_scale_shift_norm=True, - text_ctx=128, - xf_width=512, - xf_layers=16, - xf_heads=8, - xf_final_ln=True, - xf_padding=True, -) -diffusion_scheduler = dict( - type='EditDDIMScheduler', - variance_type='learned_range', - beta_schedule='squaredcos_cap_v2') +class TestDiscoDiffusion(TestCase): + + def setUp(self): + # low resolution cfg + unet_cfg = dict( + image_size=64, + base_channels=192, + in_channels=3, + resblocks_per_downsample=3, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, + ) + diffusion_scheduler_cfg = dict( + variance_type='learned_range', beta_schedule='squaredcos_cap_v2') + + # unet + self.unet = Text2ImUNet(**unet_cfg) + # diffusion_scheduler + self.diffusion_scheduler = EditDDIMScheduler(**diffusion_scheduler_cfg) -unet_up = dict( - type='SuperResText2ImUNet', - image_size=256, - base_channels=192, - in_channels=3, - output_cfg=dict(var='FIXED'), - resblocks_per_downsample=2, - attention_res=(32, 16, 8), - norm_cfg=dict(type='GN32', num_groups=32), - dropout=0.1, - num_classes=0, - use_fp16=False, - resblock_updown=True, - attention_cfg=dict( - type='MultiHeadAttentionBlock', - num_heads=1, - num_head_channels=64, - use_new_attention_order=False, - encoder_channels=512), - use_scale_shift_norm=True, - text_ctx=128, - xf_width=512, - xf_layers=16, - xf_heads=8, - xf_final_ln=True, - xf_padding=True, -) + # high resolution cfg + unet_up_cfg = dict( + image_size=256, + base_channels=192, + in_channels=3, + output_cfg=dict(var='FIXED'), + resblocks_per_downsample=2, + attention_res=(32, 16, 8), + norm_cfg=dict(type='GN32', num_groups=32), + dropout=0.1, + num_classes=0, + use_fp16=False, + resblock_updown=True, + attention_cfg=dict( + type='MultiHeadAttentionBlock', + num_heads=1, + num_head_channels=64, + use_new_attention_order=False, + encoder_channels=512), + use_scale_shift_norm=True, + text_ctx=128, + xf_width=512, + xf_layers=16, + xf_heads=8, + xf_final_ln=True, + xf_padding=True, + ) + diffusion_scheduler_up_cfg = dict( + variance_type='learned_range', beta_schedule='linear') -diffusion_scheduler_up = dict( - type='EditDDIMScheduler', - variance_type='learned_range', - beta_schedule='linear') + # unet up + self.unet_up = SuperResText2ImUNet(**unet_up_cfg) + self.diffusion_scheduler_up = EditDDIMScheduler( + **diffusion_scheduler_up_cfg) -model = dict( - type='Glide', - data_preprocessor=dict( - type='EditDataPreprocessor', mean=[127.5], std=[127.5]), - unet=unet, - diffusion_scheduler=diffusion_scheduler, - unet_up=unet_up, - diffusion_scheduler_up=diffusion_scheduler_up, - use_fp16=False) + def test_init(self): + # low resolution + unet = deepcopy(self.unet) + diffusion_scheduler = deepcopy(self.diffusion_scheduler) + self.GLIDE_low = Glide( + unet=unet, diffusion_scheduler=diffusion_scheduler) + # high resolution + unet_up = deepcopy(self.unet_up) + diffusion_scheduler_up = deepcopy(self.diffusion_scheduler_up) + self.GLIDE_high = Glide( + unet=unet, + diffusion_scheduler=diffusion_scheduler, + unet_up=unet_up, + diffusion_scheduler_up=diffusion_scheduler_up) -def test_glide(): - glide = MODELS.build(Config(model)) - prompt = 'an oil painting of a corgi' + @unittest.skipIf(not torch.cuda.is_available(), reason='requires cuda') + def test_infer(self): + # test low resolution + unet = deepcopy(self.unet) + diffusion_scheduler = deepcopy(self.diffusion_scheduler) + self.GLIDE = Glide(unet=unet, diffusion_scheduler=diffusion_scheduler) + self.GLIDE.cuda().eval() - with pytest.raises(Exception): - glide.infer( - prompt=prompt, - batch_size=1, - num_inference_steps=1, - num_inference_steps_up=1) + # test infer in low resolution + text_prompts = 'clouds surround the mountains and palaces,sunshine' + image = self.GLIDE.infer( + prompt=text_prompts, show_progress=True, + num_inference_steps=2)['samples'] + assert image.shape == (1, 3, 64, 64) - result = glide.infer( - init_image=torch.randn(1, 3, 64, 64), - prompt=prompt, - batch_size=1, - guidance_scale=3.0, - num_inference_steps=1, - num_inference_steps_up=1) - assert result['samples'].shape == (1, 3, 256, 256) + # test high resolution + unet_up = deepcopy(self.unet_up) + diffusion_scheduler_up = deepcopy(self.diffusion_scheduler_up) + self.GLIDE_high = Glide(unet, diffusion_scheduler, unet_up, + diffusion_scheduler_up) + self.GLIDE_high.cuda().eval() + image = self.GLIDE_high.infer( + prompt=text_prompts, show_progress=True, + num_inference_steps=2)['samples'] + assert image.shape == (1, 3, 256, 256) From ff49c9d4253516e7e9ff9e32b82e86dd57c8b6f1 Mon Sep 17 00:00:00 2001 From: Taited Date: Wed, 12 Apr 2023 21:57:40 +0800 Subject: [PATCH 08/14] skip windows test for limited RAM --- tests/test_models/test_editors/test_glide/test_glide.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_editors/test_glide/test_glide.py b/tests/test_models/test_editors/test_glide/test_glide.py index 21e5420207..c6c62addc3 100644 --- a/tests/test_models/test_editors/test_glide/test_glide.py +++ b/tests/test_models/test_editors/test_glide/test_glide.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform import unittest from copy import deepcopy from unittest import TestCase @@ -12,7 +13,7 @@ register_all_modules() -class TestDiscoDiffusion(TestCase): +class TestGLIDE(TestCase): def setUp(self): # low resolution cfg @@ -101,6 +102,9 @@ def test_init(self): diffusion_scheduler_up=diffusion_scheduler_up) @unittest.skipIf(not torch.cuda.is_available(), reason='requires cuda') + @unittest.skipIf( + 'win' in platform.system().lower(), + reason='skip on windows due to limited RAM.') def test_infer(self): # test low resolution unet = deepcopy(self.unet) From 793cad8df7ca449cb493eb97a8d5b86e95145e72 Mon Sep 17 00:00:00 2001 From: Taited Date: Thu, 13 Apr 2023 14:26:45 +0800 Subject: [PATCH 09/14] skip test_init in windows for RAM limitation --- tests/test_models/test_editors/test_glide/test_glide.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_models/test_editors/test_glide/test_glide.py b/tests/test_models/test_editors/test_glide/test_glide.py index c6c62addc3..04dcda14d5 100644 --- a/tests/test_models/test_editors/test_glide/test_glide.py +++ b/tests/test_models/test_editors/test_glide/test_glide.py @@ -85,6 +85,9 @@ def setUp(self): self.diffusion_scheduler_up = EditDDIMScheduler( **diffusion_scheduler_up_cfg) + @unittest.skipIf( + 'win' in platform.system().lower(), + reason='skip on windows due to limited RAM.') def test_init(self): # low resolution unet = deepcopy(self.unet) From 86bb99ed34ae37eae11240d7c97598c3c8a2dd4d Mon Sep 17 00:00:00 2001 From: Taited Date: Sun, 23 Apr 2023 22:54:10 +0800 Subject: [PATCH 10/14] skip cpu and windows test --- tests/test_models/test_editors/test_glide/test_glide.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_models/test_editors/test_glide/test_glide.py b/tests/test_models/test_editors/test_glide/test_glide.py index 04dcda14d5..fc1b91ddb7 100644 --- a/tests/test_models/test_editors/test_glide/test_glide.py +++ b/tests/test_models/test_editors/test_glide/test_glide.py @@ -104,10 +104,10 @@ def test_init(self): unet_up=unet_up, diffusion_scheduler_up=diffusion_scheduler_up) - @unittest.skipIf(not torch.cuda.is_available(), reason='requires cuda') @unittest.skipIf( - 'win' in platform.system().lower(), - reason='skip on windows due to limited RAM.') + ('win' in platform.system().lower()) + or (not torch.cuda.is_available()), + reason='skip on windows and cpu due to limited RAM.') def test_infer(self): # test low resolution unet = deepcopy(self.unet) From 5ca574a0aba2d37c298f88a470f290f15a94de81 Mon Sep 17 00:00:00 2001 From: Taited Date: Wed, 26 Apr 2023 11:07:50 +0800 Subject: [PATCH 11/14] optimize unit test when running in low RAM --- .../test_editors/test_glide/test_glide.py | 39 +++---------------- 1 file changed, 6 insertions(+), 33 deletions(-) diff --git a/tests/test_models/test_editors/test_glide/test_glide.py b/tests/test_models/test_editors/test_glide/test_glide.py index fc1b91ddb7..66aa69833c 100644 --- a/tests/test_models/test_editors/test_glide/test_glide.py +++ b/tests/test_models/test_editors/test_glide/test_glide.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import platform import unittest -from copy import deepcopy from unittest import TestCase import torch @@ -89,46 +88,20 @@ def setUp(self): 'win' in platform.system().lower(), reason='skip on windows due to limited RAM.') def test_init(self): - # low resolution - unet = deepcopy(self.unet) - diffusion_scheduler = deepcopy(self.diffusion_scheduler) - self.GLIDE_low = Glide( - unet=unet, diffusion_scheduler=diffusion_scheduler) - - # high resolution - unet_up = deepcopy(self.unet_up) - diffusion_scheduler_up = deepcopy(self.diffusion_scheduler_up) - self.GLIDE_high = Glide( - unet=unet, - diffusion_scheduler=diffusion_scheduler, - unet_up=unet_up, - diffusion_scheduler_up=diffusion_scheduler_up) + self.GLIDE = Glide( + unet=self.unet, + diffusion_scheduler=self.diffusion_scheduler, + unet_up=self.unet_up, + diffusion_scheduler_up=self.diffusion_scheduler_up) @unittest.skipIf( ('win' in platform.system().lower()) or (not torch.cuda.is_available()), reason='skip on windows and cpu due to limited RAM.') def test_infer(self): - # test low resolution - unet = deepcopy(self.unet) - diffusion_scheduler = deepcopy(self.diffusion_scheduler) - self.GLIDE = Glide(unet=unet, diffusion_scheduler=diffusion_scheduler) - self.GLIDE.cuda().eval() - - # test infer in low resolution + # test infer resolution text_prompts = 'clouds surround the mountains and palaces,sunshine' image = self.GLIDE.infer( prompt=text_prompts, show_progress=True, num_inference_steps=2)['samples'] - assert image.shape == (1, 3, 64, 64) - - # test high resolution - unet_up = deepcopy(self.unet_up) - diffusion_scheduler_up = deepcopy(self.diffusion_scheduler_up) - self.GLIDE_high = Glide(unet, diffusion_scheduler, unet_up, - diffusion_scheduler_up) - self.GLIDE_high.cuda().eval() - image = self.GLIDE_high.infer( - prompt=text_prompts, show_progress=True, - num_inference_steps=2)['samples'] assert image.shape == (1, 3, 256, 256) From f6693d238def9e78a7d310146d52905b594b5b91 Mon Sep 17 00:00:00 2001 From: Taited Date: Thu, 1 Jun 2023 13:24:11 +0800 Subject: [PATCH 12/14] adapt to mmagic --- projects/glide/models/glide.py | 8 +++++--- tests/test_models/test_editors/test_glide/test_glide.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index 7201c59382..fe7c8c1e3c 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -12,9 +12,9 @@ from mmengine.runner.checkpoint import _load_checkpoint_with_prefix from tqdm import tqdm -from mmedit.registry import DIFFUSION_SCHEDULERS, MODELS -from mmedit.structures import EditDataSample -from mmedit.utils.typing import ForwardInputs, SampleList +from mmagic.registry import DIFFUSION_SCHEDULERS, MODELS +from mmagic.structures import DataSample +from mmagic.utils.typing import ForwardInputs, SampleList ModelType = Union[Dict, nn.Module] @@ -359,6 +359,7 @@ def forward(self, batch_sample_list.append(gen_sample) return batch_sample_list + @torch.no_grad() def val_step(self, data: dict) -> SampleList: """Gets the generated image of given data. @@ -377,6 +378,7 @@ def val_step(self, data: dict) -> SampleList: outputs = self(**data) return outputs + @torch.no_grad() def test_step(self, data: dict) -> SampleList: """Gets the generated image of given data. Same as :meth:`val_step`. diff --git a/tests/test_models/test_editors/test_glide/test_glide.py b/tests/test_models/test_editors/test_glide/test_glide.py index 66aa69833c..fd561a7e01 100644 --- a/tests/test_models/test_editors/test_glide/test_glide.py +++ b/tests/test_models/test_editors/test_glide/test_glide.py @@ -5,8 +5,8 @@ import torch -from mmedit.models.diffusion_schedulers import EditDDIMScheduler -from mmedit.utils import register_all_modules +from mmagic.models.diffusion_schedulers import EditDDIMScheduler +from mmagic.utils import register_all_modules from projects.glide.models import Glide, SuperResText2ImUNet, Text2ImUNet register_all_modules() From 1686e662f92e82a362b8eabff4eb72268c76a876 Mon Sep 17 00:00:00 2001 From: Taited Date: Thu, 1 Jun 2023 13:36:11 +0800 Subject: [PATCH 13/14] update mdformat --- projects/glide/configs/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/glide/configs/README.md b/projects/glide/configs/README.md index d128286514..0d7e4e3ca6 100644 --- a/projects/glide/configs/README.md +++ b/projects/glide/configs/README.md @@ -39,7 +39,6 @@ Diffusion models have recently been shown to generate high-quality synthetic ima | Glide | 64x64 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64x64-02afff47.pth) | | Glide | 64x64 -> 256x256 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64-256-02afff47.pth) | - ## Quick Start You can run glide as follows: From 535016aaf6eb8b0a001a9d59eae947e42b1d8f8e Mon Sep 17 00:00:00 2001 From: Taited Date: Thu, 1 Jun 2023 13:48:34 +0800 Subject: [PATCH 14/14] rename EditDataPreprocessor to DataPreprocessor --- .../glide/configs/glide_ddim-classifier-free_laion-64-256.py | 3 +-- projects/glide/models/glide.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py b/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py index 6a357b9df6..23774d0741 100644 --- a/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py +++ b/projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py @@ -54,8 +54,7 @@ model = dict( type='Glide', - data_preprocessor=dict( - type='EditDataPreprocessor', mean=[127.5], std=[127.5]), + data_preprocessor=dict(type='DataPreprocessor', mean=[127.5], std=[127.5]), unet=unet_cfg, diffusion_scheduler=dict( type='EditDDIMScheduler', diff --git a/projects/glide/models/glide.py b/projects/glide/models/glide.py index fe7c8c1e3c..dda6ea2989 100644 --- a/projects/glide/models/glide.py +++ b/projects/glide/models/glide.py @@ -69,7 +69,7 @@ def __init__(self, classifier: Optional[dict] = None, classifier_scale: float = 1.0, data_preprocessor: Optional[ModelType] = dict( - type='EditDataPreprocessor'), + type='DataPreprocessor'), pretrained_cfgs: Optional[dict] = None): super().__init__(data_preprocessor=data_preprocessor)