Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] add t2i infer. #1504

Merged
merged 2 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions configs/disco_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ Created by Somnai, augmented by Gandamu, and building on the work of RiversHaveW

We have converted several `unet` weights and offer related configs. Or usage of different `unet`, please refer to tutorial.

| Diffusion Model | Config | Weights |
| ---------------------------------------- | --------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- |
| 512x512_diffusion_uncond_finetune_008100 | [config](configs/disco/disco-diffusion_adm-u-finetuned_imagenet-512x512.py) | [weights](https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u_finetuned_imagenet-512x512-ab471d70.pth) |
| 256x256_diffusion_uncond | [config](configs/disco/disco-diffusion_adm-u-finetuned_imagenet-256x256.py) | [weights](<>) |
| portrait_generator_v001 | [config](configs/disco/disco-diffusion_portrait_generator_v001.py) | [weights](https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u-cvt-rgb_portrait-v001-f4a3f3bc.pth) |
| pixelartdiffusion_expanded | Coming soon! | |
| pixel_art_diffusion_hard_256 | Coming soon! | |
| pixel_art_diffusion_soft_256 | Coming soon! | |
| pixelartdiffusion4k | Coming soon! | |
| watercolordiffusion_2 | Coming soon! | |
| watercolordiffusion | Coming soon! | |
| PulpSciFiDiffusion | Coming soon! | |
| Diffusion Model | Config | Download |
| :--------------------------------------: | :-------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------: |
| 512x512_diffusion_uncond_finetune_008100 | [config](configs/disco/disco-diffusion_adm-u-finetuned_imagenet-512x512.py) | [model](https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u_finetuned_imagenet-512x512-ab471d70.pth) |
| 256x256_diffusion_uncond | [config](configs/disco/disco-diffusion_adm-u-finetuned_imagenet-256x256.py) | [model](<>) |
| portrait_generator_v001 | [config](configs/disco/disco-diffusion_portrait_generator_v001.py) | [model](https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u-cvt-rgb_portrait-v001-f4a3f3bc.pth) |
| pixelartdiffusion_expanded | Coming soon! | |
| pixel_art_diffusion_hard_256 | Coming soon! | |
| pixel_art_diffusion_soft_256 | Coming soon! | |
| pixelartdiffusion4k | Coming soon! | |
| watercolordiffusion_2 | Coming soon! | |
| watercolordiffusion | Coming soon! | |
| PulpSciFiDiffusion | Coming soon! | |

## To-do List

Expand Down
32 changes: 31 additions & 1 deletion configs/disco_diffusion/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,34 @@ Collections:
Paper:
- https://github.com/alembics/disco-diffusion
README: configs/disco_diffusion/README.md
Models: []
Models:
- Config: configs/disco/disco-diffusion_adm-u-finetuned_imagenet-512x512.py
In Collection: Disco Diffusion
Metadata:
Training Data: Others
Name: disco-diffusion_adm-u-finetuned_imagenet-512x512
Results:
- Dataset: Others
Metrics: {}
Task: Text2Image, Image2Image
Weights: https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u_finetuned_imagenet-512x512-ab471d70.pth
- Config: configs/disco/disco-diffusion_adm-u-finetuned_imagenet-256x256.py
In Collection: Disco Diffusion
Metadata:
Training Data: Others
Name: disco-diffusion_adm-u-finetuned_imagenet-256x256
Results:
- Dataset: Others
Metrics: {}
Task: Text2Image, Image2Image
Weights: <>
- Config: configs/disco/disco-diffusion_portrait_generator_v001.py
In Collection: Disco Diffusion
Metadata:
Training Data: Others
Name: disco-diffusion_portrait_generator_v001
Results:
- Dataset: Others
Metrics: {}
Task: Text2Image, Image2Image
Weights: https://download.openmmlab.com/mmediting/synthesizers/disco/adm-u-cvt-rgb_portrait-v001-f4a3f3bc.pth
11 changes: 11 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Table of contents:

&#8195;     [2.2.8. Video Super-Resolution example](#228-video-super-resolution)

&#8195;     [2.2.9. Text-to-Image example](#229-text-to-image)

[3. Other demos](#3-other-demos)

## 1. Download sample images or videos
Expand Down Expand Up @@ -173,6 +175,15 @@ python mmediting_inference_demo.py \
--result-out-dir ../resources/output/video_restoration/demo_video_restoration_edvr_res.mp4
```

#### 2.2.9 Text-to-Image

```shell
python mmediting_inference_demo.py \
--model-name disco \
--text 0=["clouds surround the mountains and Chinese palaces,sunshine,lake,overlook,overlook,unreal engine,light effect,Dream,Greg Rutkowski,James Gurney,artstation"] \
--result-out-dir ../resources/output/text2image/demo_text2image_disco_res.png
```

#### 2.2.9 3D-aware Generation (EG3D)

```shell
Expand Down
5 changes: 5 additions & 0 deletions demo/mmediting_inference_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def parse_args():
type=str,
default=None,
help='path to input mask file for inpainting models')
parser.add_argument(
'--text',
nargs='+',
action=DictAction,
help='text input for text2image models')
parser.add_argument(
'--result-out-dir',
type=str,
Expand Down
82 changes: 70 additions & 12 deletions demo/mmediting_inference_tutorial.ipynb

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions mmedit/apis/inferencers/base_mmedit_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ def _init_extra_parameters(self, extra_parameters: Dict) -> None:
if key in extra_parameters.keys():
self.extra_parameters[key] = extra_parameters[key]

def _update_extra_parameters(self, **kwargs) -> None:
"""update extra_parameters during run time."""
if 'extra_parameters' in kwargs:
input_extra_parameters = kwargs['extra_parameters']
if input_extra_parameters is not None:
for key in self.extra_parameters.keys():
if key in input_extra_parameters.keys():
self.extra_parameters[key] = \
input_extra_parameters[key]

def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]:
"""Dispatch kwargs to preprocess(), forward(), visualize() and
postprocess() according to the actual demands."""
Expand All @@ -112,6 +122,8 @@ def __call__(self, **kwargs) -> Union[Dict, List[Dict]]:
Returns:
Union[Dict, List[Dict]]: Results of inference pipeline.
"""
self._update_extra_parameters(**kwargs)

params = self._dispatch_kwargs(**kwargs)
preprocess_kwargs = self.base_params[0].copy()
preprocess_kwargs.update(params[0])
Expand Down
4 changes: 4 additions & 0 deletions mmedit/apis/inferencers/mmedit_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .inpainting_inferencer import InpaintingInferencer
from .matting_inferencer import MattingInferencer
from .restoration_inferencer import RestorationInferencer
from .text2image_inferencer import Text2ImageInferencer
from .translation_inferencer import TranslationInferencer
from .unconditional_inferencer import UnconditionalInferencer
from .video_interpolation_inferencer import VideoInterpolationInferencer
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(self,
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['video_interpolation', 'Video Interpolation']:
self.inferencer = VideoInterpolationInferencer(
config, ckpt, device, extra_parameters)
elif self.task in ['text2image', 'Text2Image']:
self.inferencer = Text2ImageInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['3D_aware_generation', '3D-aware Generation']:
self.inferencer = EG3DInferencer(
Expand Down
83 changes: 83 additions & 0 deletions mmedit/apis/inferencers/text2image_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List

import numpy as np
from mmengine import mkdir_or_exist
from torchvision.utils import save_image

from .base_mmedit_inferencer import BaseMMEditInferencer, InputsType, PredType


class Text2ImageInferencer(BaseMMEditInferencer):
"""inferencer that predicts with text2image models."""

func_kwargs = dict(
preprocess=['text'],
forward=[],
visualize=['result_out_dir'],
postprocess=[])

extra_parameters = dict(
scheduler_kwargs=None,
height=None,
width=None,
init_image=None,
batch_size=1,
num_inference_steps=1000,
skip_steps=0,
show_progress=False,
text_prompts=[],
image_prompts=[],
eta=0.8,
clip_guidance_scale=5000,
init_scale=1000,
tv_scale=0.,
sat_scale=0.,
range_scale=150,
cut_overview=[12] * 400 + [4] * 600,
cut_innercut=[4] * 400 + [12] * 600,
cut_ic_pow=[1] * 1000,
cut_icgray_p=[0.2] * 400 + [0] * 600,
cutn_batches=4,
seed=2022)

def preprocess(self, text: InputsType) -> Dict:
"""Process the inputs into a model-feedable format.

Args:
text(InputsType): text input for text-to-image model.

Returns:
result(Dict): Results of preprocess.
"""
result = self.extra_parameters
result['text_prompts'] = text

return result

def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""
image = self.model.infer(**inputs)['samples']

return image

def visualize(self,
preds: PredType,
result_out_dir: str = None) -> List[np.ndarray]:
"""Visualize predictions.

Args:
preds (List[Union[str, np.ndarray]]): Forward results
by the inferencer.
result_out_dir (str): Output directory of image.
Defaults to ''.

Returns:
List[np.ndarray]: Result of visualize
"""
if result_out_dir:
mkdir_or_exist(os.path.dirname(result_out_dir))
save_image(preds, result_out_dir, normalize=True)

return preds
3 changes: 3 additions & 0 deletions mmedit/edit.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class MMEdit:
# video_restoration models
'edvr',

# text2image models
'disco_diffusion',

# 3D-aware generation
'eg3d',
]
Expand Down
33 changes: 33 additions & 0 deletions tests/test_apis/test_inferencers/test_text2image_inferencers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import pytest
import torch

from mmedit.apis.inferencers.text2image_inferencer import Text2ImageInferencer
from mmedit.utils import register_all_modules

register_all_modules()


@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_translation_inferencer():
cfg = osp.join(
osp.dirname(__file__), '..', '..', '..', 'configs', 'disco_diffusion',
'disco-diffusion_adm-u-finetuned_imagenet-512x512.py')
text = {0: ['sad']}
result_out_dir = osp.join(
osp.dirname(__file__), '..', '..', 'data', 'disco_result.png')

inferencer_instance = \
Text2ImageInferencer(
cfg, None, extra_parameters={'num_inference_steps': 2})
inferencer_instance(text=text)
inference_result = inferencer_instance(
text=text, result_out_dir=result_out_dir)
result_img = inference_result[1]
assert result_img[0].cpu().numpy().shape == (3, 512, 512)


if __name__ == '__main__':
test_translation_inferencer()