diff --git a/configs/diffusers_pipeline/README.md b/configs/diffusers_pipeline/README.md
new file mode 100644
index 0000000000..5e2a8ecb0e
--- /dev/null
+++ b/configs/diffusers_pipeline/README.md
@@ -0,0 +1,57 @@
+# Diffusers Pipeline (2023)
+
+> [Diffusers Pipeline](https://github.com/huggingface/diffusers)
+
+> **Task**: Diffusers Pipeline
+
+
+
+## Abstract
+
+
+
+For the convenience of our community users, this inferencer supports using the pipelines from diffusers for inference to compare the results with the algorithms supported within our algorithm library.
+
+## Configs
+
+| Model | Dataset | Download |
+| :---------------------------------------: | :-----: | :------: |
+| [diffusers pipeline](./sd_xl_pipeline.py) | - | - |
+
+## Quick Start
+
+### sd_xl_pipeline
+
+To run stable diffusion XL with mmagic inference API, follow these codes:
+
+```python
+from mmagic.apis import MMagicInferencer
+
+# Create a MMEdit instance and infer
+editor = MMagicInferencer(model_name='diffusers_pipeline')
+text_prompts = 'Japanese anime style, girl, beautiful, cute, colorful, best quality, extremely detailed'
+negative_prompt = 'bad face, bad hands'
+result_out_dir = 'sd_xl_japanese.png'
+editor.infer(text=text_prompts,
+ negative_prompt=negative_prompt,
+ result_out_dir=result_out_dir)
+```
+
+You will get this picture:
+
+
+
+
+
+## Citation
+
+```bibtex
+@misc{von-platen-etal-2022-diffusers,
+ author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf},
+ title = {Diffusers: State-of-the-art diffusion models},
+ year = {2022},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/huggingface/diffusers}}
+}
+```
diff --git a/configs/diffusers_pipeline/metafile.yml b/configs/diffusers_pipeline/metafile.yml
new file mode 100644
index 0000000000..c61b740cef
--- /dev/null
+++ b/configs/diffusers_pipeline/metafile.yml
@@ -0,0 +1,17 @@
+Collections:
+- Name: Diffusers Pipeline
+ Paper:
+ Title: Diffusers Pipeline
+ URL: https://github.com/huggingface/diffusers
+ README: configs/diffusers_pipeline/README.md
+ Task:
+ - diffusers pipeline
+ Year: 2023
+Models:
+- Config: configs/diffusers_pipeline/sd_xl_pipeline.py
+ In Collection: Diffusers Pipeline
+ Name: sd_xl_pipeline
+ Results:
+ - Dataset: '-'
+ Metrics: {}
+ Task: Diffusers Pipeline
diff --git a/configs/diffusers_pipeline/sd_xl_pipeline.py b/configs/diffusers_pipeline/sd_xl_pipeline.py
new file mode 100644
index 0000000000..3293e2a970
--- /dev/null
+++ b/configs/diffusers_pipeline/sd_xl_pipeline.py
@@ -0,0 +1,5 @@
+# config for model
+
+model = dict(
+ type='DiffusionPipeline',
+ from_pretrained='stabilityai/stable-diffusion-xl-base-1.0')
diff --git a/mmagic/apis/inferencers/__init__.py b/mmagic/apis/inferencers/__init__.py
index b175a0c297..66c5710b57 100644
--- a/mmagic/apis/inferencers/__init__.py
+++ b/mmagic/apis/inferencers/__init__.py
@@ -7,6 +7,7 @@
from .colorization_inferencer import ColorizationInferencer
from .conditional_inferencer import ConditionalInferencer
from .controlnet_animation_inferencer import ControlnetAnimationInferencer
+from .diffusers_pipeline_inferencer import DiffusersPipelineInferencer
from .eg3d_inferencer import EG3DInferencer
from .image_super_resolution_inferencer import ImageSuperResolutionInferencer
from .inpainting_inferencer import InpaintingInferencer
@@ -23,7 +24,7 @@
'ImageSuperResolutionInferencer', 'Text2ImageInferencer',
'TranslationInferencer', 'UnconditionalInferencer',
'VideoInterpolationInferencer', 'VideoRestorationInferencer',
- 'ControlnetAnimationInferencer'
+ 'ControlnetAnimationInferencer', 'DiffusersPipelineInferencer'
]
@@ -91,6 +92,9 @@ def __init__(self,
]:
self.inferencer = ImageSuperResolutionInferencer(
config, ckpt, device, extra_parameters, seed=seed)
+ elif self.task in ['Diffusers Pipeline']:
+ self.inferencer = DiffusersPipelineInferencer(
+ config, ckpt, device, extra_parameters, seed=seed)
else:
raise ValueError(f'Unknown inferencer task: {self.task}')
diff --git a/mmagic/apis/inferencers/base_mmagic_inferencer.py b/mmagic/apis/inferencers/base_mmagic_inferencer.py
index 99bfca2680..8fb93401e1 100644
--- a/mmagic/apis/inferencers/base_mmagic_inferencer.py
+++ b/mmagic/apis/inferencers/base_mmagic_inferencer.py
@@ -130,10 +130,10 @@ def __call__(self, **kwargs) -> Union[Dict, List[Dict]]:
Returns:
Union[Dict, List[Dict]]: Results of inference pipeline.
"""
- if 'extra_parameters' in kwargs.keys():
- if 'infer_with_grad' in kwargs['extra_parameters'].keys():
- if kwargs['extra_parameters']['infer_with_grad']:
- results = self.base_call(**kwargs)
+ if ('extra_parameters' in kwargs.keys()
+ and 'infer_with_grad' in kwargs['extra_parameters'].keys()
+ and kwargs['extra_parameters']['infer_with_grad']):
+ results = self.base_call(**kwargs)
else:
with torch.no_grad():
results = self.base_call(**kwargs)
diff --git a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
new file mode 100644
index 0000000000..2143cd64db
--- /dev/null
+++ b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from typing import Dict, List
+
+import numpy as np
+from mmengine import mkdir_or_exist
+from PIL.Image import Image
+from torchvision.utils import save_image
+
+from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType
+
+
+class DiffusersPipelineInferencer(BaseMMagicInferencer):
+ """inferencer that predicts with text2image models."""
+
+ func_kwargs = dict(
+ preprocess=[
+ 'text', 'negative_prompt', 'num_inference_steps', 'height', 'width'
+ ],
+ forward=[],
+ visualize=['result_out_dir'],
+ postprocess=[])
+
+ def preprocess(self,
+ text: InputsType,
+ negative_prompt: InputsType = None,
+ num_inference_steps: int = 20,
+ height=None,
+ width=None) -> Dict:
+ """Process the inputs into a model-feedable format.
+
+ Args:
+ text(InputsType): text input for text-to-image model.
+ negative_prompt(InputsType): negative prompt.
+
+ Returns:
+ result(Dict): Results of preprocess.
+ """
+ result = self.extra_parameters
+ result['prompt'] = text
+ if negative_prompt:
+ result['negative_prompt'] = negative_prompt
+ if num_inference_steps:
+ result['num_inference_steps'] = num_inference_steps
+ if height:
+ result['height'] = height
+ if width:
+ result['width'] = width
+
+ return result
+
+ def forward(self, inputs: InputsType) -> PredType:
+ """Forward the inputs to the model."""
+ images = self.model(**inputs).images
+
+ return images
+
+ def visualize(self,
+ preds: PredType,
+ result_out_dir: str = None) -> List[np.ndarray]:
+ """Visualize predictions.
+
+ Args:
+ preds (List[Union[str, np.ndarray]]): Forward results
+ by the inferencer.
+ result_out_dir (str): Output directory of image.
+ Defaults to ''.
+
+ Returns:
+ List[np.ndarray]: Result of visualize
+ """
+ if result_out_dir:
+ mkdir_or_exist(os.path.dirname(result_out_dir))
+ if type(preds) is list:
+ preds = preds[0]
+ if type(preds) is Image:
+ preds.save(result_out_dir)
+ else:
+ save_image(preds, result_out_dir, normalize=True)
+
+ return preds
diff --git a/mmagic/apis/mmagic_inferencer.py b/mmagic/apis/mmagic_inferencer.py
index 634d5e5718..bad7e82df0 100644
--- a/mmagic/apis/mmagic_inferencer.py
+++ b/mmagic/apis/mmagic_inferencer.py
@@ -114,11 +114,14 @@ class MMagicInferencer:
# 3D-aware generation
'eg3d',
- # diffusers inferencer
+ # animation inferencer
'controlnet_animation',
# draggan
- 'draggan'
+ 'draggan',
+
+ # diffusers pipeline inferencer
+ 'diffusers_pipeline',
]
inference_supported_models_cfg = {}
diff --git a/mmagic/models/archs/__init__.py b/mmagic/models/archs/__init__.py
index f33271b509..a67bb3980b 100644
--- a/mmagic/models/archs/__init__.py
+++ b/mmagic/models/archs/__init__.py
@@ -63,10 +63,21 @@ def gen_wrapped_cls(module, module_name):
wrapped_module = gen_wrapped_cls(module, module_name)
MODELS.register_module(name=module_name, module=wrapped_module)
DIFFUSERS_MODELS.append(module_name)
- return DIFFUSERS_MODELS
+ DIFFUSERS_PIPELINES = []
+ for pipeline_name in dir(diffusers.pipelines):
+ pipeline = getattr(diffusers.pipelines, pipeline_name)
+ if (inspect.isclass(pipeline)
+ and issubclass(pipeline, diffusers.DiffusionPipeline)):
+ wrapped_pipeline = gen_wrapped_cls(pipeline, pipeline_name)
+ MODELS.register_module(name=pipeline_name, module=wrapped_pipeline)
+ DIFFUSERS_PIPELINES.append(pipeline_name)
-REGISTERED_DIFFUSERS_MODELS = register_diffusers_models()
+ return DIFFUSERS_MODELS, DIFFUSERS_PIPELINES
+
+
+REGISTERED_DIFFUSERS_MODELS, REGISTERED_DIFFUSERS_PIPELINES = \
+ register_diffusers_models()
__all__ = [
'ASPP', 'DepthwiseSeparableConvModule', 'SimpleGatedConvModule',
diff --git a/mmagic/models/archs/wrapper.py b/mmagic/models/archs/wrapper.py
index b4dc39f394..ebf141f7c3 100644
--- a/mmagic/models/archs/wrapper.py
+++ b/mmagic/models/archs/wrapper.py
@@ -177,3 +177,26 @@ def forward(self, *args, **kwargs) -> Any:
Any: The output of wrapped module's forward function.
"""
return self.model(*args, **kwargs)
+
+ def to(
+ self,
+ torch_device: Optional[Union[str, torch.device]] = None,
+ torch_dtype: Optional[torch.dtype] = None,
+ ):
+ """Put wrapped module to device or convert it to torch_dtype. There are
+ two to() function. One is nn.module.to() and the other is
+ diffusers.pipeline.to(), if both args are passed,
+ diffusers.pipeline.to() is called.
+
+ Args:
+ torch_device: The device to put to.
+ torch_dtype: The type to convert to.
+
+ Returns:
+ self: the wrapped module itself.
+ """
+ if torch_dtype is None:
+ self.model.to(torch_device)
+ else:
+ self.model.to(torch_device, torch_dtype)
+ return self
diff --git a/model-index.yml b/model-index.yml
index a0cd9cd232..66c8c194db 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -12,6 +12,7 @@ Import:
- configs/deepfillv1/metafile.yml
- configs/deepfillv2/metafile.yml
- configs/dic/metafile.yml
+- configs/diffusers_pipeline/metafile.yml
- configs/dim/metafile.yml
- configs/disco_diffusion/metafile.yml
- configs/draggan/metafile.yml
diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py
new file mode 100644
index 0000000000..dbfeb5c6e8
--- /dev/null
+++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+
+import pytest
+import torch
+from mmengine.utils import digit_version
+from mmengine.utils.dl_utils import TORCH_VERSION
+
+from mmagic.apis.inferencers.diffusers_pipeline_inferencer import \
+ DiffusersPipelineInferencer
+from mmagic.utils import register_all_modules
+
+register_all_modules()
+
+
+@pytest.mark.skipif(
+ 'win' in platform.system().lower()
+ or digit_version(TORCH_VERSION) <= digit_version('1.8.1'),
+ reason='skip on windows due to limited RAM'
+ 'and get_submodule requires torch >= 1.9.0')
+def test_diffusers_pipeline_inferencer():
+ cfg = dict(
+ model=dict(
+ type='DiffusionPipeline',
+ from_pretrained='runwayml/stable-diffusion-v1-5'))
+
+ inferencer_instance = DiffusersPipelineInferencer(cfg, None)
+
+ def mock_encode_prompt(prompt, do_classifier_free_guidance,
+ num_images_per_prompt, *args, **kwargs):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ batch_size *= num_images_per_prompt
+ if do_classifier_free_guidance:
+ batch_size *= 2
+ return torch.randn(batch_size, 5, 16) # 2 for cfg
+
+ inferencer_instance.model._encode_prompt = mock_encode_prompt
+
+ text_prompts = 'Japanese anime style, girl'
+ negative_prompt = 'bad face, bad hands'
+ result = inferencer_instance(
+ text=text_prompts,
+ negative_prompt=negative_prompt,
+ height=64,
+ width=64)
+ assert result[1][0].size == (64, 64)
+
+
+def teardown_module():
+ import gc
+ gc.collect()
+ globals().clear()
+ locals().clear()