Skip to content

Commit

Permalink
Merge 535016a into 87622b9
Browse files Browse the repository at this point in the history
  • Loading branch information
Taited authored Jun 1, 2023
2 parents 87622b9 + 535016a commit 4742e44
Show file tree
Hide file tree
Showing 7 changed files with 405 additions and 71 deletions.
35 changes: 32 additions & 3 deletions projects/glide/configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ Diffusion models have recently been shown to generate high-quality synthetic ima

**Laion**

| Method | Resolution | Config | Weights |
| ------ | ---------- | -------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- |
| Glide | 64x64 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py) | [model](https://download.openmmlab.com/mmagic/glide/glide_laion-64x64-02afff47.pth) |
| Method | Resolution | Config | Weights |
| ------ | ---------------- | --------------------------------------------------------------------------- | --------------------------------------------------------------------------------------- |
| Glide | 64x64 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64x64.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64x64-02afff47.pth) |
| Glide | 64x64 -> 256x256 | [config](projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py) | [model](https://download.openmmlab.com/mmediting/glide/glide_laion-64-256-02afff47.pth) |

## Quick Start

Expand Down Expand Up @@ -66,6 +67,34 @@ with torch.no_grad():
show_progress=True)['samples']
```

You can synthesis images with 256x256 resolution:

```python
import torch
from torchvision.utils import save_image
from mmedit.apis import init_model
from mmengine.registry import init_default_scope
from projects.glide.models import *

init_default_scope('mmedit')

config = 'projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py'
ckpt = 'https://download.openmmlab.com/mmediting/glide/glide_laion-64-256-02afff47.pth'
model = init_model(config, ckpt).cuda().eval()
prompt = "an oil painting of a corgi"

with torch.no_grad():
samples = model.infer(init_image=None,
prompt=prompt,
batch_size=16,
guidance_scale=3.,
num_inference_steps=100,
labels=None,
classifier_scale=0.0,
show_progress=True)['samples']
save_image(samples, "corgi.png", nrow=4, normalize=True, value_range=(-1, 1))
```

## Citation

```bibtex
Expand Down
68 changes: 68 additions & 0 deletions projects/glide/configs/glide_ddim-classifier-free_laion-64-256.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
unet_cfg = dict(
type='Text2ImUNet',
image_size=64,
base_channels=192,
in_channels=3,
resblocks_per_downsample=3,
attention_res=(32, 16, 8),
norm_cfg=dict(type='GN32', num_groups=32),
dropout=0.1,
num_classes=0,
use_fp16=False,
resblock_updown=True,
attention_cfg=dict(
type='MultiHeadAttentionBlock',
num_heads=1,
num_head_channels=64,
use_new_attention_order=False,
encoder_channels=512),
use_scale_shift_norm=True,
text_ctx=128,
xf_width=512,
xf_layers=16,
xf_heads=8,
xf_final_ln=True,
xf_padding=True,
)
unet_up_cfg = dict(
type='SuperResText2ImUNet',
image_size=256,
base_channels=192,
in_channels=3,
output_cfg=dict(var='FIXED'),
resblocks_per_downsample=2,
attention_res=(32, 16, 8),
norm_cfg=dict(type='GN32', num_groups=32),
dropout=0.1,
num_classes=0,
use_fp16=False,
resblock_updown=True,
attention_cfg=dict(
type='MultiHeadAttentionBlock',
num_heads=1,
num_head_channels=64,
use_new_attention_order=False,
encoder_channels=512),
use_scale_shift_norm=True,
text_ctx=128,
xf_width=512,
xf_layers=16,
xf_heads=8,
xf_final_ln=True,
xf_padding=True,
)

model = dict(
type='Glide',
data_preprocessor=dict(type='DataPreprocessor', mean=[127.5], std=[127.5]),
unet=unet_cfg,
diffusion_scheduler=dict(
type='EditDDIMScheduler',
variance_type='learned_range',
beta_schedule='squaredcos_cap_v2'),
unet_up=unet_up_cfg,
diffusion_scheduler_up=dict(
type='EditDDIMScheduler',
variance_type='learned_range',
beta_schedule='linear'),
use_fp16=False)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
xf_padding=True,
),
diffusion_scheduler=dict(
type='DDIMScheduler',
type='EditDDIMScheduler',
variance_type='learned_range',
beta_schedule='squaredcos_cap_v2'),
use_fp16=False)
4 changes: 2 additions & 2 deletions projects/glide/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .glide import Glide
from .text2im_unet import Text2ImUNet
from .text2im_unet import SuperResText2ImUNet, Text2ImUNet

__all__ = ['Text2ImUNet', 'Glide']
__all__ = ['Text2ImUNet', 'Glide', 'SuperResText2ImUNet']
190 changes: 149 additions & 41 deletions projects/glide/models/glide.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
from tqdm import tqdm

from mmagic.registry import DIFFUSION_SCHEDULERS, MODELS, MODULES
from mmagic.registry import DIFFUSION_SCHEDULERS, MODELS
from mmagic.structures import DataSample
from mmagic.utils.typing import ForwardInputs, SampleList

# from .guider import ImageTextGuider

ModelType = Union[Dict, nn.Module]


Expand All @@ -32,38 +30,70 @@ def classifier_grad(classifier, x, t, y=None, classifier_scale=1.0):
return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale


@MODELS.register_module('GLIDE')
@MODELS.register_module()
class Glide(BaseModel):
"""Guided diffusion Model.
"""GLIDE: Guided language to image diffusion for generation and editing.
Refer to: https://github.com/openai/glide-text2im.
Args:
data_preprocessor (dict, optional): The pre-process config of
data_preprocessor (dict, optional): The pre-process configuration for
:class:`BaseDataPreprocessor`.
unet (ModelType): Config of denoising Unet.
diffusion_scheduler (ModelType): Config of diffusion_scheduler
unet (ModelType): Configuration for the denoising Unet.
diffusion_scheduler (ModelType): Configuration for the diffusion
scheduler.
use_fp16 (bool): Whether to use fp16 for unet model. Defaults to False.
classifier (ModelType): Config of classifier. Defaults to None.
pretrained_cfgs (dict): Path Config for pretrained weights. Usually
this is a dict contains module name and the corresponding ckpt
path.Defaults to None.
unet_up (ModelType, optional): Configuration for the upsampling
denoising UNet. Defaults to None.
diffusion_scheduler_up (ModelType, optional): Configuration for
the upsampling diffusion scheduler. Defaults to None.
use_fp16 (bool, optional): Whether to use fp16 for the unet model.
Defaults to False.
classifier (ModelType, optional): Configuration for the classifier.
Defaults to None.
classifier_scale (float): Classifier scale for classifier guidance.
Defaults to 1.0.
data_preprocessor (Optional[ModelType]): Configuration for the data
preprocessor.
pretrained_cfgs (dict, optional): Path configuration for pretrained
weights. Usually, this is a dict containing the module name and
the corresponding ckpt path. Defaults to None.
"""

def __init__(self,
data_preprocessor,
unet,
diffusion_scheduler,
use_fp16=False,
classifier=None,
classifier_scale=1.0,
pretrained_cfgs=None):
unet: ModelType,
diffusion_scheduler: ModelType,
unet_up: Optional[ModelType] = None,
diffusion_scheduler_up: Optional[ModelType] = None,
use_fp16: Optional[bool] = False,
classifier: Optional[dict] = None,
classifier_scale: float = 1.0,
data_preprocessor: Optional[ModelType] = dict(
type='DataPreprocessor'),
pretrained_cfgs: Optional[dict] = None):

super().__init__(data_preprocessor=data_preprocessor)
self.unet = MODULES.build(unet)
self.unet = unet if isinstance(unet, nn.Module) else MODELS.build(unet)
self.diffusion_scheduler = DIFFUSION_SCHEDULERS.build(
diffusion_scheduler)
diffusion_scheduler) if isinstance(diffusion_scheduler,
dict) else diffusion_scheduler

self.unet_up = None
self.diffusion_scheduler_up = None
if unet_up:
self.unet_up = unet_up if isinstance(
unet_up, nn.Module) else MODELS.build(unet_up)
if diffusion_scheduler_up:
self.diffusion_scheduler_up = DIFFUSION_SCHEDULERS.build(
diffusion_scheduler_up) if isinstance(
diffusion_scheduler_up,
dict) else diffusion_scheduler_up
else:
self.diffusion_scheduler_up = deepcopy(
self.diffusion_scheduler)

if classifier:
self.classifier = MODULES.build(classifier)
self.classifier = MODELS.build(classifier)
else:
self.classifier = None
self.classifier_scale = classifier_scale
Expand Down Expand Up @@ -101,26 +131,34 @@ def device(self):

@torch.no_grad()
def infer(self,
init_image=None,
prompt=None,
batch_size=1,
guidance_scale=3.,
num_inference_steps=50,
labels=None,
classifier_scale=0.0,
show_progress=False):
"""_summary_
init_image: Optional[torch.Tensor] = None,
prompt: str = None,
batch_size: Optional[int] = 1,
guidance_scale: float = 3.,
num_inference_steps: int = 50,
num_inference_steps_up: Optional[int] = 27,
labels: Optional[torch.Tensor] = None,
classifier_scale: float = 0.0,
show_progress: Optional[bool] = False):
"""Inference function for guided diffusion.
Args:
init_image (_type_, optional): _description_. Defaults to None.
batch_size (int, optional): _description_. Defaults to 1.
num_inference_steps (int, optional): _description_.
Defaults to 1000.
labels (_type_, optional): _description_. Defaults to None.
show_progress (bool, optional): _description_. Defaults to False.
init_image (torch.Tensor, optional): Starting noise for diffusion.
Defaults to None.
prompt (str): The prompt to guide the image generation.
batch_size (int, optional): Batch size for generation.
Defaults to 1.
num_inference_steps (int, optional): The number of denoising steps.
Defaults to 50.
num_inference_steps_up (int, optional): The number of upsampling
denoising steps. Defaults to 27.
labels (torch.Tensor, optional): Labels for the classifier.
Defaults to None.
show_progress (bool, optional): Whether to show the progress bar.
Defaults to False.
Returns:
_type_: _description_
torch.Tensor: Generated images.
"""
# Sample gaussian noise to begin loop
if init_image is None:
Expand Down Expand Up @@ -167,9 +205,6 @@ def infer(self,
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
noise_pred = torch.cat([eps, rest], dim=1)
# noise_pred_text, noise_pred_uncond = model_output.chunk(2)
# noise_pred = noise_pred_uncond + guidance_scale *
# (noise_pred_text - noise_pred_uncond)

# 2. compute previous image: x_t -> x_t-1
diffusion_scheduler_output = self.diffusion_scheduler.step(
Expand All @@ -191,8 +226,79 @@ def infer(self,
else:
image = diffusion_scheduler_output['prev_sample']

# abandon unconditional image
image = image[:image.shape[0] // 2]

if self.unet_up:
image = self.infer_up(
low_res_img=image,
batch_size=batch_size,
prompt=prompt,
num_inference_steps=num_inference_steps_up)

return {'samples': image}

@torch.no_grad()
def infer_up(self,
low_res_img: torch.Tensor,
batch_size: int = 1,
init_image: Optional[torch.Tensor] = None,
prompt: Optional[str] = None,
num_inference_steps: int = 27,
show_progress: bool = False):
"""Inference function for upsampling guided diffusion.
Args:
low_res_img (torch.Tensor): Low resolution image
(shape: [B, C, H, W]) for upsampling.
batch_size (int, optional): Batch size for generation.
Defaults to 1.
init_image (torch.Tensor, optional): Starting noise
(shape: [B, C, H, W]) for diffusion. Defaults to None.
prompt (str, optional): The text prompt to guide the image
generation. Defaults to None.
num_inference_steps (int, optional): The number of denoising
steps. Defaults to 27.
show_progress (bool, optional): Whether to show the progress bar.
Defaults to False.
Returns:
torch.Tensor: Generated upsampled images (shape: [B, C, H, W]).
"""
if init_image is None:
image = torch.randn(
(batch_size, self.unet_up.in_channels // 2,
self.unet_up.image_size, self.unet_up.image_size))
image = image.to(self.device)
else:
image = init_image

# set step values
if num_inference_steps > 0:
self.diffusion_scheduler_up.set_timesteps(num_inference_steps)
timesteps = self.diffusion_scheduler_up.timesteps

# text embedding
tokens = self.unet.tokenizer.encode(prompt)
tokens, mask = self.unet.tokenizer.padded_tokens_and_mask(tokens, 128)
tokens = torch.tensor(
[tokens] * batch_size, dtype=torch.bool, device=self.device)
mask = torch.tensor(
[mask] * batch_size, dtype=torch.bool, device=self.device)

if show_progress and mmengine.dist.is_main_process():
timesteps = tqdm(timesteps)

for t in timesteps:
noise_pred = self.unet_up(
image, t, low_res=low_res_img, tokens=tokens, mask=mask)
# compute previous image: x_t -> x_t-1
diffusion_scheduler_output = self.diffusion_scheduler_up.step(
noise_pred, t, image)
image = diffusion_scheduler_output['prev_sample']

return image

def forward(self,
inputs: ForwardInputs,
data_samples: Optional[list] = None,
Expand Down Expand Up @@ -253,6 +359,7 @@ def forward(self,
batch_sample_list.append(gen_sample)
return batch_sample_list

@torch.no_grad()
def val_step(self, data: dict) -> SampleList:
"""Gets the generated image of given data.
Expand All @@ -271,6 +378,7 @@ def val_step(self, data: dict) -> SampleList:
outputs = self(**data)
return outputs

@torch.no_grad()
def test_step(self, data: dict) -> SampleList:
"""Gets the generated image of given data. Same as :meth:`val_step`.
Expand Down
Loading

0 comments on commit 4742e44

Please sign in to comment.