diff --git a/src/animatediff/generate.py b/src/animatediff/generate.py index a7e223b8..64c04973 100644 --- a/src/animatediff/generate.py +++ b/src/animatediff/generate.py @@ -19,6 +19,7 @@ EulerDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline) +from diffusers.loaders import UNet2DConditionLoadersMixin from PIL import Image from torchvision.datasets.folder import IMG_EXTENSIONS from tqdm.rich import tqdm @@ -71,7 +72,8 @@ "qr_code_monster_v1" : ['monster-labs/control_v1p_sd15_qrcode_monster'], "qr_code_monster_v2" : ['monster-labs/control_v1p_sd15_qrcode_monster', 'v2'], "controlnet_mediapipe_face" : ['CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"], - "animatediff_controlnet" : [None, "data/models/controlnet/animatediff_controlnet/controlnet_checkpoint.ckpt"] + "animatediff_controlnet" : [None, "data/models/controlnet/animatediff_controlnet/controlnet_checkpoint.ckpt"], + "controlnet_loose" : ["data/models/controlnet/controlnet_loose/pytorch_lora_weights.safetensors"] } # Edit this table if you want to change to another controlnet checkpoint @@ -118,7 +120,17 @@ onnxruntime_installed = False - +def attach_loaders_mixin(model): + # hacky way to make ControlNet work with LoRA. This may not be required in future versions of diffusers. + model.text_encoder_name = "text_encoder" + model.unet_name = "unet" + # mixin_instance = UNet2DConditionLoadersMixin() + for attr_name, attr_value in vars(UNet2DConditionLoadersMixin).items(): + # print(attr_name) + if callable(attr_value): + # setattr(model, attr_name, functools.partialmethod(attr_value, model).__get__(model, model.__class__)) + setattr(model, attr_name, partial(attr_value, model)) + return model logger = logging.getLogger(__name__) @@ -322,7 +334,13 @@ def create_controlnet_model(pipe, type_str, is_sdxl): addr = controlnet_address_table[type_str] if addr[0] != None: if len(addr) == 1: - return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) + if type_str == "controlnet_loose": + cn = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16) + cn = attach_loaders_mixin(cn) + cn.load_attn_procs(addr[0]) + return cn + else: + return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) else: return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) else: