Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support inference with diffusers pipeline, sd_xl first. #2023

Merged
merged 8 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions configs/diffusers_pipeline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Diffusers Pipeline (2023)

> [Diffusers Pipeline](https://github.com/huggingface/diffusers)

> **Task**: Diffusers Pipeline

<!-- [ALGORITHM] -->

## Abstract

<!-- [ABSTRACT] -->

For the convenience of our community users, this inferencer supports using the pipelines from diffusers for inference to compare the results with the algorithms supported within our algorithm library.

## Configs

| Model | Dataset | Download |
| :---------------------------------------: | :-----: | :------: |
| [diffusers pipeline](./sd_xl_pipeline.py) | - | - |

## Quick Start

### sd_xl_pipeline

To run stable diffusion XL with mmagic inference API, follow these codes:

```python
from mmagic.apis import MMagicInferencer

# Create a MMEdit instance and infer
editor = MMagicInferencer(model_name='diffusers_pipeline')
text_prompts = 'Japanese anime style, girl, beautiful, cute, colorful, best quality, extremely detailed'
negative_prompt = 'bad face, bad hands'
result_out_dir = 'sd_xl_japanese.png'
editor.infer(text=text_prompts,
negative_prompt=negative_prompt,
result_out_dir=result_out_dir)
```

You will get this picture:

<div align=center >
<img src="https://user-images.githubusercontent.com/12782558/266557074-53519887-6597-42cf-8a0b-03c2db3f4ab2.png" width="600"/>
</div >

## Citation

```bibtex
@misc{von-platen-etal-2022-diffusers,
author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf},
title = {Diffusers: State-of-the-art diffusion models},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/huggingface/diffusers}}
}
```
17 changes: 17 additions & 0 deletions configs/diffusers_pipeline/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Collections:
- Name: Diffusers Pipeline
Paper:
Title: Diffusers Pipeline
URL: https://github.com/huggingface/diffusers
README: configs/diffusers_pipeline/README.md
Task:
- diffusers pipeline
Year: 2023
Models:
- Config: configs/diffusers_pipeline/sd_xl_pipeline.py
In Collection: Diffusers Pipeline
Name: sd_xl_pipeline
Results:
- Dataset: '-'
Metrics: {}
Task: Diffusers Pipeline
5 changes: 5 additions & 0 deletions configs/diffusers_pipeline/sd_xl_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# config for model

model = dict(
type='DiffusionPipeline',
from_pretrained='stabilityai/stable-diffusion-xl-base-1.0')
6 changes: 5 additions & 1 deletion mmagic/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .colorization_inferencer import ColorizationInferencer
from .conditional_inferencer import ConditionalInferencer
from .controlnet_animation_inferencer import ControlnetAnimationInferencer
from .diffusers_pipeline_inferencer import DiffusersPipelineInferencer
from .eg3d_inferencer import EG3DInferencer
from .image_super_resolution_inferencer import ImageSuperResolutionInferencer
from .inpainting_inferencer import InpaintingInferencer
Expand All @@ -23,7 +24,7 @@
'ImageSuperResolutionInferencer', 'Text2ImageInferencer',
'TranslationInferencer', 'UnconditionalInferencer',
'VideoInterpolationInferencer', 'VideoRestorationInferencer',
'ControlnetAnimationInferencer'
'ControlnetAnimationInferencer', 'DiffusersPipelineInferencer'
]


Expand Down Expand Up @@ -91,6 +92,9 @@ def __init__(self,
]:
self.inferencer = ImageSuperResolutionInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['Diffusers Pipeline']:
self.inferencer = DiffusersPipelineInferencer(
config, ckpt, device, extra_parameters, seed=seed)
else:
raise ValueError(f'Unknown inferencer task: {self.task}')

Expand Down
8 changes: 4 additions & 4 deletions mmagic/apis/inferencers/base_mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def __call__(self, **kwargs) -> Union[Dict, List[Dict]]:
Returns:
Union[Dict, List[Dict]]: Results of inference pipeline.
"""
if 'extra_parameters' in kwargs.keys():
if 'infer_with_grad' in kwargs['extra_parameters'].keys():
if kwargs['extra_parameters']['infer_with_grad']:
results = self.base_call(**kwargs)
if ('extra_parameters' in kwargs.keys()
and 'infer_with_grad' in kwargs['extra_parameters'].keys()
and kwargs['extra_parameters']['infer_with_grad']):
results = self.base_call(**kwargs)
else:
with torch.no_grad():
results = self.base_call(**kwargs)
Expand Down
81 changes: 81 additions & 0 deletions mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List

import numpy as np
from mmengine import mkdir_or_exist
from PIL.Image import Image
from torchvision.utils import save_image

from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType


class DiffusersPipelineInferencer(BaseMMagicInferencer):
"""inferencer that predicts with text2image models."""

func_kwargs = dict(
preprocess=[
'text', 'negative_prompt', 'num_inference_steps', 'height', 'width'
],
forward=[],
visualize=['result_out_dir'],
postprocess=[])

def preprocess(self,
text: InputsType,
negative_prompt: InputsType = None,
num_inference_steps: int = 20,
height=None,
width=None) -> Dict:
"""Process the inputs into a model-feedable format.

Args:
text(InputsType): text input for text-to-image model.
negative_prompt(InputsType): negative prompt.

Returns:
result(Dict): Results of preprocess.
"""
result = self.extra_parameters
result['prompt'] = text
if negative_prompt:
result['negative_prompt'] = negative_prompt
if num_inference_steps:
result['num_inference_steps'] = num_inference_steps
if height:
result['height'] = height
if width:
result['width'] = width

return result

def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""
images = self.model(**inputs).images

return images

def visualize(self,
preds: PredType,
result_out_dir: str = None) -> List[np.ndarray]:
"""Visualize predictions.

Args:
preds (List[Union[str, np.ndarray]]): Forward results
by the inferencer.
result_out_dir (str): Output directory of image.
Defaults to ''.

Returns:
List[np.ndarray]: Result of visualize
"""
if result_out_dir:
mkdir_or_exist(os.path.dirname(result_out_dir))
if type(preds) is list:
preds = preds[0]
if type(preds) is Image:
preds.save(result_out_dir)
else:
save_image(preds, result_out_dir, normalize=True)

return preds
7 changes: 5 additions & 2 deletions mmagic/apis/mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,14 @@ class MMagicInferencer:
# 3D-aware generation
'eg3d',

# diffusers inferencer
# animation inferencer
'controlnet_animation',

# draggan
'draggan'
'draggan',

# diffusers pipeline inferencer
'diffusers_pipeline',
]

inference_supported_models_cfg = {}
Expand Down
15 changes: 13 additions & 2 deletions mmagic/models/archs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,21 @@ def gen_wrapped_cls(module, module_name):
wrapped_module = gen_wrapped_cls(module, module_name)
MODELS.register_module(name=module_name, module=wrapped_module)
DIFFUSERS_MODELS.append(module_name)
return DIFFUSERS_MODELS

DIFFUSERS_PIPELINES = []
for pipeline_name in dir(diffusers.pipelines):
pipeline = getattr(diffusers.pipelines, pipeline_name)
if (inspect.isclass(pipeline)
and issubclass(pipeline, diffusers.DiffusionPipeline)):
wrapped_pipeline = gen_wrapped_cls(pipeline, pipeline_name)
MODELS.register_module(name=pipeline_name, module=wrapped_pipeline)
DIFFUSERS_PIPELINES.append(pipeline_name)

REGISTERED_DIFFUSERS_MODELS = register_diffusers_models()
return DIFFUSERS_MODELS, DIFFUSERS_PIPELINES


REGISTERED_DIFFUSERS_MODELS, REGISTERED_DIFFUSERS_PIPELINES = \
register_diffusers_models()

__all__ = [
'ASPP', 'DepthwiseSeparableConvModule', 'SimpleGatedConvModule',
Expand Down
23 changes: 23 additions & 0 deletions mmagic/models/archs/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,26 @@ def forward(self, *args, **kwargs) -> Any:
Any: The output of wrapped module's forward function.
"""
return self.model(*args, **kwargs)

def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
):
"""Put wrapped module to device or convert it to torch_dtype. There are
two to() function. One is nn.module.to() and the other is
diffusers.pipeline.to(), if both args are passed,
diffusers.pipeline.to() is called.

Args:
torch_device: The device to put to.
torch_dtype: The type to convert to.

Returns:
self: the wrapped module itself.
"""
if torch_dtype is None:
self.model.to(torch_device)
else:
self.model.to(torch_device, torch_dtype)
return self
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Import:
- configs/deepfillv1/metafile.yml
- configs/deepfillv2/metafile.yml
- configs/dic/metafile.yml
- configs/diffusers_pipeline/metafile.yml
- configs/dim/metafile.yml
- configs/disco_diffusion/metafile.yml
- configs/draggan/metafile.yml
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform

import pytest
import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from mmagic.apis.inferencers.diffusers_pipeline_inferencer import \
DiffusersPipelineInferencer
from mmagic.utils import register_all_modules

register_all_modules()


@pytest.mark.skipif(
'win' in platform.system().lower()
or digit_version(TORCH_VERSION) <= digit_version('1.8.1'),
reason='skip on windows due to limited RAM'
'and get_submodule requires torch >= 1.9.0')
def test_diffusers_pipeline_inferencer():
cfg = dict(
model=dict(
type='DiffusionPipeline',
from_pretrained='runwayml/stable-diffusion-v1-5'))

inferencer_instance = DiffusersPipelineInferencer(cfg, None)

def mock_encode_prompt(prompt, do_classifier_free_guidance,
num_images_per_prompt, *args, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
batch_size *= num_images_per_prompt
if do_classifier_free_guidance:
batch_size *= 2
return torch.randn(batch_size, 5, 16) # 2 for cfg

inferencer_instance.model._encode_prompt = mock_encode_prompt

text_prompts = 'Japanese anime style, girl'
negative_prompt = 'bad face, bad hands'
result = inferencer_instance(
text=text_prompts,
negative_prompt=negative_prompt,
height=64,
width=64)
assert result[1][0].size == (64, 64)


def teardown_module():
import gc
gc.collect()
globals().clear()
locals().clear()