From d9d97a45bb185060ee473da28a1205ba96014744 Mon Sep 17 00:00:00 2001 From: Zeyu Chen <93063038+InsaneOnion@users.noreply.github.com> Date: Wed, 27 Mar 2024 18:03:28 +0800 Subject: [PATCH] Add Checkpoint Merger Pipeline (#485) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #271 Checkpoint Merger 生成结果对齐: CompVis/stable-diffusion-v1-4 + runwayml/stable-diffusion-v1-5 ## torch ![CompVis_runwayml_1](https://github.com/PaddlePaddle/PaddleMIX/assets/93063038/1b090fbc-4300-490a-b551-3b629bc17f9a) 代码如下: ```python from diffusers import DiffusionPipeline import torch pipe = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger", ) merged_pipe = pipe.merge( ["CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"], interp="sigmoid", alpha=0.4, ) prompt = "An astronaut riding a horse on Mars" merged_pipe.to("cuda") image = merged_pipe(prompt, generator=torch.Generator("cuda").manual_seed(102)).images[0] image.save("CompVis_runwayml.jpg") ``` ## paddle ![CompVis_runwayml](https://github.com/PaddlePaddle/PaddleMIX/assets/93063038/d3059125-b7a4-4506-8516-1407c14cde10) 代码如下: ```python from ppdiffusers import DiffusionPipeline import paddle pipe = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", custom_pipeline="/home/onion/workspace/code/pp/PaddleMIX/ppdiffusers/examples/community/checkpoint_merger", ) merged_pipe = pipe.merge( ["CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"], interp="sigmoid", alpha=0.4, ) prompt = "An astronaut riding a horse on Mars" image = merged_pipe(prompt, generator=paddle.Generator("cuda").manual_seed(102)).images[0] image.save("CompVis_runwayml.jpg") ``` 结果一致 --- ppdiffusers/examples/community/README.md | 65 ++++ .../examples/community/checkpoint_merger.py | 323 ++++++++++++++++++ 2 files changed, 388 insertions(+) create mode 100644 ppdiffusers/examples/community/checkpoint_merger.py diff --git a/ppdiffusers/examples/community/README.md b/ppdiffusers/examples/community/README.md index 9e7497c5c..a340f3e90 100644 --- a/ppdiffusers/examples/community/README.md +++ b/ppdiffusers/examples/community/README.md @@ -18,6 +18,7 @@ |EDICT Image Editing Pipeline| 一个用于文本引导的图像编辑的 Stable Diffusion Pipeline|[EDICT Image Editing Pipeline](#edict_pipeline)|| |FABRIC - Stable Diffusion with feedback Pipeline| 一个用于喜欢图片和不喜欢图片的反馈 Pipeline|[FABRIC - Stable Diffusion with feedback Pipeline](#fabric_pipeline)|| |Stable Diffusion XL Long Weighted Prompt Pipeline| 一个不限制 prompt 长度的 Pipeline|[Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline)|| +|Checkpoint Merger Pipeline|一个支持合并模型checkpoints的Diffusion Pipeline|[Checkpoint Merger Pipeline](#checkpoint-merger-pipeline)|| ## Example usages @@ -892,3 +893,67 @@ images.save("out.png") 生成的图片如下所示:
+ +### Checkpoint Merger Pipeline + +一个支持合并模型checkpoints的Diffusion Pipeline,使用方式如下所示: + +``` python +from ppdiffusers import DiffusionPipeline + +# Return a CheckpointMergerPipeline class that allows you to merge checkpoints. +# The checkpoint passed here is ignored. But still pass one of the checkpoints you plan to +# merge for convenience +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="checkpoint_merger", +) + +# There are multiple possible scenarios: +# The pipeline with the merged checkpoints is returned in all the scenarios + +# Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparison.( attrs with _ as prefix ) +merged_pipe = pipe.merge( + ["CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"], + interp="sigmoid", + alpha=0.4, +) + +# Incompatible checkpoints in model_index.json but merge might be possible. Use force = True to ignore model_index.json compatibility +merged_pipe_1 = pipe.merge( + ["CompVis/stable-diffusion-v1-4", "prompthero/openjourney"], + force=True, + interp="sigmoid", + alpha=0.4, +) + +# Three checkpoint merging. Only "add_difference" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint. +merged_pipe_2 = pipe.merge( + [ + "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", + "prompthero/openjourney", + ], + force=True, + interp="add_difference", + alpha=0.4, +) + +prompt = "An astronaut riding a horse on Mars" + +image = merged_pipe(prompt).images[0] +image.save("CompVis_runwayml.jpg") +image = merged_pipe_1(prompt).images[0] +image.save("CompVis_prompthero.jpg") +image = merged_pipe_2(prompt).images[0] +image.save("CompVis_runwayml_prompthero.jpg") +``` + +一些示例图片以及合并详细信息如下: + +1. "CompVis/stable-diffusion-v1-4" + "runwayml/stable-diffusion-v1-5" ; Sigmoid interpolation; alpha = 0.4 +
+2. "CompVis/stable-diffusion-v1-4" + "prompthero/openjourney" ; Sigmoid interpolation; alpha = 0.4 +
+3. "CompVis/stable-diffusion-v1-4" + "runwayml/stable-diffusion-v1-5" + "prompthero/openjourney" ; Add Difference interpolation; alpha = 0.4 +
diff --git a/ppdiffusers/examples/community/checkpoint_merger.py b/ppdiffusers/examples/community/checkpoint_merger.py new file mode 100644 index 000000000..7a606d4cf --- /dev/null +++ b/ppdiffusers/examples/community/checkpoint_merger.py @@ -0,0 +1,323 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +from typing import Dict, List, Union + +import paddle +import safetensors.paddle + +from ppdiffusers import DiffusionPipeline +from ppdiffusers.utils import ( + DIFFUSERS_CACHE, + FROM_AISTUDIO, + FROM_HF_HUB, + PPDIFFUSERS_CACHE, +) + + +class CheckpointMergerPipeline(DiffusionPipeline): + """ + A class that supports merging diffusion models based on the discussion here: + https://github.com/huggingface/diffusers/issues/877 + + Example usage:- + + pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py") + + merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True) + + merged_pipe.to('cuda') + + prompt = "An astronaut riding a unicycle on Mars" + + results = merged_pipe(prompt) + + ## For more details, see the docstring for the merge method. + + """ + + def __init__(self): + self.register_to_config() + super().__init__() + + def _convert_dict(self, ori_dict): + del_flag = False + for key, value in ori_dict.items(): + if isinstance(value, list): + for item in value: + if item in ["ppdiffusers", "ppdiffusers.transformers"]: + ori_dict[key] = item + elif item == "diffusers" or item == "diffusers_paddle": + ori_dict[key] = "ppdiffusers" + elif item == "transformers" or item == "paddlenlp.transformers": + ori_dict[key] = "ppdiffusers.transformers" + if key == "requires_safety_checker": + del_flag = True + if del_flag: + del ori_dict["requires_safety_checker"] + return ori_dict + + def _compare_model_configs(self, dict0, dict1): + print(dict0) + print(dict1) + if dict0 == dict1: + return True + else: + config0, meta_keys0 = self._remove_meta_keys(dict0) + config1, meta_keys1 = self._remove_meta_keys(dict1) + if config0 == config1: + print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.") + return True + return False + + def _remove_meta_keys(self, config_dict: Dict): + meta_keys = [] + temp_dict = config_dict.copy() + for key in config_dict.keys(): + if key.startswith("_"): + temp_dict.pop(key) + meta_keys.append(key) + return (temp_dict, meta_keys) + + @paddle.no_grad() + def merge( + self, + pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], + **kwargs, + ): + """ + Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed + in the argument 'pretrained_model_name_or_path_list' as a list. + + Parameters: + ----------- + pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format. + + **kwargs: + Supports all the default DiffusionPipeline.get_config_dict kwargs viz.. + + cache_dir, resume_download, force_download, proxies, local_files_only, token, revision, paddle_dtype, device_map. + + alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha + would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 + + interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None. + Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported. + + force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. + + variant - which variant of a pretrained model to load, e.g. "fp16" (None) + + """ + # Default kwargs from DiffusionPipeline + from_hf_hub = kwargs.pop("from_hf_hub", FROM_HF_HUB) + from_aistudio = kwargs.pop("from_aistudio", FROM_AISTUDIO) + cache_dir = kwargs.pop("cache_dir", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + token = kwargs.pop("token", None) + variant = kwargs.pop("variant", None) + revision = kwargs.pop("revision", None) + paddle_dtype = kwargs.pop("paddle_dtype", None) + device_map = kwargs.pop("device_map", None) + + alpha = kwargs.pop("alpha", 0.5) + interp = kwargs.pop("interp", None) + + print("Received list", pretrained_model_name_or_path_list) + print(f"Combining with alpha={alpha}, interpolation mode={interp}") + + checkpoint_count = len(pretrained_model_name_or_path_list) + # Ignore result from model_index_json comparision of the two checkpoints + force = kwargs.pop("force", False) + + # If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now. + if checkpoint_count > 3 or checkpoint_count < 2: + raise ValueError( + "Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being" + " passed." + ) + + print("Received the right number of checkpoints") + # chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2] + # chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None + + # Validate that the checkpoints can be merged + # Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_' + config_dicts = [] + for pretrained_model_name_or_path in pretrained_model_name_or_path_list: + config_dict = DiffusionPipeline.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + config_dict = self._convert_dict(config_dict) + config_dicts.append(config_dict) + + comparison_result = True + for idx in range(1, len(config_dicts)): + comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx]) + if not force and comparison_result is False: + raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.") + print("Compatible model_index.json files found") + # Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files. + cached_folders = [] + for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts): + if os.path.isdir(pretrained_model_name_or_path): + cached_folder = pretrained_model_name_or_path + else: + DiffusionPipeline.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=True, + proxies=proxies, + local_files_only=local_files_only, + revision=revision, + safety_checker=None, + use_safetensors=True, + ) + if from_aistudio: + cached_folder = None # TODO, check aistudio cache + elif from_hf_hub: + cached_folder = os.path.join(DIFFUSERS_CACHE, pretrained_model_name_or_path) + else: + cached_folder = os.path.join(PPDIFFUSERS_CACHE, pretrained_model_name_or_path) + + print("Cached Folder", cached_folder) + cached_folders.append(cached_folder) + + # Step 3:- + # Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place + final_pipe = DiffusionPipeline.from_pretrained( + cached_folders[0], + paddle_dtype=paddle_dtype, + device_map=device_map, + variant=variant, + safety_checker=None, + ) + final_pipe.to(self.device) + + checkpoint_path_2 = None + if len(cached_folders) > 2: + checkpoint_path_2 = os.path.join(cached_folders[2]) + + if interp == "sigmoid": + theta_func = CheckpointMergerPipeline.sigmoid + elif interp == "inv_sigmoid": + theta_func = CheckpointMergerPipeline.inv_sigmoid + elif interp == "add_diff": + theta_func = CheckpointMergerPipeline.add_difference + else: + theta_func = CheckpointMergerPipeline.weighted_sum + + # Find each module's state dict. + for attr in final_pipe.config.keys(): + if not attr.startswith("_"): + checkpoint_path_1 = os.path.join(cached_folders[1], attr) + if os.path.exists(checkpoint_path_1): + files = [ + *glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")), + *glob.glob(os.path.join(checkpoint_path_1, "*.pdprams")), + ] + checkpoint_path_1 = files[0] if len(files) > 0 else None + if len(cached_folders) < 3: + checkpoint_path_2 = None + else: + checkpoint_path_2 = os.path.join(cached_folders[2], attr) + if os.path.exists(checkpoint_path_2): + files = [ + *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")), + *glob.glob(os.path.join(checkpoint_path_2, "*.pdprams")), + ] + checkpoint_path_2 = files[0] if len(files) > 0 else None + # For an attr if both checkpoint_path_1 and 2 are None, ignore. + # If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match. + if checkpoint_path_1 is None and checkpoint_path_2 is None: + print(f"Skipping {attr}: not present in 2nd or 3d model") + continue + try: + module = getattr(final_pipe, attr) + if isinstance(module, bool): # ignore requires_safety_checker boolean + continue + theta_0 = getattr(module, "state_dict") + theta_0 = theta_0() + + update_theta_0 = getattr(module, "load_state_dict") + theta_1 = ( + safetensors.paddle.load_file(checkpoint_path_1) + if (checkpoint_path_1.endswith(".safetensors")) + else paddle.load(checkpoint_path_1, map_location="cpu") + ) + theta_2 = None + if checkpoint_path_2: + theta_2 = ( + safetensors.paddle.load_file(checkpoint_path_2) + if (checkpoint_path_2.endswith(".safetensors")) + else paddle.load(checkpoint_path_2, map_location="cpu") + ) + + if not theta_0.keys() == theta_1.keys(): + print(f"Skipping {attr}: key mismatch") + continue + if theta_2 and not theta_1.keys() == theta_2.keys(): + print(f"Skipping {attr}:y mismatch") + except Exception as e: + print(f"Skipping {attr} do to an unexpected error: {str(e)}") + continue + print(f"MERGING {attr}") + + for key in theta_0.keys(): + if theta_2: + theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha) + else: + theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha) + + del theta_1 + del theta_2 + update_theta_0(theta_0) + + del theta_0 + return final_pipe + + @staticmethod + def weighted_sum(theta0, theta1, theta2, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) + @staticmethod + def sigmoid(theta0, theta1, theta2, alpha): + alpha = alpha * alpha * (3 - (2 * alpha)) + return theta0 + ((theta1 - theta0) * alpha) + + # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) + @staticmethod + def inv_sigmoid(theta0, theta1, theta2, alpha): + import math + + alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) + return theta0 + ((theta1 - theta0) * alpha) + + @staticmethod + def add_difference(theta0, theta1, theta2, alpha): + return theta0 + (theta1 - theta2) * (1.0 - alpha)