From 06242d4f45b44c3c47bec28e04383b78af2beaa6 Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Thu, 7 Sep 2023 21:26:38 +0800 Subject: [PATCH 1/8] add support for diffusers pipeline --- configs/diffusers_pipeline/README.md | 47 ++++++++++++ configs/diffusers_pipeline/metafile.yml | 17 +++++ configs/diffusers_pipeline/sd_xl_pipeline.py | 6 ++ mmagic/apis/inferencers/__init__.py | 6 +- .../diffusers_pipeline_inferencer.py | 73 +++++++++++++++++++ mmagic/apis/mmagic_inferencer.py | 7 +- mmagic/models/archs/__init__.py | 15 +++- mmagic/models/archs/wrapper.py | 8 ++ model-index.yml | 1 + 9 files changed, 175 insertions(+), 5 deletions(-) create mode 100644 configs/diffusers_pipeline/README.md create mode 100644 configs/diffusers_pipeline/metafile.yml create mode 100644 configs/diffusers_pipeline/sd_xl_pipeline.py create mode 100644 mmagic/apis/inferencers/diffusers_pipeline_inferencer.py diff --git a/configs/diffusers_pipeline/README.md b/configs/diffusers_pipeline/README.md new file mode 100644 index 0000000000..ab2b246d15 --- /dev/null +++ b/configs/diffusers_pipeline/README.md @@ -0,0 +1,47 @@ +# Diffusers Pipeline (2023) + +> [Diffusers Pipeline](https://github.com/huggingface/diffusers) + +> **Task**: Diffusers Pipeline + + + +## Abstract + + + +We support diffusers pipelines for users to conveniently use diffusers to do inferece in our repo. + +## Configs + +| Model | Dataset | Download | +| :---------------------------------------: | :-----: | :------: | +| [diffusers pipeline](./sd_xl_pipeline.py) | - | - | + +## Quick Start + +```python +from mmagic.apis import MMagicInferencer + +# Create a MMEdit instance and infer +editor = MMagicInferencer(model_name='diffusers_pipeline') +text_prompts = 'Japanese anime style, girl, beautiful, cute, colorful, best quality, extremely detailed' +negative_prompt = 'bad face, bad hands' +result_out_dir = 'resources/output/text2image/sd_xl_japanese.png' +editor.infer(text=text_prompts, + negative_prompt=negative_prompt, + result_out_dir=result_out_dir) +``` + +## Citation + +```bibtex +@misc{von-platen-etal-2022-diffusers, + author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf}, + title = {Diffusers: State-of-the-art diffusion models}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/huggingface/diffusers}} +} +``` diff --git a/configs/diffusers_pipeline/metafile.yml b/configs/diffusers_pipeline/metafile.yml new file mode 100644 index 0000000000..c61b740cef --- /dev/null +++ b/configs/diffusers_pipeline/metafile.yml @@ -0,0 +1,17 @@ +Collections: +- Name: Diffusers Pipeline + Paper: + Title: Diffusers Pipeline + URL: https://github.com/huggingface/diffusers + README: configs/diffusers_pipeline/README.md + Task: + - diffusers pipeline + Year: 2023 +Models: +- Config: configs/diffusers_pipeline/sd_xl_pipeline.py + In Collection: Diffusers Pipeline + Name: sd_xl_pipeline + Results: + - Dataset: '-' + Metrics: {} + Task: Diffusers Pipeline diff --git a/configs/diffusers_pipeline/sd_xl_pipeline.py b/configs/diffusers_pipeline/sd_xl_pipeline.py new file mode 100644 index 0000000000..e1c66e47e7 --- /dev/null +++ b/configs/diffusers_pipeline/sd_xl_pipeline.py @@ -0,0 +1,6 @@ +# config for model + +model = dict( + type='DiffusionPipeline', + from_pretrained='stabilityai/stable-diffusion-xl-base-1.0' +) diff --git a/mmagic/apis/inferencers/__init__.py b/mmagic/apis/inferencers/__init__.py index b175a0c297..66c5710b57 100644 --- a/mmagic/apis/inferencers/__init__.py +++ b/mmagic/apis/inferencers/__init__.py @@ -7,6 +7,7 @@ from .colorization_inferencer import ColorizationInferencer from .conditional_inferencer import ConditionalInferencer from .controlnet_animation_inferencer import ControlnetAnimationInferencer +from .diffusers_pipeline_inferencer import DiffusersPipelineInferencer from .eg3d_inferencer import EG3DInferencer from .image_super_resolution_inferencer import ImageSuperResolutionInferencer from .inpainting_inferencer import InpaintingInferencer @@ -23,7 +24,7 @@ 'ImageSuperResolutionInferencer', 'Text2ImageInferencer', 'TranslationInferencer', 'UnconditionalInferencer', 'VideoInterpolationInferencer', 'VideoRestorationInferencer', - 'ControlnetAnimationInferencer' + 'ControlnetAnimationInferencer', 'DiffusersPipelineInferencer' ] @@ -91,6 +92,9 @@ def __init__(self, ]: self.inferencer = ImageSuperResolutionInferencer( config, ckpt, device, extra_parameters, seed=seed) + elif self.task in ['Diffusers Pipeline']: + self.inferencer = DiffusersPipelineInferencer( + config, ckpt, device, extra_parameters, seed=seed) else: raise ValueError(f'Unknown inferencer task: {self.task}') diff --git a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py new file mode 100644 index 0000000000..5af47ebf13 --- /dev/null +++ b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Dict, List + +import numpy as np +from mmengine import mkdir_or_exist +from PIL.Image import Image +from torchvision.utils import save_image + +from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType + + +class DiffusersPipelineInferencer(BaseMMagicInferencer): + """inferencer that predicts with text2image models.""" + + func_kwargs = dict( + preprocess=['text', 'negative_prompt'], + forward=[], + visualize=['result_out_dir'], + postprocess=[]) + + extra_parameters = dict(height=None, width=None) + + def preprocess(self, + text: InputsType, + negative_prompt: InputsType = None) -> Dict: + """Process the inputs into a model-feedable format. + + Args: + text(InputsType): text input for text-to-image model. + negative_prompt(InputsType): negative prompt. + + Returns: + result(Dict): Results of preprocess. + """ + result = self.extra_parameters + result['prompt'] = text + + if negative_prompt: + result['negative_prompt'] = negative_prompt + + return result + + def forward(self, inputs: InputsType) -> PredType: + """Forward the inputs to the model.""" + images = self.model(**inputs).images + + return images + + def visualize(self, + preds: PredType, + result_out_dir: str = None) -> List[np.ndarray]: + """Visualize predictions. + + Args: + preds (List[Union[str, np.ndarray]]): Forward results + by the inferencer. + result_out_dir (str): Output directory of image. + Defaults to ''. + + Returns: + List[np.ndarray]: Result of visualize + """ + if result_out_dir: + mkdir_or_exist(os.path.dirname(result_out_dir)) + if type(preds) is list: + preds = preds[0] + if type(preds) is Image: + preds.save(result_out_dir) + else: + save_image(preds, result_out_dir, normalize=True) + + return preds diff --git a/mmagic/apis/mmagic_inferencer.py b/mmagic/apis/mmagic_inferencer.py index 634d5e5718..bad7e82df0 100644 --- a/mmagic/apis/mmagic_inferencer.py +++ b/mmagic/apis/mmagic_inferencer.py @@ -114,11 +114,14 @@ class MMagicInferencer: # 3D-aware generation 'eg3d', - # diffusers inferencer + # animation inferencer 'controlnet_animation', # draggan - 'draggan' + 'draggan', + + # diffusers pipeline inferencer + 'diffusers_pipeline', ] inference_supported_models_cfg = {} diff --git a/mmagic/models/archs/__init__.py b/mmagic/models/archs/__init__.py index f33271b509..a67bb3980b 100644 --- a/mmagic/models/archs/__init__.py +++ b/mmagic/models/archs/__init__.py @@ -63,10 +63,21 @@ def gen_wrapped_cls(module, module_name): wrapped_module = gen_wrapped_cls(module, module_name) MODELS.register_module(name=module_name, module=wrapped_module) DIFFUSERS_MODELS.append(module_name) - return DIFFUSERS_MODELS + DIFFUSERS_PIPELINES = [] + for pipeline_name in dir(diffusers.pipelines): + pipeline = getattr(diffusers.pipelines, pipeline_name) + if (inspect.isclass(pipeline) + and issubclass(pipeline, diffusers.DiffusionPipeline)): + wrapped_pipeline = gen_wrapped_cls(pipeline, pipeline_name) + MODELS.register_module(name=pipeline_name, module=wrapped_pipeline) + DIFFUSERS_PIPELINES.append(pipeline_name) -REGISTERED_DIFFUSERS_MODELS = register_diffusers_models() + return DIFFUSERS_MODELS, DIFFUSERS_PIPELINES + + +REGISTERED_DIFFUSERS_MODELS, REGISTERED_DIFFUSERS_PIPELINES = \ + register_diffusers_models() __all__ = [ 'ASPP', 'DepthwiseSeparableConvModule', 'SimpleGatedConvModule', diff --git a/mmagic/models/archs/wrapper.py b/mmagic/models/archs/wrapper.py index b4dc39f394..5b95a9d649 100644 --- a/mmagic/models/archs/wrapper.py +++ b/mmagic/models/archs/wrapper.py @@ -177,3 +177,11 @@ def forward(self, *args, **kwargs) -> Any: Any: The output of wrapped module's forward function. """ return self.model(*args, **kwargs) + + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + ): + self.model.to(torch_device, torch_dtype) + return self diff --git a/model-index.yml b/model-index.yml index a0cd9cd232..66c8c194db 100644 --- a/model-index.yml +++ b/model-index.yml @@ -12,6 +12,7 @@ Import: - configs/deepfillv1/metafile.yml - configs/deepfillv2/metafile.yml - configs/dic/metafile.yml +- configs/diffusers_pipeline/metafile.yml - configs/dim/metafile.yml - configs/disco_diffusion/metafile.yml - configs/draggan/metafile.yml From 850b7e339a27d243fe81854c173bd8b6df32781e Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Thu, 7 Sep 2023 21:41:18 +0800 Subject: [PATCH 2/8] fix lint --- configs/diffusers_pipeline/sd_xl_pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/configs/diffusers_pipeline/sd_xl_pipeline.py b/configs/diffusers_pipeline/sd_xl_pipeline.py index e1c66e47e7..3293e2a970 100644 --- a/configs/diffusers_pipeline/sd_xl_pipeline.py +++ b/configs/diffusers_pipeline/sd_xl_pipeline.py @@ -2,5 +2,4 @@ model = dict( type='DiffusionPipeline', - from_pretrained='stabilityai/stable-diffusion-xl-base-1.0' -) + from_pretrained='stabilityai/stable-diffusion-xl-base-1.0') From 83a30e8a9d59aafd94faed000c855fb73fd361f1 Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Fri, 8 Sep 2023 13:28:22 +0800 Subject: [PATCH 3/8] fix test --- mmagic/models/archs/wrapper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mmagic/models/archs/wrapper.py b/mmagic/models/archs/wrapper.py index 5b95a9d649..b985ac231d 100644 --- a/mmagic/models/archs/wrapper.py +++ b/mmagic/models/archs/wrapper.py @@ -183,5 +183,8 @@ def to( torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, ): - self.model.to(torch_device, torch_dtype) + if torch_dtype is None: + self.model.to(torch_device) + else: + self.model.to(torch_device, torch_dtype) return self From a08fd8c2ddc30c5840dc2724aa657b27693e59ff Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Fri, 8 Sep 2023 16:26:40 +0800 Subject: [PATCH 4/8] add ut and fix base inferencer --- .../inferencers/base_mmagic_inferencer.py | 8 ++-- .../diffusers_pipeline_inferencer.py | 18 ++++++--- mmagic/models/archs/wrapper.py | 12 ++++++ .../test_diffusers_pipeline_inferencer.py | 40 +++++++++++++++++++ 4 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py diff --git a/mmagic/apis/inferencers/base_mmagic_inferencer.py b/mmagic/apis/inferencers/base_mmagic_inferencer.py index 99bfca2680..8fb93401e1 100644 --- a/mmagic/apis/inferencers/base_mmagic_inferencer.py +++ b/mmagic/apis/inferencers/base_mmagic_inferencer.py @@ -130,10 +130,10 @@ def __call__(self, **kwargs) -> Union[Dict, List[Dict]]: Returns: Union[Dict, List[Dict]]: Results of inference pipeline. """ - if 'extra_parameters' in kwargs.keys(): - if 'infer_with_grad' in kwargs['extra_parameters'].keys(): - if kwargs['extra_parameters']['infer_with_grad']: - results = self.base_call(**kwargs) + if ('extra_parameters' in kwargs.keys() + and 'infer_with_grad' in kwargs['extra_parameters'].keys() + and kwargs['extra_parameters']['infer_with_grad']): + results = self.base_call(**kwargs) else: with torch.no_grad(): results = self.base_call(**kwargs) diff --git a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py index 5af47ebf13..2143cd64db 100644 --- a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py +++ b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py @@ -14,16 +14,19 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer): """inferencer that predicts with text2image models.""" func_kwargs = dict( - preprocess=['text', 'negative_prompt'], + preprocess=[ + 'text', 'negative_prompt', 'num_inference_steps', 'height', 'width' + ], forward=[], visualize=['result_out_dir'], postprocess=[]) - extra_parameters = dict(height=None, width=None) - def preprocess(self, text: InputsType, - negative_prompt: InputsType = None) -> Dict: + negative_prompt: InputsType = None, + num_inference_steps: int = 20, + height=None, + width=None) -> Dict: """Process the inputs into a model-feedable format. Args: @@ -35,9 +38,14 @@ def preprocess(self, """ result = self.extra_parameters result['prompt'] = text - if negative_prompt: result['negative_prompt'] = negative_prompt + if num_inference_steps: + result['num_inference_steps'] = num_inference_steps + if height: + result['height'] = height + if width: + result['width'] = width return result diff --git a/mmagic/models/archs/wrapper.py b/mmagic/models/archs/wrapper.py index b985ac231d..ebf141f7c3 100644 --- a/mmagic/models/archs/wrapper.py +++ b/mmagic/models/archs/wrapper.py @@ -183,6 +183,18 @@ def to( torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, ): + """Put wrapped module to device or convert it to torch_dtype. There are + two to() function. One is nn.module.to() and the other is + diffusers.pipeline.to(), if both args are passed, + diffusers.pipeline.to() is called. + + Args: + torch_device: The device to put to. + torch_dtype: The type to convert to. + + Returns: + self: the wrapped module itself. + """ if torch_dtype is None: self.model.to(torch_device) else: diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py new file mode 100644 index 0000000000..a16fcd9c74 --- /dev/null +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import platform +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION + +from mmagic.apis.inferencers.diffusers_pipeline_inferencer import \ + DiffusersPipelineInferencer +from mmagic.utils import register_all_modules + +register_all_modules() + + +@pytest.mark.skipif( + 'win' in platform.system().lower() + or digit_version(TORCH_VERSION) <= digit_version('1.8.1'), + reason='skip on windows due to limited RAM' + 'and get_submodule requires torch >= 1.9.0') +def test_diffusers_pipeline_inferencer(): + cfg = dict( + model=dict( + type='DiffusionPipeline', + from_pretrained='runwayml/stable-diffusion-v1-5')) + + inferencer_instance = DiffusersPipelineInferencer(cfg, None) + text_prompts = 'Japanese anime style, girl' + negative_prompt = 'bad face, bad hands' + result = inferencer_instance( + text=text_prompts, + negative_prompt=negative_prompt, + height=128, + width=128) + assert result[1][0].size == (128, 128) + + +def teardown_module(): + import gc + gc.collect() + globals().clear() + locals().clear() From ffef00d4cf2ce5a2d130023b6502ec648c92d066 Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Fri, 8 Sep 2023 16:32:20 +0800 Subject: [PATCH 5/8] fix lint --- .../test_inferencers/test_diffusers_pipeline_inferencer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py index a16fcd9c74..27b2d08eb8 100644 --- a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -import pytest import platform + +import pytest from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION From 0ea00bef84fcd5ad1fcd721b7223abc6a775c1a3 Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Fri, 8 Sep 2023 17:03:24 +0800 Subject: [PATCH 6/8] add readme --- configs/diffusers_pipeline/README.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/configs/diffusers_pipeline/README.md b/configs/diffusers_pipeline/README.md index ab2b246d15..5e2a8ecb0e 100644 --- a/configs/diffusers_pipeline/README.md +++ b/configs/diffusers_pipeline/README.md @@ -10,7 +10,7 @@ -We support diffusers pipelines for users to conveniently use diffusers to do inferece in our repo. +For the convenience of our community users, this inferencer supports using the pipelines from diffusers for inference to compare the results with the algorithms supported within our algorithm library. ## Configs @@ -20,6 +20,10 @@ We support diffusers pipelines for users to conveniently use diffusers to do inf ## Quick Start +### sd_xl_pipeline + +To run stable diffusion XL with mmagic inference API, follow these codes: + ```python from mmagic.apis import MMagicInferencer @@ -27,12 +31,18 @@ from mmagic.apis import MMagicInferencer editor = MMagicInferencer(model_name='diffusers_pipeline') text_prompts = 'Japanese anime style, girl, beautiful, cute, colorful, best quality, extremely detailed' negative_prompt = 'bad face, bad hands' -result_out_dir = 'resources/output/text2image/sd_xl_japanese.png' +result_out_dir = 'sd_xl_japanese.png' editor.infer(text=text_prompts, negative_prompt=negative_prompt, result_out_dir=result_out_dir) ``` +You will get this picture: + +
+ +
+ ## Citation ```bibtex From 4be2dd876dad45161f9f7fe22b140a74bad30d7e Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Fri, 8 Sep 2023 17:21:49 +0800 Subject: [PATCH 7/8] reduce test size --- .../test_inferencers/test_diffusers_pipeline_inferencer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py index 27b2d08eb8..98b49f6e7c 100644 --- a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -29,9 +29,9 @@ def test_diffusers_pipeline_inferencer(): result = inferencer_instance( text=text_prompts, negative_prompt=negative_prompt, - height=128, - width=128) - assert result[1][0].size == (128, 128) + height=64, + width=64) + assert result[1][0].size == (64, 64) def teardown_module(): From 6a5c50316b32f6378d22cf4010ba477d00f79858 Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Fri, 8 Sep 2023 20:01:55 +0800 Subject: [PATCH 8/8] mock text encoder --- .../test_diffusers_pipeline_inferencer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py index 98b49f6e7c..dbfeb5c6e8 100644 --- a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -2,6 +2,7 @@ import platform import pytest +import torch from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -24,6 +25,17 @@ def test_diffusers_pipeline_inferencer(): from_pretrained='runwayml/stable-diffusion-v1-5')) inferencer_instance = DiffusersPipelineInferencer(cfg, None) + + def mock_encode_prompt(prompt, do_classifier_free_guidance, + num_images_per_prompt, *args, **kwargs): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + batch_size *= num_images_per_prompt + if do_classifier_free_guidance: + batch_size *= 2 + return torch.randn(batch_size, 5, 16) # 2 for cfg + + inferencer_instance.model._encode_prompt = mock_encode_prompt + text_prompts = 'Japanese anime style, girl' negative_prompt = 'bad face, bad hands' result = inferencer_instance(