From 91493d5b4f610cc80ca4f2d215a049d857ff7f23 Mon Sep 17 00:00:00 2001 From: Xuye Qin Date: Fri, 11 Oct 2024 12:38:21 +0800 Subject: [PATCH] REF: refactor controlnet for image model (#2346) --- examples/StableDiffusionControlNet.ipynb | 8 +- xinference/model/image/core.py | 11 +- xinference/model/image/sdapi.py | 37 ++- .../model/image/stable_diffusion/core.py | 210 +++++++++++++----- 4 files changed, 192 insertions(+), 74 deletions(-) diff --git a/examples/StableDiffusionControlNet.ipynb b/examples/StableDiffusionControlNet.ipynb index 7c9842709c..14c46c3dec 100644 --- a/examples/StableDiffusionControlNet.ipynb +++ b/examples/StableDiffusionControlNet.ipynb @@ -91,7 +91,7 @@ "from diffusers.utils import load_image\n", "\n", "mlsd = MLSDdetector.from_pretrained(\"lllyasviel/ControlNet\")\n", - "image_path = os.path.expanduser(\"~/draft.png\")\n", + "image_path = os.path.expanduser(\"draft.png\")\n", "image = load_image(image_path)\n", "image = mlsd(image)\n", "image" @@ -181,7 +181,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -195,9 +195,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.9" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/xinference/model/image/core.py b/xinference/model/image/core.py index 0db0833340..7b015fd9c4 100644 --- a/xinference/model/image/core.py +++ b/xinference/model/image/core.py @@ -210,18 +210,19 @@ def create_image_model_instance( for name in controlnet: for cn_model_spec in model_spec.controlnet: if cn_model_spec.model_name == name: - if not model_path: - model_path = cache(cn_model_spec) - controlnet_model_paths.append(model_path) + controlnet_model_path = cache(cn_model_spec) + controlnet_model_paths.append(controlnet_model_path) break else: raise ValueError( f"controlnet `{name}` is not supported for model `{model_name}`." ) if len(controlnet_model_paths) == 1: - kwargs["controlnet"] = controlnet_model_paths[0] + kwargs["controlnet"] = (controlnet[0], controlnet_model_paths[0]) else: - kwargs["controlnet"] = controlnet_model_paths + kwargs["controlnet"] = [ + (n, path) for n, path in zip(controlnet, controlnet_model_paths) + ] if not model_path: model_path = cache(model_spec) if peft_model_config is not None: diff --git a/xinference/model/image/sdapi.py b/xinference/model/image/sdapi.py index 6ef21d48ab..f1a0a73e5f 100644 --- a/xinference/model/image/sdapi.py +++ b/xinference/model/image/sdapi.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import base64 import io import warnings -from PIL import Image +from PIL import Image, ImageOps class SDAPIToDiffusersConverter: @@ -30,7 +31,7 @@ class SDAPIToDiffusersConverter: txt2img_arg_mapping = { "steps": "num_inference_steps", "cfg_scale": "guidance_scale", - # "denoising_strength": "strength", + "denoising_strength": "strength", } img2img_identical_args = { "prompt", @@ -42,9 +43,11 @@ class SDAPIToDiffusersConverter: } img2img_arg_mapping = { "init_images": "image", + "mask": "mask_image", "steps": "num_inference_steps", "cfg_scale": "guidance_scale", "denoising_strength": "strength", + "inpaint_full_res_padding": "padding_mask_crop", } @staticmethod @@ -121,12 +124,38 @@ def _decode_b64_img(img_str: str) -> Image: def img2img(self, **kwargs): init_images = kwargs.pop("init_images", []) - kwargs["init_images"] = [self._decode_b64_img(i) for i in init_images] + kwargs["init_images"] = init_images = [ + self._decode_b64_img(i) for i in init_images + ] + if len(init_images) == 1: + kwargs["init_images"] = init_images[0] + mask_image = kwargs.pop("mask", None) + if mask_image: + if kwargs.pop("inpainting_mask_invert"): + mask_image = ImageOps.invert(mask_image) + + kwargs["mask"] = self._decode_b64_img(mask_image) + + # process inpaint_full_res and inpaint_full_res_padding + if kwargs.pop("inpaint_full_res", None): + kwargs["inpaint_full_res_padding"] = kwargs.pop( + "inpaint_full_res_padding", 0 + ) + else: + # inpaint_full_res_padding is turned `into padding_mask_crop` + # in diffusers, if padding_mask_crop is passed, it will do inpaint_full_res + # so if not inpaint_full_rs, we need to pop this option + kwargs.pop("inpaint_full_res_padding", None) + clip_skip = kwargs.get("override_settings", {}).get("clip_skip") converted_kwargs = self._check_kwargs("img2img", kwargs) if clip_skip: converted_kwargs["clip_skip"] = clip_skip - result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore + + if not converted_kwargs.get("mask_image"): + result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore + else: + result = self.inpainting(response_format="b64_json", **converted_kwargs) # type: ignore # convert to SD API result return { diff --git a/xinference/model/image/stable_diffusion/core.py b/xinference/model/image/stable_diffusion/core.py index a1914ac7e2..d44837a981 100644 --- a/xinference/model/image/stable_diffusion/core.py +++ b/xinference/model/image/stable_diffusion/core.py @@ -14,7 +14,9 @@ import base64 import contextlib +import gc import inspect +import itertools import logging import os import re @@ -25,7 +27,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import PIL.Image import torch @@ -93,16 +95,21 @@ def __init__( self._model_uid = model_uid self._model_path = model_path self._device = device - # when a model has text2image ability, - # it will be loaded as AutoPipelineForText2Image - # for image2image and inpainting, - # we convert to the corresponding model + # model info when loading self._model = None - self._i2i_model = None # image to image model - self._inpainting_model = None # inpainting model self._lora_model = lora_model self._lora_load_kwargs = lora_load_kwargs or {} self._lora_fuse_kwargs = lora_fuse_kwargs or {} + # deepcache + self._deepcache_helper = None + # when a model has text2image ability, + # it will be loaded as AutoPipelineForText2Image + # for image2image and inpainting, + # we convert to the corresponding model + self._torch_dtype = None + self._ability_to_models: Dict[Tuple[str, Any], Any] = {} + self._controlnet_models: Dict[str, Any] = {} + # info self._model_spec = model_spec self._abilities = model_spec.model_ability or [] # type: ignore self._kwargs = kwargs @@ -111,6 +118,63 @@ def __init__( def model_ability(self): return self._abilities + @staticmethod + def _get_pipeline_type(ability: str) -> type: + if ability == "text2image": + from diffusers import AutoPipelineForText2Image as AutoPipelineModel + elif ability == "image2image": + from diffusers import AutoPipelineForImage2Image as AutoPipelineModel + elif ability == "inpainting": + from diffusers import AutoPipelineForInpainting as AutoPipelineModel + else: + raise ValueError(f"Unknown ability: {ability}") + return AutoPipelineModel + + def _get_controlnet_model(self, name: str, path: str): + from diffusers import ControlNetModel + + try: + return self._controlnet_models[name] + except KeyError: + logger.debug("Loading controlnet %s, from %s", name, path) + model = ControlNetModel.from_pretrained(path, torch_dtype=self._torch_dtype) + self._controlnet_models[name] = model + return model + + def _get_model( + self, + ability: str, + controlnet_name: Optional[Union[str, List[str]]] = None, + controlnet_path: Optional[Union[str, List[str]]] = None, + ): + try: + return self._ability_to_models[ability, controlnet_name] + except KeyError: + model_type = self._get_pipeline_type(ability) + + assert self._model is not None + + if controlnet_name: + assert controlnet_path + if isinstance(controlnet_name, (list, tuple)): + controlnet = [] + # multiple controlnet + for name, path in itertools.zip_longest( + controlnet_name, controlnet_path + ): + controlnet.append(self._get_controlnet_model(name, path)) + else: + controlnet = self._get_controlnet_model( + controlnet_name, controlnet_path + ) + model = model_type.from_pipe(self._model, controlnet=controlnet) + else: + model = model_type.from_pipe(self._model) + self._load_to_device(model) + + self._ability_to_models[ability, controlnet_name] = model + return model + def _apply_lora(self): if self._lora_model is not None: logger.info( @@ -132,22 +196,24 @@ def load(self): else: raise ValueError(f"Unknown ability: {self._abilities}") - controlnet = self._kwargs.get("controlnet") - if controlnet is not None: - from diffusers import ControlNetModel - - logger.debug("Loading controlnet %s", controlnet) - self._kwargs["controlnet"] = ControlNetModel.from_pretrained(controlnet) - - torch_dtype = self._kwargs.get("torch_dtype") + self._torch_dtype = torch_dtype = self._kwargs.get("torch_dtype") if sys.platform != "darwin" and torch_dtype is None: # The following params crashes on Mac M2 - self._kwargs["torch_dtype"] = torch.float16 + self._torch_dtype = self._kwargs["torch_dtype"] = torch.float16 self._kwargs["variant"] = "fp16" self._kwargs["use_safetensors"] = True if isinstance(torch_dtype, str): self._kwargs["torch_dtype"] = getattr(torch, torch_dtype) + controlnet = self._kwargs.get("controlnet") + if controlnet is not None: + if isinstance(controlnet, tuple): + self._kwargs["controlnet"] = self._get_controlnet_model(*controlnet) + else: + self._kwargs["controlnet"] = [ + self._get_controlnet_model(*cn) for cn in controlnet + ] + quantize_text_encoder = self._kwargs.pop("quantize_text_encoder", None) if quantize_text_encoder: try: @@ -193,27 +259,39 @@ def load(self): self._model_path, **self._kwargs, ) - if self._kwargs.get("deepcache", True): - # NOTE: DeepCache should be loaded first before cpu_offloading + self._load_to_device(self._model) + self._apply_lora() + + if self._kwargs.get("deepcache", False): try: from DeepCache import DeepCacheSDHelper + except ImportError: + error_message = "Failed to import module 'deepcache' when you launch with deepcache=True" + installation_guide = [ + "Please make sure 'deepcache' is installed. ", + "You can install it by `pip install deepcache`\n", + ] - helper = DeepCacheSDHelper(pipe=self._model) + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + else: + self._deepcache_helper = helper = DeepCacheSDHelper() helper.set_params(cache_interval=3, cache_branch_id=0) - helper.enable() - except ImportError: - logger.debug("deepcache is not installed") - pass + def _load_to_device(self, model): if self._kwargs.get("cpu_offload", False): logger.debug("CPU offloading model") - self._model.enable_model_cpu_offload() + model.enable_model_cpu_offload() + elif self._kwargs.get("sequential_cpu_offload", False): + logger.debug("CPU sequential offloading model") + model.enable_sequential_cpu_offload() elif not self._kwargs.get("device_map"): logger.debug("Loading model to available device") - self._model = move_model_to_available_device(self._model) + model = move_model_to_available_device(self._model) # Recommended if your computer has < 64 GB of RAM - self._model.enable_attention_slicing() - self._apply_lora() + if self._kwargs.get("attention_slicing", True): + model.enable_attention_slicing() + if self._kwargs.get("vae_tiling", False): + model.enable_vae_tiling() @staticmethod def _get_scheduler(model: Any, sampler_name: str): @@ -298,27 +376,49 @@ def _reset_when_done(model: Any, sampler_name: str): else: yield + @staticmethod + @contextlib.contextmanager + def _release_after(): + from ....device_utils import empty_cache + + try: + yield + finally: + gc.collect() + empty_cache() + + @contextlib.contextmanager + def _wrap_deepcache(self, model: Any): + if self._deepcache_helper: + self._deepcache_helper.pipe = model + self._deepcache_helper.enable() + try: + yield + finally: + if self._deepcache_helper: + self._deepcache_helper.disable() + self._deepcache_helper.pipe = None + def _call_model( self, response_format: str, model=None, **kwargs, ): - import gc - - from ....device_utils import empty_cache - model = model if model is not None else self._model is_padded = kwargs.pop("is_padded", None) origin_size = kwargs.pop("origin_size", None) seed = kwargs.pop("seed", None) - if seed is not None: + return_images = kwargs.pop("_return_images", None) + if seed is not None and seed != -1: kwargs["generator"] = generator = torch.Generator(device=get_available_device()) # type: ignore if seed != -1: kwargs["generator"] = generator.manual_seed(seed) sampler_name = kwargs.pop("sampler_name", None) assert callable(model) - with self._reset_when_done(model, sampler_name): + with self._reset_when_done( + model, sampler_name + ), self._release_after(), self._wrap_deepcache(model): logger.debug("stable diffusion args: %s, model: %s", kwargs, model) self._filter_kwargs(model, kwargs) images = model(**kwargs).images @@ -331,9 +431,8 @@ def _call_model( new_images.append(img.crop((0, 0, x, y))) images = new_images - # clean cache - gc.collect() - empty_cache() + if return_images: + return images if response_format == "url": os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True) @@ -378,8 +477,6 @@ def text_to_image( response_format: str = "url", **kwargs, ): - # References: - # https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet_sdxl width, height = map(int, re.split(r"[^\d]+", size)) generate_kwargs = self._model_spec.default_generate_config.copy() # type: ignore generate_kwargs.update({k: v for k, v in kwargs.items() if v is not None}) @@ -409,19 +506,13 @@ def image_to_image( response_format: str = "url", **kwargs, ): - if "controlnet" in self._kwargs: + if self._kwargs.get("controlnet"): model = self._model else: - if "image2image" not in self._abilities: + ability = "image2image" + if ability not in self._abilities: raise RuntimeError(f"{self._model_uid} does not support image2image") - if self._i2i_model is not None: - model = self._i2i_model - else: - from diffusers import AutoPipelineForImage2Image - - self._i2i_model = model = AutoPipelineForImage2Image.from_pipe( - self._model - ) + model = self._get_model(ability) if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None): # Model like SD3 image to image requires image's height and width is times of 16 @@ -462,24 +553,23 @@ def inpainting( response_format: str = "url", **kwargs, ): - if "inpainting" not in self._abilities: + ability = "inpainting" + if ability not in self._abilities: raise RuntimeError(f"{self._model_uid} does not support inpainting") if ( "text2image" in self._abilities or "image2image" in self._abilities ) and self._model is not None: - from diffusers import AutoPipelineForInpainting - - if self._inpainting_model is not None: - model = self._inpainting_model - else: - model = self._inpainting_model = AutoPipelineForInpainting.from_pipe( - self._model - ) + model = self._get_model(ability) else: model = self._model - width, height = map(int, re.split(r"[^\d]+", size)) + if mask_blur := kwargs.pop("mask_blur", None): + logger.debug("Process mask image with mask_blur: %s", mask_blur) + mask_image = model.mask_processor.blur(mask_image, blur_factor=mask_blur) # type: ignore + + if "width" not in kwargs: + kwargs["width"], kwargs["height"] = map(int, re.split(r"[^\d]+", size)) if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None): # Model like SD3 inpainting requires image's height and width is times of 16 @@ -492,14 +582,12 @@ def inpainting( mask_image, multiple=int(padding_image_to_multiple) ) # calculate actual image size after padding - width, height = image.size + kwargs["width"], kwargs["height"] = image.size return self._call_model( image=image, mask_image=mask_image, prompt=prompt, - height=height, - width=width, num_images_per_prompt=n, response_format=response_format, model=model,