diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index e2d812575..a70b10b70 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -321,6 +321,8 @@ jobs: run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_deep_cache_sd_sdxl_enterprise.py --model /share_nfs/stable-diffusion-xl-base-1.0-int8-deep-cache --model_type sdxl --width 512 --height 512 --saved_image output_enterprise_deepcache_sdxl.png - if: matrix.test-suite == 'diffusers_examples' && startsWith(matrix.image, 'onediff-pro') run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_deep_cache_sdxl.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --width 512 --height 512 --run_multiple_resolutions true --saved_image deepcache_sdxl.png + - if: matrix.test-suite == 'diffusers_examples' + run: docker exec -w /src/onediff ${{ env.CONTAINER_NAME }} python3 tests/test_model_inference.py - if: matrix.test-suite == 'diffusers_examples' run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image.py --model_id=/share_nfs/hf_models/stable-diffusion-v1-5 - if: matrix.test-suite == 'diffusers_examples' diff --git a/onediff_comfy_nodes/__init__.py b/onediff_comfy_nodes/__init__.py index 63bce6bb2..6904d5c41 100644 --- a/onediff_comfy_nodes/__init__.py +++ b/onediff_comfy_nodes/__init__.py @@ -1,4 +1,5 @@ """OneDiff ComfyUI Speedup Module""" +from onediff.utils.import_utils import is_nexfort_available, is_oneflow_available from ._config import is_disable_oneflow_backend from ._nodes import ( ControlnetSpeedup, @@ -8,7 +9,6 @@ OneDiffControlNetLoader, VaeSpeedup, ) -from .utils.import_utils import is_nexfort_available, is_oneflow_available NODE_CLASS_MAPPINGS = { "ModelSpeedup": ModelSpeedup, @@ -22,8 +22,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ModelSpeedup": "Model Speedup", "VaeSpeedup": "VAE Speedup", + "OneDiffModelBooster": "Apply Model Booster - OneDiff", "ControlnetSpeedup": "ControlNet Speedup", - "OneDiffModelBooster": "Apply Model Booster - OneDff", "OneDiffCheckpointLoaderSimple": "Load Checkpoint - OneDiff", } diff --git a/onediff_comfy_nodes/_config.py b/onediff_comfy_nodes/_config.py index 8d32de5d1..0ee882e9c 100644 --- a/onediff_comfy_nodes/_config.py +++ b/onediff_comfy_nodes/_config.py @@ -1,4 +1,5 @@ import os +import sys import folder_paths __all__ = [ @@ -14,6 +15,12 @@ os.environ.get("ONEDIFF_COMFY_NODES_DISABLE_ONEFLOW_BACKEND", "0") == "1" ) +custom_nodes_path = os.path.join(folder_paths.base_path, "custom_nodes") + +# Add paths to sys.path if not already there +if custom_nodes_path not in sys.path: + sys.path.append(custom_nodes_path) + if _default_backend not in ["oneflow", "nexfort"]: raise ValueError(f"Invalid default backend: {_default_backend}") diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index 65662609c..ccf695913 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -1,12 +1,13 @@ +from typing import Optional, Tuple import folder_paths import torch import comfy -from onediff.utils.chache_utils import LRUCache +import uuid from nodes import CheckpointLoaderSimple, ControlNetLoader from ._config import is_disable_oneflow_backend -from .modules import BoosterScheduler, BoosterExecutor -from .utils.import_utils import is_nexfort_available # type: ignore -from .utils.import_utils import is_oneflow_available +from .modules import BoosterScheduler, BoosterExecutor, BoosterSettings +from onediff.utils.import_utils import is_nexfort_available # type: ignore +from onediff.utils.import_utils import is_oneflow_available if is_oneflow_available() and not is_disable_oneflow_backend(): from .modules.oneflow import BasicOneFlowBoosterExecutor @@ -31,50 +32,66 @@ ] -class ModelSpeedup: - @classmethod - def INPUT_TYPES(s): - return { - "required": {"model": ("MODEL",), "inplace": ([False, True],),}, - "optional": {"custom_booster": ("CUSTOM_BOOSTER",),}, - } +class SpeedupMixin: + """A mix-in class to provide speedup functionality.""" - RETURN_TYPES = ("MODEL",) FUNCTION = "speedup" CATEGORY = "OneDiff" - @torch.no_grad() - def speedup(self, model, inplace=False, custom_booster: BoosterScheduler = None): + @torch.inference_mode() + def speedup( + self, + model, + inplace: bool = False, + custom_booster: Optional[BoosterScheduler] = None, + *args, + **kwargs + ) -> Tuple: + """ + Speed up the model inference. + + Args: + model: The input model to be sped up. + inplace (bool, optional): Whether to perform the operation inplace. Defaults to False. + custom_booster (BoosterScheduler, optional): Custom booster scheduler to use. Defaults to None. + *args: Additional positional arguments to be passed to the underlying functions. + **kwargs: Additional keyword arguments to be passed to the underlying functions. + + Returns: + Tuple: Tuple containing the optimized model. + """ + if not hasattr(self, "booster_settings"): + self.booster_settings = BoosterSettings(tmp_cache_key=str(uuid.uuid4())) + if custom_booster: booster = custom_booster - booster.inplace = False + booster.inplace = inplace else: booster = BoosterScheduler(BasicBoosterExecutor(), inplace=inplace) + booster.settings = self.booster_settings + return (booster(model, *args, **kwargs),) - return (booster(model),) - -class VaeSpeedup: +class ModelSpeedup(SpeedupMixin): @classmethod def INPUT_TYPES(s): return { - "required": {"vae": ("VAE",),}, + "required": {"model": ("MODEL",), "inplace": ([False, True],),}, "optional": {"custom_booster": ("CUSTOM_BOOSTER",),}, } - RETURN_TYPES = ("VAE",) - FUNCTION = "speedup" - CATEGORY = "OneDiff" + RETURN_TYPES = ("MODEL",) - @torch.no_grad() - def speedup(self, vae, custom_booster=None): - if custom_booster: - booster = custom_booster - else: - booster = BoosterScheduler(BasicBoosterExecutor()) - new_vae = booster(vae) - return (new_vae,) +class VaeSpeedup(SpeedupMixin): + @classmethod + def INPUT_TYPES(s): + return { + "required": {"vae": ("VAE",), "inplace": ([False, True],),}, + "optional": {"custom_booster": ("CUSTOM_BOOSTER",),}, + } + + RETURN_TYPES = ("VAE",) class ControlnetSpeedup: @@ -177,8 +194,6 @@ def onediff_load_controlnet(self, control_net_name, custom_booster=None): class OneDiffCheckpointLoaderSimple(CheckpointLoaderSimple): - _cache_map = LRUCache(1) - @classmethod def INPUT_TYPES(s): return { @@ -226,11 +241,6 @@ def _load_checkpoint( def onediff_load_checkpoint( self, ckpt_name, vae_speedup="disable", custom_booster: BoosterScheduler = None, ): - cache_key = (ckpt_name, vae_speedup, custom_booster) - out = self._cache_map.get(cache_key, None) - if out is None: - out = self._load_checkpoint(ckpt_name, vae_speedup, custom_booster) - self._cache_map.put(cache_key, out) - + out = self._load_checkpoint(ckpt_name, vae_speedup, custom_booster) # Return the loaded checkpoint (modelpatcher, clip, vae) return out diff --git a/onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py b/onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py index 02a724684..a265f9d0d 100644 --- a/onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py +++ b/onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py @@ -1,18 +1,57 @@ +import collections from ..modules.nexfort.booster_basic import BasicNexFortBoosterExecutor -NODE_CLASS_MAPPINGS = {} -NODE_DISPLAY_NAME_MAPPINGS = {} + +# https://github.com/siliconflow/nexfort?tab=readme-ov-file#suggested-combinations-of-compiler-modes +compiler_modes = collections.OrderedDict( + { + "jit:disable-runtime-fusion:low-precision": "This compiles super quickly, but the performance might not be optimized very noticeably.", + "jit:benchmark:low-precision:freezing:cudagraphs": "This compiles the model very quickly, but the performance might be not as good as `TorchInductor` optimized models.", + "max-autotune:low-precision": "This will deliver a good performance and adapt quickly to shape changes.", + "max-autotune:benchmark:low-precision:cudagraphs": "This is the most suggested combination of compiler modes. It will deliver a good balance between performance and compilation time.", + "max-optimize:max-autotune:benchmark:low-precision:freezing:cudagraphs": "This is the most aggressive combination of compiler modes. It will deliver the best performance but might slow down the compilation significantly.", + } +) class OneDiffNexfortBooster: - @classmethod def INPUT_TYPES(s): - return {} - + return { + "required": { + "fullgraph": ([False, True],), + "dynamic": ([None, True, False],), + "mode": ([mode for mode in compiler_modes.keys()],), + "docs_link": ( + "STRING", + { + "multiline": True, + "default": "[Note]: \nInstall-nexfort \nhttps://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler/backends/nexfort#install-nexfort", + }, + ), + } + } + CATEGORY = "OneDiff/Booster" RETURN_TYPES = ("TorchCompileBooster",) FUNCTION = "apply" - def apply(self, *args, **kwargs): - return (BasicNexFortBoosterExecutor(),) \ No newline at end of file + def apply( + self, + fullgraph=False, + dynamic=None, + mode="max-autotune:cudagraphs", + docs_link=None, + ): + return ( + BasicNexFortBoosterExecutor( + fullgraph=fullgraph, mode=f"{mode}:cache-all", dynamic=dynamic + ), + ) + + +NODE_CLASS_MAPPINGS = { + "OneDiffNexfortBooster": OneDiffNexfortBooster, +} + +NODE_DISPLAY_NAME_MAPPINGS = {"OneDiffNexfortBooster": "Nexfort Booster - OneDiff"} diff --git a/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py b/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py index 3a082493c..02b371d2a 100644 --- a/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py +++ b/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py @@ -7,9 +7,9 @@ from comfy import model_management from comfy.cli_args import args -from onediff.infer_compiler.backends.oneflow.utils.version_util import ( - is_community_version, -) +from onediff.utils.import_utils import is_onediff_quant_available +from onediff.infer_compiler.backends.oneflow.utils.version_util import is_community_version + from ..modules import BoosterScheduler from ..modules.oneflow import ( @@ -28,7 +28,7 @@ from ..modules.oneflow.hijack_utils import comfy_utils_hijack from ..modules.oneflow.utils import OUTPUT_FOLDER, load_graph, save_graph -from ..utils.import_utils import is_onediff_quant_available +from ..modules import BoosterScheduler if is_onediff_quant_available() and not is_community_version(): from ..modules.oneflow.booster_quantization import ( diff --git a/onediff_comfy_nodes/modules/__init__.py b/onediff_comfy_nodes/modules/__init__.py index 2a8604f5b..21c7ce109 100644 --- a/onediff_comfy_nodes/modules/__init__.py +++ b/onediff_comfy_nodes/modules/__init__.py @@ -1,2 +1,2 @@ -from .booster_interface import BoosterExecutor -from .booster_scheduler import BoosterScheduler \ No newline at end of file +from .booster_interface import BoosterExecutor, BoosterSettings +from .booster_scheduler import BoosterScheduler diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py new file mode 100644 index 000000000..392dc5f41 --- /dev/null +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -0,0 +1,49 @@ +import torch +import traceback +from collections import OrderedDict +from comfy.model_patcher import ModelPatcher +from comfy.sd import VAE +from onediff.torch_utils.module_operations import get_sub_module +from onediff.utils.import_utils import is_oneflow_available + +if is_oneflow_available(): + from .oneflow.utils.booster_utils import is_using_oneflow_backend + + +def switch_to_cached_model(new_model: ModelPatcher, cache_model): + assert type(new_model.model) == type(cache_model) + for k, v in new_model.model.state_dict().items(): + cached_v: torch.Tensor = get_sub_module(cache_model, k) + assert v.dtype == cached_v.dtype + cached_v.copy_(v) + new_model.model = cache_model + return new_model + + +class BoosterCacheService: + _cache = OrderedDict() + + def put(self, key, model): + if key is None: + return + # oneflow backends output image error + if is_oneflow_available() and is_using_oneflow_backend(model): + return + self._cache[key] = model.model + + def get(self, key, default=None): + return self._cache.get(key, default) + + def get_cached_model(self, key, model): + cached_model = self.get(key, None) + print(f"Cache lookup: Key='{key}', Cached Model Type='{type(cached_model)}'") + if cached_model is not None: + try: + return switch_to_cached_model(model, cached_model) + except Exception as e: + print("An exception occurred when switching to cached model:") + print(traceback.format_exc()) + del self._cache[key] + torch.cuda.empty_cache() + + return None diff --git a/onediff_comfy_nodes/modules/booster_interface.py b/onediff_comfy_nodes/modules/booster_interface.py index ec2967ccf..2abacfb59 100644 --- a/onediff_comfy_nodes/modules/booster_interface.py +++ b/onediff_comfy_nodes/modules/booster_interface.py @@ -1,5 +1,7 @@ # import os +import uuid from abc import ABC, abstractmethod +import dataclasses # from functools import singledispatchmethod # from typing import Optional @@ -10,6 +12,7 @@ # from comfy.model_patcher import ModelPatcher # from comfy.sd import VAE + class BoosterExecutor(ABC): """Interface for optimization.""" @@ -17,3 +20,14 @@ class BoosterExecutor(ABC): def execute(self, model, ckpt_name=None, **kwargs): """Apply the optimization strategy to the model.""" pass + + +@dataclasses.dataclass +class BoosterSettings: + tmp_cache_key: str = None + + +if __name__ == "__main__": + print(BoosterSettings(str(uuid.uuid4())).tmp_cache_key) + print(BoosterSettings(str(uuid.uuid4())).tmp_cache_key) + print(BoosterSettings(str(uuid.uuid4())).tmp_cache_key) diff --git a/onediff_comfy_nodes/modules/booster_scheduler.py b/onediff_comfy_nodes/modules/booster_scheduler.py index 0b15665a1..6b5c29260 100644 --- a/onediff_comfy_nodes/modules/booster_scheduler.py +++ b/onediff_comfy_nodes/modules/booster_scheduler.py @@ -1,47 +1,76 @@ import copy -from functools import singledispatchmethod +import torch.nn as nn +from functools import singledispatchmethod, wraps from typing import List from comfy.model_patcher import ModelPatcher +from comfy.sd import VAE +from comfy import model_management +from .booster_cache import BoosterCacheService +from .booster_interface import BoosterExecutor, BoosterSettings -from .booster_interface import BoosterExecutor + +def auto_cache_model(func): + @wraps(func) + def wrapper(self: "BoosterScheduler", model=None, *args, **kwargs): + if self.settings is None: + return func(self, model, *args, **kwargs) + cached_model_key = self.settings.tmp_cache_key + cached_model = self.cache_service.get_cached_model(cached_model_key, model) + if cached_model is not None: + return cached_model + cached_model = func(self, model, *args, **kwargs) + self.cache_service.put(cached_model_key, cached_model) + return cached_model + + return wrapper class BoosterScheduler: - def __init__(self, booster_executors: List[BoosterExecutor], * , inplace = True): + def __init__( + self, + booster_executors: List[BoosterExecutor], + *, + inplace=True, + settings: BoosterSettings = None, + ): if not isinstance(booster_executors, (list, tuple)): booster_executors = [booster_executors] self.booster_executors = booster_executors self.inplace = inplace - + self.settings = settings + self.cache_service = BoosterCacheService() def is_empty(self) -> bool: """ Checks if the list of boosters is empty. """ return not self.booster_executors - + + @auto_cache_model def compile(self, model=None, ckpt_name=None, **kwargs): if not self.inplace: model = self.copy(model) for executor in self.booster_executors: - model = executor.execute(model, ckpt_name=ckpt_name, **kwargs) + return model def __call__(self, model=None, ckpt_name=None, **kwargs): return self.compile(model=model, ckpt_name=ckpt_name, **kwargs) - + @singledispatchmethod def copy(self, model): raise NotImplementedError(f"Copying {type(model)} is not implemented.") - + @copy.register def _(self, model: ModelPatcher): + model.model = model.model.to("cpu") new_modelpatcher = model.clone() - new_modelpatcher.model = copy.deepcopy(model.model) + copied_model: nn.Module = copy.deepcopy(model.model) + new_modelpatcher.model = copied_model.to(model_management.get_torch_device()) return new_modelpatcher - - - - + @copy.register + def _(self, model: VAE): + new_vae = copy.deepcopy(model) + return new_vae diff --git a/onediff_comfy_nodes/modules/nexfort/__init__.py b/onediff_comfy_nodes/modules/nexfort/__init__.py new file mode 100644 index 000000000..dbb1deb21 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/__init__.py @@ -0,0 +1,15 @@ +import os +from .hijack_samplers import samplers_hijack +from .hijack_ipadapter_plus import ipadapter_plus_hijacker +from .hijack_pulid_comfyui import pulid_comfyui_hijacker +from .hijack_comfyui_instantid import comfyui_instantid_hijacker + +samplers_hijack.hijack(last=False) +ipadapter_plus_hijacker.hijack(last=False) +pulid_comfyui_hijacker.hijack(last=False) +comfyui_instantid_hijacker.hijack(last=False) + + +# https://github.com/pytorch/pytorch/blob/1edcb31d34ef012d828bb9f39a8aef6020f580b2/aten/src/ATen/cuda/CUDABlas.cpp#L182-L203 +if os.getenv("CUBLASLT_WORKSPACE_SIZE") is None: + os.environ["CUBLASLT_WORKSPACE_SIZE"] = str(1024 * 1024) diff --git a/onediff_comfy_nodes/modules/nexfort/booster_basic.py b/onediff_comfy_nodes/modules/nexfort/booster_basic.py index b4ea3f5b9..35dab257a 100644 --- a/onediff_comfy_nodes/modules/nexfort/booster_basic.py +++ b/onediff_comfy_nodes/modules/nexfort/booster_basic.py @@ -1,4 +1,4 @@ -import torch +import torch from functools import partial, singledispatchmethod from typing import Optional @@ -6,9 +6,9 @@ from comfy.model_patcher import ModelPatcher from comfy.sd import VAE -# from onediff.infer_compiler import compile -from nexfort.compilers import nexfort_compile - +from onediff.infer_compiler import compile +from nexfort.utils.memory_format import apply_memory_format +from .onediff_controlnet import OneDiffControlLora from ..booster_interface import BoosterExecutor @@ -16,36 +16,55 @@ class BasicNexFortBoosterExecutor(BoosterExecutor): # https://pytorch.org/docs/stable/_modules/torch.html#compile def __init__( self, + mode: str = "max-optimize:max-autotune:freezing:benchmark:cudagraphs", + fullgraph=False, + dynamic=None, ): super().__init__() - # self.compile_fn = partial(compile, backend="nexfort") - from nexfort.compilers import nexfort_compile + options = { + "mode": mode, + "dynamic": dynamic, + "fullgraph": fullgraph, + } # "memory_format": "channels_last" - mode = "max-optimize:max-autotune:cudagraphs" - self.compile_fn = partial(nexfort_compile, mode=mode, fullgraph=True, dynamic = True) + self.compile_fn = partial(compile, backend="nexfort", options=options) - @singledispatchmethod def execute(self, model, ckpt_name=None, **kwargs): raise NotImplementedError(f"Cannot execute {type(model)=}") @execute.register(ModelPatcher) - @torch.no_grad() + @torch.inference_mode() def _(self, model, ckpt_name: Optional[str] = None, **kwargs): - model.model.diffusion_model = self.compile_fn(model.model.diffusion_model) + diffusion_model = model.model.diffusion_model + model.model.diffusion_model = apply_memory_format( + diffusion_model, torch.channels_last + ) + model.model.diffusion_model = self.compile_fn(diffusion_model) + model.weight_inplace_update = True return model - + @execute.register(VAE) - @torch.no_grad() + @torch.inference_mode() def _(self, model, ckpt_name: Optional[str] = None, **kwargs): - # model.first_stage_model = torch.compile(model.first_stage_model, **self.compile_kwargs) - model.first_stage_model = self.compile_fn(model.first_stage_model) + model.first_stage_model.decode = self.compile_fn(model.first_stage_model.decode) return model @execute.register(ControlNet) - @torch.no_grad() + @torch.inference_mode() def _(self, model, ckpt_name: Optional[str] = None, **kwargs): torch_model = model.control_model - compiled_model = self.compile_fn(torch_model) + torch_model = apply_memory_format(torch_model, torch.channels_last) + compiled_model: torch.nn.Module = self.compile_fn(torch_model) model.control_model = compiled_model return model + + @execute.register(ControlLora) + @torch.inference_mode() + def _(self, model, ckpt_name: Optional[str] = None, **kwargs): + def compile_cnet(model): + out: torch.nn.Module = self.compile_fn(model) + return out + + model = OneDiffControlLora.from_controllora(model, compile_fn=compile_cnet) + return model diff --git a/onediff_comfy_nodes/modules/nexfort/booster_utils.py b/onediff_comfy_nodes/modules/nexfort/booster_utils.py new file mode 100644 index 000000000..db00ad99f --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/booster_utils.py @@ -0,0 +1,20 @@ +from onediff.infer_compiler.backends.nexfort.deployable_module import ( + NexfortDeployableModule as DeployableModule, +) +from comfy.model_patcher import ModelPatcher +from comfy.model_base import BaseModel + + +def clear_deployable_module_cache_and_unbind(*args, **kwargs): + raise RuntimeError(f"TODO") + + +def is_using_nexfort_backend(module): + if isinstance(module, ModelPatcher): + if hasattr(module.model, "diffusion_model"): + diff_model = module.model.diffusion_model + return isinstance(diff_model, DeployableModule) + if isinstance(module, BaseModel): + if hasattr(module, "diffusion_model"): + return isinstance(module.diffusion_model, DeployableModule) + return False diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/InstantID.py b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/InstantID.py new file mode 100644 index 000000000..bdfc9a64f --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/InstantID.py @@ -0,0 +1,14 @@ +from ..booster_utils import is_using_nexfort_backend +from ._config import comfyui_instantid_hijacker,comfyui_instantid +from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace + +set_model_patch_replace_fn_pt = comfyui_instantid.InstantID._set_model_patch_replace + + +def cond_func(org_fn, model, *args, **kwargs): + return is_using_nexfort_backend(model) + + +comfyui_instantid_hijacker.register( + set_model_patch_replace_fn_pt, set_model_patch_replace, cond_func +) diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/README.md b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/README.md new file mode 100644 index 000000000..6cc18c45e --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/README.md @@ -0,0 +1,123 @@ +## Accelerating ComfyUI_InstantID with OneDiff +### Environment +Please Refer to the Readme in the Respective Repositories for Installation Instructions. +#### Install OneDiff + +When you have completed these steps, follow the [instructions](https://github.com/siliconflow/onediff/blob/ba93c5a68607abefd38ffed9e6a17bed48c01a81/README.md?plain=1#L224) to install OneDiff. +Then follow the [guide](https://github.com/siliconflow/onediff/blob/0819aa41c8a910add96400265f3165f9d8d3634c/onediff_comfy_nodes/README.md?plain=1#L86) to install ComfyUI OneDiff extension + +#### Install ComfyUI + +``` +cd ComfyUI/custom_nodes +git clone https://github.com/comfyanonymous/ComfyUI +git reset --hard 2d4164271634476627aae31fbec251ca748a0ae0 +``` +When you have completed these steps, follow the [instructions](https://github.com/comfyanonymous/ComfyUI) to install ComfyUI + +#### Install ComfyUI_InstantID + +``` +cd ComfyUI/custom_nodes +git clone https://github.com/cubiq/ComfyUI_InstantID.git +git reset --hard d8c70a0cd8ce0d4d62e78653674320c9c3084ec1 +``` +When you have completed these steps,follow the [instructions](https://github.com/cubiq/ComfyUI_InstantID) to install ComfyUI_InstantID + +### Quick Start + +> Recommend running the official example of ComfyUI_InstantID now, and then trying OneDiff acceleration. +> You can Load these images in ComfyUI to get the full workflow. + +Experiment (GeForce RTX 3090) Workflow for OneDiff Acceleration in ComfyUI_InstantID: + +1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. +2. Add a **`Batch Size Patcher`** node before the **`Ksampler`** node (due to temporary lack of support for dynamic batch size). +As follows: +![workflow (20)](https://github.com/siliconflow/onediff/assets/117806079/492a83a8-1a5b-4fb3-9e53-6d53e881a3f8) + +Note that you can download all images in this page and then drag or load them on ComfyUI to get the workflow embedded in the image. +![oneflow_basic](https://github.com/siliconflow/oneflow/assets/117806079/81016bd8-3ec8-457f-850f-9c486bfd2d0c) + + +
+ Download the required model files + +InstantID requires `insightface`, you need to add it to your libraries together with `onnxruntime` and `onnxruntime-gpu`. + +The InsightFace model is **antelopev2** (not the classic buffalo_l). Download the models (for example from [here](https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing) or [here](https://huggingface.co/MonsterMMORPG/tools/tree/main)), unzip and place them in the `ComfyUI/models/insightface/models/antelopev2` directory. + + +##### For NA/EU users +```shell +cd ComfyUI +# Load Checkpoint +wget -O models/checkpoints/sd_xl_base_1.0.safetensors https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors + +# Load InstantID Model +mkdir -p models/instantid +wget -O models/instantid/ip-adapter.bin https://huggingface.co/InstantX/InstantID/resolve/main/ip-adapter.bin + + +# Load ControlNet Model +wget -O models/controlnet/diffusion_pytorch_model.safetensors https://huggingface.co/InstantX/InstantID/resolve/main/ControlNetModel/diffusion_pytorch_model.safetensors + +``` + +##### For CN users +```shell +cd ComfyUI +wget -O models/checkpoints/sd_xl_base_1.0.safetensors https://hf-mirror.com/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors + +# Load InstantID Model +mkdir -p models/instantid +wget -O models/instantid/ip-adapter.bin https://hf-mirror.com/InstantX/InstantID/resolve/main/ip-adapter.bin + +# Load ControlNet Model +wget -O models/controlnet/diffusion_pytorch_model.safetensors https://hf-mirror.com/InstantX/InstantID/resolve/main/ControlNetModel/diffusion_pytorch_model.safetensors +``` + +
+ + +### InstantID_basic +#### WorkFlow Description +source: https://github.com/cubiq/ComfyUI_InstantID/blob/main/examples/InstantID_basic.json +| InstantID | Baseline (non-optimized) | OneDiff (optimized) | +| ----------------- | ---------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------ | +| WorkFlow |![InstantID_basic_torch](https://github.com/siliconflow/sd-team/assets/117806079/d649539c-7e8e-449f-b7b5-08622e6f93cc) |![InstantID_basic_oneflow](https://github.com/siliconflow/sd-team/assets/117806079/c752ca4b-7d81-49b4-915a-9c3088227e9d)| + +#### Performance Comparison + +Timings for 30 steps at 1024*1024 + +| Accelerator | Baseline (non-optimized) | OneDiff (optimized) | Percentage improvement | +| --------------------- | ------------------------ | ------------------- | ---------------------- | +| GeForce RTX 3090 | 12.69 s | 9 s | 29.1 % | + +### InstantID_IPAdapter +#### WorkFlow Description +source: https://github.com/cubiq/ComfyUI_InstantID/blob/main/examples/InstantID_IPAdapter.json + +| InstantID_IPAdapter | Baseline (non-optimized) | OneDiff (optimized) | +| ----------------- | ---------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------ | +| WorkFlow |![InstantID_IPAdapter_torch](https://github.com/siliconflow/sd-team/assets/117806079/ba4ba6a9-f9d8-4921-85dd-be00c72f20a6) | ![InstantID_IPAdapter_oneflow](https://github.com/siliconflow/sd-team/assets/117806079/46533f74-7634-4839-8c3e-c555c78eca63) | + +#### Performance Comparison + +Timings for 30 steps at 1024*1024 + +| Accelerator | Baseline (non-optimized) | OneDiff (optimized) | Percentage improvement | +| --------------------- | ------------------------ | ------------------- | ---------------------- | +| GeForce RTX 3090 | 13.23 s | 9.33 s | 29.5% | + +- **Note:** + - Consider setting `ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION=0` and `ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION=0` to ensure computational precision, but expect a potential 5% reduction in performance. + +## Contact + +For users of OneDiff Community, please visit [GitHub Issues](https://github.com/siliconflow/onediff/issues) for bug reports and feature requests. + +For users of OneDiff Enterprise, you can contact contact@siliconflow.com for commercial support. + +Feel free to join our [Discord](https://discord.gg/RKJTjZMcPQ) community for discussions and to receive the latest updates. \ No newline at end of file diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/__init__.py b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/__init__.py new file mode 100644 index 000000000..80fba76d1 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/__init__.py @@ -0,0 +1,4 @@ +from ._config import comfyui_instantid_hijacker, is_load_comfyui_instantid_pkg + +if is_load_comfyui_instantid_pkg: + from .InstantID import * diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/_config.py b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/_config.py new file mode 100644 index 000000000..d6decc38b --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/_config.py @@ -0,0 +1,22 @@ +import os +import traceback + +COMFYUI_ROOT = os.getenv("COMFYUI_ROOT") +from onediff.utils.import_utils import DynamicModuleLoader +from ...sd_hijack_utils import Hijacker + +__all__ = ["comfyui_instantid"] + +pkg_name = "ComfyUI_InstantID" +pkg_root = os.path.join(COMFYUI_ROOT, "custom_nodes", pkg_name) +is_load_comfyui_instantid_pkg = True +try: + if os.path.exists(pkg_root): + comfyui_instantid = DynamicModuleLoader.from_path(pkg_root) + else: + is_load_comfyui_instantid_pkg = False +except Exception as e: + print(traceback.format_exc()) + print(f"Warning: Failed to load {pkg_root} due to {e}") + is_load_comfyui_instantid_pkg = False +comfyui_instantid_hijacker = Hijacker() diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/CrossAttentionPatch.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/CrossAttentionPatch.py new file mode 100644 index 000000000..b3a16a2b8 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/CrossAttentionPatch.py @@ -0,0 +1,256 @@ +import torch +import math +import torch.nn.functional as F +from comfy.ldm.modules.attention import optimized_attention + + +def tensor_to_size(source, dest_size): + if isinstance(dest_size, torch.Tensor): + dest_size = dest_size.shape[0] + source_size = source.shape[0] + + if source_size < dest_size: + shape = [dest_size - source_size] + [1] * (source.dim() - 1) + source = torch.cat((source, source[-1:].repeat(shape)), dim=0) + elif source_size > dest_size: + source = source[:dest_size] + + return source + + +class Attn2Replace: + def __init__(self, callback=None, **kwargs): + self.callback = [callback] + self.kwargs = [kwargs] + + self.cache_map = {} # {ui_index, index} + self.forward_patch_key = id(self) + + def add(self, callback, **kwargs): + self.callback.append(callback) + self.kwargs.append(kwargs) + + for key, value in kwargs.items(): + setattr(self, key, value) + + def __deepcopy__(self, memo): + # print("Warning: CrossAttentionPatch is not deepcopiable.", '-'*20) + return self + + def __call__(self, q, k, v, extra_options): + dtype = q.dtype + out = optimized_attention(q, k, v, extra_options["n_heads"]) + # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-predicate + # sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9 + patch_kwargs = extra_options["_attn2"].get(self.forward_patch_key, None) + _sigmas = extra_options["_sigmas"].get(self.forward_patch_key, None) + assert patch_kwargs is not None and isinstance(patch_kwargs, list) + assert _sigmas is not None and isinstance(_sigmas, torch.Tensor) + for i, callback in enumerate(self.callback): + # if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]: + if _sigmas.shape[i] == 1: + out = out + callback( + out, q, k, v, extra_options, **self.kwargs[i], **patch_kwargs[i] + ) + + return out.to(dtype=dtype) + + +def ipadapter_attention( + out, + q, + k, + v, + extra_options, + module_key="", + ipadapter=None, + weight=1.0, + cond=None, + cond_alt=None, + uncond=None, + weight_type="linear", + mask=None, + sigma_start=0.0, + sigma_end=1.0, + unfold_batch=False, + embeds_scaling="V only", + **kwargs +): + dtype = q.dtype + cond_or_uncond = extra_options["cond_or_uncond"] + block_type = extra_options["block"][0] + # block_id = extra_options["block"][1] + t_idx = extra_options["transformer_index"] + layers = 11 if "101_to_k_ip" in ipadapter.ip_layers.to_kvs else 16 + k_key = module_key + "_to_k_ip" + v_key = module_key + "_to_v_ip" + + # extra options for AnimateDiff + ad_params = extra_options["ad_params"] if "ad_params" in extra_options else None + + b = q.shape[0] + seq_len = q.shape[1] + batch_prompt = b // len(cond_or_uncond) + _, _, oh, ow = extra_options["original_shape"] + + if weight_type == "ease in": + weight = weight * (0.05 + 0.95 * (1 - t_idx / layers)) + elif weight_type == "ease out": + weight = weight * (0.05 + 0.95 * (t_idx / layers)) + elif weight_type == "ease in-out": + weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (layers / 2)) / (layers / 2))) + elif weight_type == "reverse in-out": + weight = weight * (0.05 + 0.95 * (abs(t_idx - (layers / 2)) / (layers / 2))) + elif weight_type == "weak input" and block_type == "input": + weight = weight * 0.2 + elif weight_type == "weak middle" and block_type == "middle": + weight = weight * 0.2 + elif weight_type == "weak output" and block_type == "output": + weight = weight * 0.2 + elif weight_type == "strong middle" and ( + block_type == "input" or block_type == "output" + ): + weight = weight * 0.2 + elif isinstance(weight, dict): + if t_idx not in weight: + return 0 + + weight = weight[t_idx] + + if cond_alt is not None and t_idx in cond_alt: + cond = cond_alt[t_idx] + del cond_alt + + if unfold_batch: + # Check AnimateDiff context window + if ad_params is not None and ad_params["sub_idxs"] is not None: + if isinstance(weight, torch.Tensor): + weight = tensor_to_size(weight, ad_params["full_length"]) + weight = torch.Tensor(weight[ad_params["sub_idxs"]]) + if torch.all(weight == 0): + return 0 + weight = weight.repeat( + len(cond_or_uncond), 1, 1 + ) # repeat for cond and uncond + elif weight == 0: + return 0 + + # if image length matches or exceeds full_length get sub_idx images + if cond.shape[0] >= ad_params["full_length"]: + cond = torch.Tensor(cond[ad_params["sub_idxs"]]) + uncond = torch.Tensor(uncond[ad_params["sub_idxs"]]) + # otherwise get sub_idxs images + else: + cond = tensor_to_size(cond, ad_params["full_length"]) + uncond = tensor_to_size(uncond, ad_params["full_length"]) + cond = cond[ad_params["sub_idxs"]] + uncond = uncond[ad_params["sub_idxs"]] + else: + if isinstance(weight, torch.Tensor): + weight = tensor_to_size(weight, batch_prompt) + if torch.all(weight == 0): + return 0 + weight = weight.repeat( + len(cond_or_uncond), 1, 1 + ) # repeat for cond and uncond + elif weight == 0: + return 0 + + cond = tensor_to_size(cond, batch_prompt) + uncond = tensor_to_size(uncond, batch_prompt) + + k_cond = ipadapter.ip_layers.to_kvs[k_key](cond) + k_uncond = ipadapter.ip_layers.to_kvs[k_key](uncond) + v_cond = ipadapter.ip_layers.to_kvs[v_key](cond) + v_uncond = ipadapter.ip_layers.to_kvs[v_key](uncond) + else: + # TODO: should we always convert the weights to a tensor? + if isinstance(weight, torch.Tensor): + weight = tensor_to_size(weight, batch_prompt) + if torch.all(weight == 0): + return 0 + weight = weight.repeat( + len(cond_or_uncond), 1, 1 + ) # repeat for cond and uncond + elif weight == 0: + return 0 + + k_cond = ipadapter.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1) + k_uncond = ipadapter.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1) + v_cond = ipadapter.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1) + v_uncond = ipadapter.ip_layers.to_kvs[v_key](uncond).repeat(batch_prompt, 1, 1) + + ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) + ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) + + if embeds_scaling == "K+mean(V) w/ C penalty": + scaling = float(ip_k.shape[2]) / 1280.0 + weight = weight * scaling + ip_k = ip_k * weight + ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True) + ip_v = (ip_v - ip_v_mean) + ip_v_mean * weight + out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) + del ip_v_mean + elif embeds_scaling == "K+V w/ C penalty": + scaling = float(ip_k.shape[2]) / 1280.0 + weight = weight * scaling + ip_k = ip_k * weight + ip_v = ip_v * weight + out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) + elif embeds_scaling == "K+V": + ip_k = ip_k * weight + ip_v = ip_v * weight + out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) + else: + # ip_v = ip_v * weight + out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) + out_ip = out_ip * weight # I'm doing this to get the same results as before + + if mask is not None: + mask_h = oh / math.sqrt(oh * ow / seq_len) + mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) + mask_w = seq_len // mask_h + + # check if using AnimateDiff and sliding context window + if ( + mask.shape[0] > 1 + and ad_params is not None + and ad_params["sub_idxs"] is not None + ): + # if mask length matches or exceeds full_length, get sub_idx masks + if mask.shape[0] >= ad_params["full_length"]: + mask = torch.Tensor(mask[ad_params["sub_idxs"]]) + mask = F.interpolate( + mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear" + ).squeeze(1) + else: + mask = F.interpolate( + mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear" + ).squeeze(1) + mask = tensor_to_size(mask, ad_params["full_length"]) + mask = mask[ad_params["sub_idxs"]] + else: + mask = F.interpolate( + mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear" + ).squeeze(1) + mask = tensor_to_size(mask, batch_prompt) + + mask = mask.repeat(len(cond_or_uncond), 1, 1) + mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) + + # covers cases where extreme aspect ratios can cause the mask to have a wrong size + mask_len = mask_h * mask_w + if mask_len < seq_len: + pad_len = seq_len - mask_len + pad1 = pad_len // 2 + pad2 = pad_len - pad1 + mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0) + elif mask_len > seq_len: + crop_start = (mask_len - seq_len) // 2 + mask = mask[:, crop_start : crop_start + seq_len, :] + + out_ip = out_ip * mask + + # out = out + out_ip + + return out_ip.to(dtype=dtype) diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/IPAdapterPlus.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/IPAdapterPlus.py new file mode 100644 index 000000000..49468218a --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/IPAdapterPlus.py @@ -0,0 +1,15 @@ +"""hijack ComfyUI/custom_nodes/ComfyUI_IPAdapter_plus/IPAdapterPlus.py""" +from ..booster_utils import is_using_nexfort_backend +from ._config import ipadapter_plus_hijacker, ipadapter_plus +from .set_model_patch_replace import set_model_patch_replace + +set_model_patch_replace_fn = ipadapter_plus.IPAdapterPlus.set_model_patch_replace + + +def cond_func(org_fn, model, *args, **kwargs): + return is_using_nexfort_backend(model) + + +ipadapter_plus_hijacker.register( + set_model_patch_replace_fn, set_model_patch_replace, cond_func +) diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/__init__.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/__init__.py new file mode 100644 index 000000000..bb49b8dcc --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/__init__.py @@ -0,0 +1,4 @@ +from ._config import ipadapter_plus_hijacker, is_load_ipadapter_plus_pkg + +if is_load_ipadapter_plus_pkg: + from .IPAdapterPlus import * diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/_config.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/_config.py new file mode 100644 index 000000000..fe1a737ca --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/_config.py @@ -0,0 +1,22 @@ +import os +import traceback + +COMFYUI_ROOT = os.getenv("COMFYUI_ROOT") +from onediff.utils.import_utils import DynamicModuleLoader +from ...sd_hijack_utils import Hijacker + +__all__ = ["ipadapter_plus", "ipadapter_plus_hijacker"] + +pkg_name = "ComfyUI_IPAdapter_plus" +pkg_root = os.path.join(COMFYUI_ROOT, "custom_nodes", pkg_name) +is_load_ipadapter_plus_pkg = True +try: + if os.path.exists(pkg_root): + ipadapter_plus = DynamicModuleLoader.from_path(pkg_root) + else: + is_load_ipadapter_plus_pkg = False +except Exception as e: + print(traceback.format_exc()) + print(f"Warning: Failed to load {pkg_root} due to {e}") + is_load_ipadapter_plus_pkg = False +ipadapter_plus_hijacker = Hijacker() diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/set_model_patch_replace.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/set_model_patch_replace.py new file mode 100644 index 000000000..30e130169 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/set_model_patch_replace.py @@ -0,0 +1,91 @@ +import torch +from comfy import model_management + +from .CrossAttentionPatch import Attn2Replace, ipadapter_attention +from ..patch_management import create_patch_executor, PatchType +from ..booster_utils import clear_deployable_module_cache_and_unbind + + +def set_model_patch_replace( + org_fn, model, patch_kwargs, key, attention_func=ipadapter_attention +): + diff_model = model.model.diffusion_model + cache_patch_executor = create_patch_executor(PatchType.CachedCrossAttentionPatch) + unet_extra_options_patch_executor = create_patch_executor( + PatchType.UNetExtraInputOptions + ) + cache_dict = cache_patch_executor.get_patch(diff_model) + ui_cache_key = create_patch_executor(PatchType.UiNodeWithIndexPatch).get_patch( + model + ) + unet_extra_options = unet_extra_options_patch_executor.get_patch(diff_model) + + if "attn2" not in unet_extra_options: + unet_extra_options["attn2"] = {} + + to = model.model_options["transformer_options"].copy() + if "patches_replace" not in to: + to["patches_replace"] = {} + else: + to["patches_replace"] = to["patches_replace"].copy() + + if "attn2" not in to["patches_replace"]: + to["patches_replace"]["attn2"] = {} + else: + to["patches_replace"]["attn2"] = to["patches_replace"]["attn2"].copy() + + def split_patch_kwargs(patch_kwargs): + split1dict = {} + split2dict = {} + for k, v in patch_kwargs.items(): + if k in ["cond", "cond_alt", "uncond", "mask", "weight"] or isinstance( + v, torch.Tensor + ): + split1dict[k] = v + else: + split2dict[k] = v + + # patch for weight + # weight = split1dict["weight"] + # if isinstance(weight, (int, float)): + # weight = torch.tensor(weight) + # split1dict["weight"] = weight.to(model_management.get_torch_device()) + + return split1dict, split2dict + + new_patch_kwargs, patch_kwargs = split_patch_kwargs(patch_kwargs) + # update patch_kwargs + if key in cache_dict: + try: + attn2_m = cache_dict[key] + index = attn2_m.cache_map.get(ui_cache_key, None) + if index is not None: + unet_extra_options["attn2"][attn2_m.forward_patch_key][ + index + ] = new_patch_kwargs + + to["patches_replace"]["attn2"][key] = attn2_m + model.model_options["transformer_options"] = to + return + + except Exception as e: + clear_deployable_module_cache_and_unbind(model) + + if key not in to["patches_replace"]["attn2"]: + if key not in cache_dict: + attn2_m = Attn2Replace(attention_func, **patch_kwargs) + cache_dict[key] = attn2_m + index = len(attn2_m.callback) - 1 + attn2_m.cache_map[ui_cache_key] = index + unet_extra_options["attn2"][attn2_m.forward_patch_key] = [new_patch_kwargs] + else: + attn2_m = cache_dict[key] + + to["patches_replace"]["attn2"][key] = attn2_m + model.model_options["transformer_options"] = to + else: + attn2_m: Attn2Replace = to["patches_replace"]["attn2"][key] + unet_extra_options["attn2"][attn2_m.forward_patch_key].append( + new_patch_kwargs + ) # update last patch + attn2_m.cache_map[ui_cache_key] = len(attn2_m.callback) - 1 diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_model_patcher.py b/onediff_comfy_nodes/modules/nexfort/hijack_model_patcher.py new file mode 100644 index 000000000..77b10e75c --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_model_patcher.py @@ -0,0 +1,23 @@ +from comfy.model_patcher import ModelPatcher + +from ..sd_hijack_utils import Hijacker +from .patch_management import PatchType, create_patch_executor +from .booster_utils import is_using_nexfort_backend + + +def clone_nexfort(org_fn, self, *args, **kwargs): + n = org_fn(self, *args, **kwargs) + create_patch_executor(PatchType.UiNodeWithIndexPatch).copy_to(self, n) + dc_patch_executor = create_patch_executor(PatchType.DCUNetExecutorPatch) + if dc_patch_executor.check_patch(self): + dc_patch_executor.copy_to(self, n) + return n + + +def cond_func(org_fn, self, *args, **kwargs): + return is_using_nexfort_backend(self) + + +model_patch_hijacker = Hijacker() + +model_patch_hijacker.register(ModelPatcher.clone, clone_nexfort, cond_func) diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/__init__.py b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/__init__.py new file mode 100644 index 000000000..8da720289 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/__init__.py @@ -0,0 +1,4 @@ +from ._config import is_load_pulid_comfyui_pkg, pulid_comfyui_hijacker + +if is_load_pulid_comfyui_pkg: + from .pulid import * diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/_config.py b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/_config.py new file mode 100644 index 000000000..e48ca66e0 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/_config.py @@ -0,0 +1,22 @@ +import os +import traceback + +COMFYUI_ROOT = os.getenv("COMFYUI_ROOT") +from onediff.utils.import_utils import DynamicModuleLoader +from ...sd_hijack_utils import Hijacker + +__all__ = ["pulid_comfyui", "pulid_comfyui_hijacker"] + +pkg_name = "PuLID_ComfyUI" +pkg_root = os.path.join(COMFYUI_ROOT, "custom_nodes", pkg_name) +is_load_pulid_comfyui_pkg = True +try: + if os.path.exists(pkg_root): + pulid_comfyui = DynamicModuleLoader.from_path(pkg_root) + else: + is_load_pulid_comfyui_pkg = False +except Exception as e: + print(traceback.format_exc()) + print(f"Warning: Failed to load {pkg_root} due to {e}") + is_load_pulid_comfyui_pkg = False +pulid_comfyui_hijacker = Hijacker() diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/pulid.py b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/pulid.py new file mode 100644 index 000000000..7dd7e11ca --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/pulid.py @@ -0,0 +1,23 @@ +from functools import partial +from ._config import pulid_comfyui, pulid_comfyui_hijacker +from ..booster_utils import is_using_nexfort_backend + +from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace + +# ComfyUI/custom_nodes/PuLID_ComfyUI/pulid.py +set_model_patch_replace_fn = pulid_comfyui.pulid.set_model_patch_replace +pulid_attention = pulid_comfyui.pulid.pulid_attention + + +set_model_patch_replace_puild = partial( + set_model_patch_replace, attention_func=pulid_attention +) + + +def cond_func(org_fn, model, *args, **kwargs): + return is_using_nexfort_backend(model) + + +pulid_comfyui_hijacker.register( + set_model_patch_replace_fn, set_model_patch_replace_puild, cond_func +) diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_samplers.py b/onediff_comfy_nodes/modules/nexfort/hijack_samplers.py new file mode 100644 index 000000000..5a2af0a56 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/hijack_samplers.py @@ -0,0 +1,188 @@ +"""hijack ComfyUI/comfy/samplers.py +commit: 4bd7d55b9028d79829a645edfe8259f7b7a049c0 +Date: Thu Apr 11 22:43:05 2024 -0400 +""" + +from typing import Dict +import torch +from comfy.samplers import calc_cond_batch, can_concat_cond, cond_cat, get_area_and_mult + +from ..sd_hijack_utils import Hijacker +from .patch_management import PatchType, create_patch_executor +from .booster_utils import is_using_nexfort_backend + + +def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): + out_conds = [] + out_counts = [] + to_run = [] + + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) + + cond = conds[i] + if cond is not None: + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, i)] + + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + # to_batch = to_batch_temp[:1] + to_batch = to_batch_temp + # free_memory = model_management.get_free_memory(x_in.device) + # for i in range(1, len(to_batch_temp) + 1): + # batch_amount = to_batch_temp[:len(to_batch_temp)//i] + # input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + # if model.memory_required(input_shape) < free_memory: + # to_batch = batch_amount + # break + + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) + + if control is not None: + c["control"] = control.get_control( + input_x, timestep_, c, len(cond_or_uncond) + ) + + transformer_options = {} + if "transformer_options" in model_options: + transformer_options = model_options["transformer_options"].copy() + + if patches is not None: + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + + diff_model = model.diffusion_model + transformer_options["sigmas"] = timestep + + if create_patch_executor(PatchType.CachedCrossAttentionPatch).check_patch( + diff_model + ): + patch_executor = create_patch_executor(PatchType.UNetExtraInputOptions) + extra_options = transformer_options + sigma = extra_options["sigmas"][0].item() if 'sigmas' in extra_options else 999999999.9 + assert "_sigmas" not in extra_options + extra_options["_sigmas"] = {} + attn2_patch_dict = extra_options['patches_replace']["attn2"] + for k, attn_m in attn2_patch_dict.items(): + out_lst = [] + for i, callback in enumerate(attn_m.callback): + if sigma <= attn_m.kwargs[i]["sigma_start"] and sigma >= attn_m.kwargs[i]["sigma_end"]: + out_lst.append(1) + else: + out_lst.append(0) + # extra inputs + transformer_options["_sigmas"][attn_m.forward_patch_key] = torch.randn(*out_lst) + + # extra inputs + transformer_options["_attn2"] = patch_executor.get_patch(diff_model)[ + "attn2" + ] + + c["transformer_options"] = transformer_options + if "model_function_wrapper" in model_options: + output = model_options["model_function_wrapper"]( + model.apply_model, + { + "input": input_x, + "timestep": timestep_, + "c": c, + "cond_or_uncond": cond_or_uncond, + }, + ).chunk(batch_chunks) + else: + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + + # for o in range(batch_chunks): + # cond_index = cond_or_uncond[o] + # out_conds[cond_index][ + # :, + # :, + # area[o][2] : area[o][0] + area[o][2], + # area[o][3] : area[o][1] + area[o][3], + # ] += (output[o] * mult[o]) + # out_counts[cond_index][ + # :, + # :, + # area[o][2] : area[o][0] + area[o][2], + # area[o][3] : area[o][1] + area[o][3], + # ] += mult[o] + + # for i in range(len(out_conds)): + # out_conds[i] /= out_counts[i] + + # return out_conds + + +def cond_func(orig_func, model, *args, **kwargs): + return is_using_nexfort_backend(model) + + +samplers_hijack = Hijacker() +samplers_hijack.register( + orig_func=calc_cond_batch, sub_func=calc_cond_batch_of, cond_func=cond_func, +) diff --git a/onediff_comfy_nodes/modules/nexfort/onediff_controlnet.py b/onediff_comfy_nodes/modules/nexfort/onediff_controlnet.py new file mode 100644 index 000000000..919f9008d --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/onediff_controlnet.py @@ -0,0 +1,84 @@ +import inspect +import comfy +from comfy.controlnet import ControlLora, ControlLoraOps, ControlNet + + +class OneDiffControlLora(ControlLora): + @classmethod + def from_controllora(cls, controlnet: ControlLora, *, compile_fn: callable = None): + init_parameters = set(inspect.signature(cls.__init__).parameters.keys()) + init_dict = { + attr: getattr(controlnet, attr) + for attr in init_parameters + if attr != "self" + } + c = cls(**init_dict) + controlnet.copy_to(c) + c._control_model = None + c._compile_fn = compile_fn + return c + + def pre_run(self, model, percent_to_timestep_function): + ControlNet.pre_run(self, model, percent_to_timestep_function) + + self.manual_cast_dtype = model.manual_cast_dtype + dtype = model.get_dtype() + if self.manual_cast_dtype is None: + + class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): + pass + + else: + + class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast): + pass + + dtype = self.manual_cast_dtype + if self._control_model is None: + controlnet_config = model.model_config.unet_config.copy() + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = self.control_weights[ + "input_hint_block.0.weight" + ].shape[1] + controlnet_config["operations"] = control_lora_ops + controlnet_config["dtype"] = dtype + self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + self.control_model.to(comfy.model_management.get_torch_device()) + self._control_model = self._compile_fn(self.control_model) + + self.control_model = self._control_model + diffusion_model = model.diffusion_model + sd = diffusion_model.state_dict() + # cm = self.control_model.state_dict() + + for k in sd: + weight = sd[k] + try: + comfy.utils.set_attr_param(self.control_model, k, weight) + except: + pass + + for k in self.control_weights: + if k not in {"lora_controlnet"}: + comfy.utils.set_attr_param( + self.control_model, + k, + self.control_weights[k] + .to(dtype) + .to(comfy.model_management.get_torch_device()), + ) + + def cleanup(self): + pass + + def copy(self): + init_parameters = set(inspect.signature(type(self).__init__).parameters.keys()) + init_dict = { + attr: getattr(self, attr) for attr in init_parameters if attr != "self" + } + c = type(self)(**init_dict) + + self.copy_to(c) + c._control_model = self._control_model + c._compile_fn = self._compile_fn + return c diff --git a/onediff_comfy_nodes/modules/nexfort/patch_management/__init__.py b/onediff_comfy_nodes/modules/nexfort/patch_management/__init__.py new file mode 100644 index 000000000..31139aafe --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/patch_management/__init__.py @@ -0,0 +1 @@ +from .patch_factory import create_patch_executor, PatchType diff --git a/onediff_comfy_nodes/modules/nexfort/patch_management/patch_executor.py b/onediff_comfy_nodes/modules/nexfort/patch_management/patch_executor.py new file mode 100644 index 000000000..faa2c1ea8 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/patch_management/patch_executor.py @@ -0,0 +1,138 @@ +from abc import ABC, abstractmethod +from typing import Dict, List + +from comfy.model_patcher import ModelPatcher +from comfy.model_base import BaseModel + + +class PatchExecutorBase(ABC): + @abstractmethod + def check_patch(self): + pass + + @abstractmethod + def set_patch(self): + pass + + @abstractmethod + def get_patch(self): + pass + + +class UiNodeWithIndexPatch(PatchExecutorBase): + DEFAULT_VALUE = -1 + INCREMENT_VALUE = 1 + + def __init__(self) -> None: + self.patch_name = type(self).__name__ + + def check_patch(self, module: ModelPatcher) -> bool: + return hasattr(module, self.patch_name) + + def set_patch(self, module: ModelPatcher, value: int): + setattr(module, self.patch_name, value) + + def get_patch(self, module: ModelPatcher) -> int: + return getattr(module, self.patch_name, self.DEFAULT_VALUE) + + def copy_to(self, old_model: ModelPatcher, new_model: ModelPatcher): + value = self.get_patch(old_model) + self.set_patch(new_model, value + self.INCREMENT_VALUE) + + +class CachedCrossAttentionPatch(PatchExecutorBase): + def __init__(self) -> None: + self.patch_name = type(self).__name__ + + def check_patch(self, module): + return hasattr(module, self.patch_name) + + def set_patch(self, module, value: dict): + setattr(module, self.patch_name, value) + + def get_patch(self, module) -> Dict[str, any]: + if not self.check_patch(module): + self.set_patch(module, {}) + return getattr(module, self.patch_name) + + def clear_patch(self, module): + if self.check_patch(module): + self.get_patch(module).clear() + + +class CrossAttentionForwardMasksPatch(PatchExecutorBase): + def __init__(self) -> None: + """Will be abandoned""" + self.patch_name = "forward_masks" + + def check_patch(self, module): + return hasattr(module, self.patch_name) + + def set_patch(self, module, value): + raise NotImplementedError() + + def get_patch(self, module) -> Dict: + if not self.check_patch(module): + setattr(module, self.patch_name, {}) + return getattr(module, self.patch_name) + + def clear_patch(self, module): + if self.check_patch(module): + self.get_patch(module).clear() + + +class DeepCacheUNetExecutorPatch(PatchExecutorBase): + def __init__(self) -> None: + super().__init__() + self.patch_names = ("deep_cache_unet", "fast_deep_cache_unet") + + def check_patch(self, model_patcher: ModelPatcher): + return all(hasattr(model_patcher, name) for name in self.patch_names) + + def set_patch(self, model_patcher: ModelPatcher, values): + assert len(self.patch_names) == len(values) + for attr, value in zip(self.patch_names, values): + setattr(model_patcher, attr, value) + + def get_patch(self, model_patcher: ModelPatcher) -> List: + return [getattr(model_patcher, attr, None) for attr in self.patch_names] + + def copy_to(self, old_model: ModelPatcher, new_model: ModelPatcher): + values = self.get_patch(old_model) + self.set_patch(new_model, values) + new_model.model.use_deep_cache_unet = True + + def is_use_deep_cache_unet(self, module: BaseModel): + return getattr(module, "use_deep_cache_unet", False) + + +class UNetExtraInputOptions(PatchExecutorBase): + def __init__(self) -> None: + """UNetExtraInputOptions + + """ + super().__init__() + self.patch_name = type(self).__name__ + + def check_patch(self, module): + return hasattr(module, self.patch_name) + + def set_patch(self, module, value: Dict): + """ + Bind extra input options to the specified module. + For UNet extra input options, the value is a dictionary. + + Args: + module: The module object to set the patch attribute on. + value (Dict): The extra input options to bind to the module. + """ + setattr(module, self.patch_name, value) + + def get_patch(self, module) -> Dict: + if not self.check_patch(module): + self.set_patch(module, {}) + return getattr(module, self.patch_name) + + def clear_patch(self, module): + if self.check_patch(module): + self.get_patch(module).clear() diff --git a/onediff_comfy_nodes/modules/nexfort/patch_management/patch_factory.py b/onediff_comfy_nodes/modules/nexfort/patch_management/patch_factory.py new file mode 100644 index 000000000..5f7e38c33 --- /dev/null +++ b/onediff_comfy_nodes/modules/nexfort/patch_management/patch_factory.py @@ -0,0 +1,26 @@ +from enum import Enum +from .patch_executor import ( + CachedCrossAttentionPatch, + DeepCacheUNetExecutorPatch, + UiNodeWithIndexPatch, +) +from .patch_executor import UNetExtraInputOptions + + +class PatchType(Enum): + CachedCrossAttentionPatch = CachedCrossAttentionPatch + UNetExtraInputOptions = UNetExtraInputOptions + DCUNetExecutorPatch = DeepCacheUNetExecutorPatch + UiNodeWithIndexPatch = UiNodeWithIndexPatch + + +def create_patch_executor(selected_patch_type): + for patch_type in PatchType: + if selected_patch_type == patch_type: + return patch_type.value() + raise NotImplementedError(selected_patch_type) + + +if __name__ == "__main__": + patch_executor = create_patch_executor(PatchType.CachedCrossAttentionPatch) + print(patch_executor) diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py b/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py index 47bfae8b4..cdd2eb57a 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py @@ -4,18 +4,18 @@ """ import torch -from comfy.samplers import (calc_cond_batch, can_concat_cond, cond_cat, - get_area_and_mult) +from comfy.samplers import calc_cond_batch, can_concat_cond, cond_cat, get_area_and_mult from ..sd_hijack_utils import Hijacker from .patch_management import PatchType, create_patch_executor from .utils.booster_utils import is_using_oneflow_backend + def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): out_conds = [] out_counts = [] to_run = [] - + for i in range(len(conds)): out_conds.append(torch.zeros_like(x_in)) out_counts.append(torch.ones_like(x_in) * 1e-37) @@ -72,11 +72,13 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): timestep_ = torch.cat([timestep] * batch_chunks) if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) + c["control"] = control.get_control( + input_x, timestep_, c, len(cond_or_uncond) + ) transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() + if "transformer_options" in model_options: + transformer_options = model_options["transformer_options"].copy() if patches is not None: if "patches" in transformer_options: @@ -93,27 +95,48 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): transformer_options["cond_or_uncond"] = cond_or_uncond[:] diff_model = model.diffusion_model - - if create_patch_executor(PatchType.CachedCrossAttentionPatch).check_patch(diff_model): + + if create_patch_executor(PatchType.CachedCrossAttentionPatch).check_patch( + diff_model + ): transformer_options["sigmas"] = timestep[0].item() patch_executor = create_patch_executor(PatchType.UNetExtraInputOptions) - transformer_options["_attn2"] = patch_executor.get_patch(diff_model)["attn2"] + transformer_options["_attn2"] = patch_executor.get_patch(diff_model)[ + "attn2" + ] else: transformer_options["sigmas"] = timestep - - c['transformer_options'] = transformer_options - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + c["transformer_options"] = transformer_options + if "model_function_wrapper" in model_options: + output = model_options["model_function_wrapper"]( + model.apply_model, + { + "input": input_x, + "timestep": timestep_, + "c": c, + "cond_or_uncond": cond_or_uncond, + }, + ).chunk(batch_chunks) else: output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - for o in range(batch_chunks): cond_index = cond_or_uncond[o] - out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -126,7 +149,5 @@ def cond_func(orig_func, model, *args, **kwargs): samplers_hijack = Hijacker() samplers_hijack.register( - orig_func=calc_cond_batch, - sub_func=calc_cond_batch_of, - cond_func=cond_func, + orig_func=calc_cond_batch, sub_func=calc_cond_batch_of, cond_func=cond_func, ) diff --git a/onediff_comfy_nodes/modules/sd_hijack_utils.py b/onediff_comfy_nodes/modules/sd_hijack_utils.py index 987bdc861..bf2eb4ebd 100644 --- a/onediff_comfy_nodes/modules/sd_hijack_utils.py +++ b/onediff_comfy_nodes/modules/sd_hijack_utils.py @@ -2,7 +2,8 @@ import importlib import inspect from types import FunctionType -from typing import Callable, Union +from typing import Callable, List, Union +from collections import deque __all__ = ["Hijacker", "hijack_func"] @@ -21,46 +22,101 @@ class CondFunc: Copied from: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_hijack_utils.py """ - def __new__(cls, orig_func, sub_func, cond_func): + # Dictionary to store hijacked methods and their corresponding CondFunc instances + hijacked_registry = {} + + def __new__( + cls, + orig_func: Union[str, Callable], + sub_funcs: List[FunctionType], + cond_funcs: List[FunctionType], + ): # self: CondFunc instance self = super(CondFunc, cls).__new__(cls) - if isinstance(orig_func, str): - func_path = orig_func.split(".") - for i in range(len(func_path) - 1, -1, -1): - try: - resolved_obj = importlib.import_module(".".join(func_path[:i])) - break - except ImportError: - pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr( - resolved_obj, - func_path[-1], - lambda *args, **kwargs: self(*args, **kwargs), - ) - - def unhijack_func(): - setattr(resolved_obj, func_path[-1], orig_func) - - self.__init__(orig_func, sub_func, cond_func) - return (lambda *args, **kwargs: self(*args, **kwargs), unhijack_func) - - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func + if isinstance(orig_func, FunctionType): + orig_func = get_func_full_name(orig_func) + + assert isinstance(orig_func, str) + + func_path = orig_func.split(".") + for i in range(len(func_path) - 1, -1, -1): + try: + resolved_obj = importlib.import_module(".".join(func_path[:i])) + break + except ImportError: + pass + + if resolved_obj is None: + raise ImportError(f"Could not resolve module for {func_path}") + + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + + orig_func = getattr(resolved_obj, func_path[-1]) + + def hijacked_method(*args, **kwargs): + return self(*args, **kwargs) + + setattr( + resolved_obj, func_path[-1], hijacked_method, + ) + + def unhijack_func(): + setattr(resolved_obj, func_path[-1], orig_func) + del cls.hijacked_registry[hijacked_method] + + self.__init__(orig_func, sub_funcs, cond_funcs) + cls.hijacked_registry[hijacked_method] = self + return (hijacked_method, unhijack_func) + + @staticmethod + def is_hijacked_method(func: Callable): + return func in CondFunc.hijacked_registry + + @staticmethod + def get_hijacked_instance(func: Callable): + return CondFunc.hijacked_registry.get(func) + + def __init__( + self, + orig_func: Callable, + sub_funcs: List[FunctionType], + cond_funcs: List[FunctionType], + ): + self._orig_func = orig_func + self._sub_funcs = deque(sub_funcs) + self._cond_funcs = deque(cond_funcs) def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) + for cond_func, sub_func in zip(self._cond_funcs, self._sub_funcs): + if not cond_func or cond_func(self._orig_func, *args, **kwargs): + return sub_func(self._orig_func, *args, **kwargs) else: - return self.__orig_func(*args, **kwargs) + return self._orig_func(*args, **kwargs) + + def add_condition(self, sub_func: FunctionType, cond_func: FunctionType, last=True): + """Pairs are returned in LIFO order if last is true or FIFO order if false.""" + instance: CondFunc = self + if last: + instance._sub_funcs.append(sub_func) + instance._cond_funcs.append(cond_func) + else: + instance._sub_funcs.appendleft(sub_func) + instance._cond_funcs.appendleft(cond_func) + + +def ensure_list(obj: Union[FunctionType, List[FunctionType]]) -> List[FunctionType]: + if not isinstance(obj, list): + return [obj] + return obj def hijack_func( - orig_func: Union[str, Callable], sub_func: Callable, cond_func: Callable + orig_func: Union[str, Callable], + sub_func: Callable, + cond_func: Callable, + *, + last=True, ): """ Hijacks a function with another function. @@ -82,10 +138,14 @@ def hijack_func( >>> foo() bar """ + if CondFunc.is_hijacked_method(orig_func): + ins = CondFunc.get_hijacked_instance(orig_func) + ins.add_condition(sub_func, cond_func, last=last) + return orig_func, lambda: None if isinstance(orig_func, FunctionType): orig_func = get_func_full_name(orig_func) - return CondFunc(orig_func, sub_func, cond_func) + return CondFunc(orig_func, ensure_list(sub_func), ensure_list(cond_func)) class Hijacker: @@ -103,10 +163,10 @@ def __init__(self, funcs_list=[]): self.funcs_list = funcs_list self.unhijack_funcs = [] - def hijack(self): + def hijack(self, last=True): self.unhijack() for orig_func, sub_func, cond_func in self.funcs_list: - _, unhijack_func = hijack_func(orig_func, sub_func, cond_func) + _, unhijack_func = hijack_func(orig_func, sub_func, cond_func, last=last) self.unhijack_funcs.append(unhijack_func) def unhijack(self): @@ -118,5 +178,64 @@ def unhijack(self): def extend_unhijack(self, unhijack_func): self.unhijack_funcs.append(unhijack_func) - def register(self, orig_func, sub_func, cond_func): + def register( + self, orig_func: FunctionType, sub_func: Callable, cond_func: Callable + ): self.funcs_list.append((orig_func, sub_func, cond_func)) + + +if __name__ == "__main__": + + def orig_func(*args, **kwargs): + print("Original function") + return "orig_func" + + def sub_func_0(orig_func, *args, **kwargs): + print(f"Called sub_func_0") + return "sub_func_0" + + def sub_func_1(orig_func, *args, **kwargs): + print(f"Called sub_func_1") + return "sub_func_1" + + cond_0 = True + + def cond_func_0(orig_func, *args, **kwargs): + return cond_0 + + def cond_func_1(orig_func, *args, **kwargs): + return True + + hijack_func(orig_func, sub_func_0, cond_func_0) + assert orig_func() == "sub_func_0" # Output: Called sub_func_0 + + hijack_func(orig_func, sub_func_1, cond_func_1) + cond_0 = False + assert orig_func() == "sub_func_1" # Output: Called sub_func_1 + cond_0 = True + assert orig_func() == "sub_func_0" # Called sub_func_0 + hijack_func(orig_func, sub_func_1, cond_func_1, last=False) + cond_0 = True + assert orig_func() == "sub_func_1" # Called sub_func_1 + + class Case1: + def clone(self): + print(f"{type(self)}.clone") + + def cond_func(org_fn, self): + return True + + def custom_clone(org_fn, self): + print(f"custom_clone") + return "custom_clone" + + hijack_func(Case1.clone, custom_clone, cond_func) + + def custom_clone_1(org_fn, self): + print(f"custom_clone_1") + return "custom_clone_1" + + assert Case1().clone() == "custom_clone" + hijack_func(Case1.clone, custom_clone_1, cond_func, last=False) + + assert Case1().clone() == "custom_clone_1" diff --git a/onediff_comfy_nodes/modules/torch_compile/booster_basic.py b/onediff_comfy_nodes/modules/torch_compile/booster_basic.py index e571b3548..b02bb8101 100644 --- a/onediff_comfy_nodes/modules/torch_compile/booster_basic.py +++ b/onediff_comfy_nodes/modules/torch_compile/booster_basic.py @@ -5,10 +5,18 @@ from comfy.controlnet import ControlLora, ControlNet from comfy.model_patcher import ModelPatcher from comfy.sd import VAE +from onediff.infer_compiler.backends.nexfort.deployable_module import ( + get_deployable_module, +) from ..booster_interface import BoosterExecutor +def compile(model: callable, *args, **kwargs): + compiled_model = torch.compile(model, *args, **kwargs) + return get_deployable_module(model, compiled_model) + + class TorchCompileBoosterExecutor(BoosterExecutor): # https://pytorch.org/docs/stable/_modules/torch.html#compile def __init__( @@ -29,7 +37,7 @@ def __init__( "mode": mode, "disable": disable, } - self.compile_fn = partial(torch.compile, **self.compile_kwargs) + self.compile_fn = partial(compile, **self.compile_kwargs) @singledispatchmethod def execute(self, model, ckpt_name=None, **kwargs): diff --git a/onediff_comfy_nodes/utils/__init__.py b/onediff_comfy_nodes/utils/__init__.py index a15479b11..e69de29bb 100644 --- a/onediff_comfy_nodes/utils/__init__.py +++ b/onediff_comfy_nodes/utils/__init__.py @@ -1 +0,0 @@ -from .import_utils import is_nexfort_available, is_oneflow_available diff --git a/onediff_comfy_nodes/utils/function_selector.py b/onediff_comfy_nodes/utils/function_selector.py new file mode 100644 index 000000000..39cec2068 --- /dev/null +++ b/onediff_comfy_nodes/utils/function_selector.py @@ -0,0 +1,56 @@ +try: + import git # type: ignore +except ImportError: + print("GitPython library not found. Installing...") + import subprocess + + subprocess.check_call(["pip", "install", "GitPython"]) + import git # type: ignore + +from datetime import datetime + + +class FunctionSelectorByCommitDate: + def __init__(self, repo_path): + self.repo_path = repo_path + self.repo = git.Repo(repo_path) + + def _get_current_commit_date(self): + """ + Get the date of the current commit. + """ + current_commit = self.repo.head.commit + current_commit_date = datetime.fromtimestamp(current_commit.committed_date) + return current_commit, current_commit_date + + def _get_commit_date(self, commit_hash): + """ + Get the date of a specific commit. + """ + commit = self.repo.commit(commit_hash) + commit_date = datetime.fromtimestamp(commit.committed_date) + return commit_date + + def __call__(self, iterable, default_func=None): + """ + Select a function based on the commit date. + + Args: + iterable (iterable): An iterable of tuples, where each tuple contains + a commit hash and a corresponding function. + default_func (callable, optional): The default function to return + if no match is found. + + Returns: + callable: The selected function. + """ + cur_commit, cur_date = self._get_current_commit_date() + sel_date, sel_func = None, default_func + for pair in iterable: + commit_hash, func = pair + other_commit_date = self._get_commit_date(commit_hash) + if cur_date > other_commit_date and (sel_date and other_commit_date > sel_date): + sel_date = other_commit_date + sel_func = func + + return sel_func diff --git a/onediff_comfy_nodes/utils/import_utils.py b/onediff_comfy_nodes/utils/import_utils.py deleted file mode 100644 index d3d6e6f2a..000000000 --- a/onediff_comfy_nodes/utils/import_utils.py +++ /dev/null @@ -1,39 +0,0 @@ -import importlib -import platform - -system = platform.system() - - -def check_module_availability(module_name): - spec = importlib.util.find_spec(module_name) - - if spec: - try: - importlib.import_module(module_name) - except ImportError: - return False - else: - return False - - return True - - -_oneflow_available = check_module_availability("oneflow") -_onediff_quant_available = check_module_availability("onediff_quant") -_nexfort_available = check_module_availability("nexfort") - -if system != "Linux": - print(f"Warning: OneFlow is only supported on Linux. Current system: {system}") - _oneflow_available = False - - -def is_oneflow_available(): - return _oneflow_available - - -def is_onediff_quant_available(): - return _onediff_quant_available - - -def is_nexfort_available(): - return _nexfort_available diff --git a/src/onediff/infer_compiler/backends/compiler.py b/src/onediff/infer_compiler/backends/compiler.py index 4bf91bb83..2793df7c8 100644 --- a/src/onediff/infer_compiler/backends/compiler.py +++ b/src/onediff/infer_compiler/backends/compiler.py @@ -1,3 +1,4 @@ +from typing import Callable, Optional import torch from .deployable_module import DeployableModule @@ -6,7 +7,7 @@ def compile( - torch_module: torch.nn.Module, *, backend=_DEFAULT_BACKEND, options=None + torch_module: Optional[Callable] = None, *, backend=_DEFAULT_BACKEND, options=None ) -> DeployableModule: from .registry import lookup_backend diff --git a/src/onediff/infer_compiler/backends/nexfort/deployable_module.py b/src/onediff/infer_compiler/backends/nexfort/deployable_module.py index a9e94977e..0a513b24c 100644 --- a/src/onediff/infer_compiler/backends/nexfort/deployable_module.py +++ b/src/onediff/infer_compiler/backends/nexfort/deployable_module.py @@ -1,17 +1,56 @@ +from types import FunctionType +from typing import Type, Union import torch - +from torch import nn from ..deployable_module import DeployableModule class NexfortDeployableModule(DeployableModule): def __init__(self, compiled_module, torch_module): torch.nn.Module.__init__(self) - object.__setattr__(self, "_deployable_module_model", compiled_module) - object.__setattr__(self, "_modules", compiled_module._modules) object.__setattr__(self, "_torch_module", torch_module) + object.__setattr__(self, "_deployable_module_model", compiled_module) + # https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/eval_frame.py#L148 + if isinstance(torch_module, nn.Module) and isinstance(compiled_module, torch._dynamo.eval_frame.OptimizedModule): + object.__setattr__(self, "_modules", compiled_module._orig_mod._modules) + object.__setattr__( + self, "_parameters", compiled_module._orig_mod._parameters + ) + object.__setattr__(self, "_buffers", compiled_module._orig_mod._buffers) - def __call__(self, *args, **kwargs): + def forward(self, *args, **kwargs): return self._deployable_module_model(*args, **kwargs) def __getattr__(self, name): return getattr(self._deployable_module_model, name) + + +def _create_deployable_function( + compiled_model, torch_module: FunctionType = None +) -> FunctionType: + return compiled_model + + +def _create_mixed_deployable_module( + compiled_model, torch_module: nn.Module +) -> Type[NexfortDeployableModule]: + module_cls = type(torch_module) + + class MixedNexfortDeployableModule(NexfortDeployableModule, module_cls): + def __init__(self, compiled_module, torch_module): + super().__init__(compiled_module, torch_module) + + def _get_name(self): + return f"{self.__class__.__name__}(of {module_cls.__name__})" + + return MixedNexfortDeployableModule( + compiled_module=compiled_model, torch_module=torch_module + ) + + +def get_deployable_module( + torch_module: Union[nn.Module, FunctionType], compiled_model +) -> Union[Type[NexfortDeployableModule], FunctionType]: + if not isinstance(torch_module, nn.Module): + return _create_deployable_function(compiled_model, torch_module) + return _create_mixed_deployable_module(compiled_model, torch_module) diff --git a/src/onediff/infer_compiler/backends/nexfort/nexfort.py b/src/onediff/infer_compiler/backends/nexfort/nexfort.py index d7f2fa69d..c31fa75c0 100644 --- a/src/onediff/infer_compiler/backends/nexfort/nexfort.py +++ b/src/onediff/infer_compiler/backends/nexfort/nexfort.py @@ -1,11 +1,25 @@ -import dataclasses +from typing import Callable + import torch + from ..registry import register_backend +from .deployable_module import get_deployable_module @register_backend("nexfort") def compile(torch_module: torch.nn.Module, *, options=None): from nexfort.compilers import nexfort_compile + + # Decorator mode + if torch_module is None: + + def fn(torch_module: Callable): + if torch_module is None: + raise RuntimeError("torch_module can't be None") + return compile(torch_module, options=options) + + return fn + if isinstance(options, str): import json @@ -13,5 +27,7 @@ def compile(torch_module: torch.nn.Module, *, options=None): options = json.loads(options) nexfort_options = options if options is not None else dict() + compiled_model = nexfort_compile(torch_module, **nexfort_options) - return compiled_model + + return get_deployable_module(torch_module, compiled_model) diff --git a/src/onediff/utils/import_utils.py b/src/onediff/utils/import_utils.py new file mode 100644 index 000000000..9d7e8b1d4 --- /dev/null +++ b/src/onediff/utils/import_utils.py @@ -0,0 +1,71 @@ +import importlib +import traceback +from inspect import ismodule +import os +import platform +from types import ModuleType + +system = platform.system() + + +def check_module_availability(module_name): + spec = importlib.util.find_spec(module_name) + + if spec: + try: + importlib.import_module(module_name) + except ImportError as e: + print(traceback.format_exc()) + return False + else: + return False + + return True + + +_oneflow_available = check_module_availability("oneflow") +_onediff_quant_available = check_module_availability("onediff_quant") +_nexfort_available = check_module_availability("nexfort") + +if system != "Linux": + print(f"Warning: OneFlow is only supported on Linux. Current system: {system}") + _oneflow_available = False + + +def is_oneflow_available(): + return _oneflow_available + + +def is_onediff_quant_available(): + return _onediff_quant_available + + +def is_nexfort_available(): + return _nexfort_available + + +class DynamicModuleLoader(ModuleType): + def __init__(self, obj_entity: ModuleType, pkg_root=None, module_path=None): + self.obj_entity = obj_entity + self.pkg_root = pkg_root + self.module_path = module_path + + @classmethod + def from_path(cls, module_path: str): + model_name = os.path.basename(module_path) + module = importlib.import_module(model_name) + return cls(module, module_path, module_path) + + def __getattr__(self, name): + obj_entity = getattr(self.obj_entity, name, None) + module_path = os.path.join(self.module_path, name) + if obj_entity is None: + pkg_name = os.path.basename(self.pkg_root) + absolute_name = os.path.relpath(module_path, self.pkg_root).replace( + os.path.sep, "." + ) + absolute_name = f"{pkg_name}.{absolute_name}" + obj_entity = importlib.import_module(absolute_name) + if ismodule(obj_entity): + return DynamicModuleLoader(obj_entity, self.pkg_root, module_path) + return obj_entity diff --git a/tests/test_model_inference.py b/tests/test_model_inference.py new file mode 100644 index 000000000..2996a5b38 --- /dev/null +++ b/tests/test_model_inference.py @@ -0,0 +1,101 @@ +import copy +import time +import unittest +from functools import partial + +import torch +from onediff.utils.import_utils import is_oneflow_available, is_nexfort_available +import onediff.infer_compiler as infer_compiler + + +class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(100, 10) + + def forward(self, x): + return self.linear(x) + + +class MainModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(1000, 100) + self.sub_module = SubModule() + + def forward(self, x): + x = torch.nn.functional.relu(self.linear(x)) + x = self.sub_module(x) + return x + + +def compute(x): + return torch.sin(x) + torch.cos(x) + + +class TestModelInference(unittest.TestCase): + def setUp(self) -> None: + self.compilation_functions = [] + + if is_oneflow_available(): + oneflow_compile_fn = partial(infer_compiler.compile, backend="oneflow") + self.compilation_functions.append(oneflow_compile_fn) + + if is_nexfort_available(): + nexfort_compile_options = { + "mode": "max-optimize:max-autotune:freezing:benchmark:cudagraphs", + "dynamic": True, + "fullgraph": True, + } + nexfort_compile_fn = partial(infer_compiler.compile, backend="nexfort", options=nexfort_compile_options) + self.compilation_functions.append(nexfort_compile_fn) + + assert len(self.compilation_functions) > 0 + + def measure_inference_time(self, model, warmup=3, num_runs=30, input_args=[], input_kwargs={}): + for _ in range(warmup): + model(*input_args, **input_kwargs) + + total_time = 0.0 + for _ in range(num_runs): + start_time = time.time() + model(*input_args, **input_kwargs) + total_time += time.time() - start_time + + average_time = total_time / num_runs + result = model(*input_args, **input_kwargs) + return result, average_time + + def generate_models_and_inputs(self): + for compile_fn in self.compilation_functions: + model = MainModule().cuda().half() + inputs = [torch.randn(10000, 1000).cuda().half()] + compiled_model = compile_fn(model) + yield model, compiled_model, inputs, {} + + model = MainModule().cuda().half() + inputs = [torch.randn(10000, 1000).cuda().half()] + compiled_model_sub = copy.deepcopy(model) + compiled_model_sub.sub_module = compile_fn(compiled_model_sub.sub_module) + yield model, compiled_model_sub, inputs, {} + + if compile_fn.keywords.get('backend') == "nexfort": + inputs_compute = [torch.randn(10000, 1000).cuda().half()] + compiled_compute_fn = compile_fn(compute) + yield compute, compiled_compute_fn, inputs_compute, {} + + @torch.inference_mode() + def test_inference_results(self): + for model, compiled_model, input_args, input_kwargs in self.generate_models_and_inputs(): + original_result, _ = self.measure_inference_time(model, input_args=input_args, input_kwargs=input_kwargs) + compiled_result, _ = self.measure_inference_time(compiled_model, input_args=input_args, input_kwargs=input_kwargs) + + self.assertTrue(torch.allclose(original_result, compiled_result, atol=1e-2, rtol=1e-3)) + + if isinstance(model, torch.nn.Module): + self.assertIsInstance(compiled_model, MainModule) + self.assertEqual(set(model.state_dict().keys()), set(compiled_model.state_dict().keys())) + + +if __name__ == "__main__": + unittest.main()