diff --git a/config/prompts/prompt_travel.json b/config/prompts/prompt_travel.json index 7d792ff8..7bcdf678 100644 --- a/config/prompts/prompt_travel.json +++ b/config/prompts/prompt_travel.json @@ -229,6 +229,15 @@ "control_guidance_end": 1.0, "control_scale_list":[0.5,0.4,0.3,0.2,0.1] }, + "animatediff_controlnet": { + "enable": true, + "use_preprocessor":true, + "guess_mode":false, + "controlnet_conditioning_scale": 1.0, + "control_guidance_start": 0.0, + "control_guidance_end": 1.0, + "control_scale_list":[0.5,0.4,0.3,0.2,0.1] + }, "controlnet_ref": { "enable": false, "ref_image": "ref_image/ref_sample.png", diff --git a/src/animatediff/generate.py b/src/animatediff/generate.py index d9aeff1f..45910e71 100644 --- a/src/animatediff/generate.py +++ b/src/animatediff/generate.py @@ -47,7 +47,8 @@ from animatediff.utils.util import (get_resized_image, get_resized_image2, get_resized_images, get_tensor_interpolation_method, - prepare_dwpose, prepare_ip_adapter, + prepare_dwpose, prepare_extra_controlnet, + prepare_ip_adapter, prepare_ip_adapter_sdxl, prepare_lcm_lora, prepare_lllite, prepare_motion_module, save_frames, save_imgs, save_video) @@ -70,6 +71,7 @@ "qr_code_monster_v1" : ['monster-labs/control_v1p_sd15_qrcode_monster'], "qr_code_monster_v2" : ['monster-labs/control_v1p_sd15_qrcode_monster', 'v2'], "controlnet_mediapipe_face" : ['CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"], + "animatediff_controlnet" : [None, "data/models/controlnet/animatediff_controlnet/controlnet_checkpoint.ckpt"] } # Edit this table if you want to change to another controlnet checkpoint @@ -284,17 +286,47 @@ def is_valid_controlnet_type(type_str, is_sdxl): else: return (type_str in controlnet_address_table_sdxl) or (type_str in lllite_address_table_sdxl) +def load_controlnet_from_file(file_path, torch_dtype): + from safetensors.torch import load_file + + prepare_extra_controlnet() + file_path = Path(file_path) + if file_path.exists() and file_path.is_file(): + if file_path.suffix.lower() in [".pth", ".pt", ".ckpt"]: + controlnet_state_dict = torch.load(file_path, map_location="cpu", weights_only=True) + elif file_path.suffix.lower() == ".safetensors": + controlnet_state_dict = load_file(file_path, device="cpu") + else: + raise RuntimeError( + f"unknown file format for controlnet weights: {file_path.suffix}" + ) + else: + raise FileNotFoundError(f"no controlnet weights found in {file_path}") + + if file_path.parent.name == "animatediff_controlnet": + model = ControlNetModel(cross_attention_dim=768) + else: + model = ControlNetModel() + + missing, _ = model.load_state_dict(controlnet_state_dict["state_dict"], strict=False) + if len(missing) > 0: + logger.info(f"ControlNetModel has missing keys: {missing}") + + return model.to(dtype=torch_dtype) def create_controlnet_model(pipe, type_str, is_sdxl): if not is_sdxl: if type_str in controlnet_address_table: addr = controlnet_address_table[type_str] - if len(addr) == 1: - return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) + if addr[0] != None: + if len(addr) == 1: + return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) + else: + return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) else: - return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) + return load_controlnet_from_file(addr[1],torch_dtype=torch.float16) else: raise ValueError(f"unknown controlnet type {type_str}") else: diff --git a/src/animatediff/stylize.py b/src/animatediff/stylize.py index 4b9b83a8..7d48eb03 100644 --- a/src/animatediff/stylize.py +++ b/src/animatediff/stylize.py @@ -52,6 +52,7 @@ "qr_code_monster_v1", "qr_code_monster_v2", "controlnet_mediapipe_face", + "animatediff_controlnet", ] def create_controlnet_dir(controlnet_root): diff --git a/src/animatediff/utils/util.py b/src/animatediff/utils/util.py index 6b99d214..a1ca2ef8 100644 --- a/src/animatediff/utils/util.py +++ b/src/animatediff/utils/util.py @@ -347,6 +347,28 @@ def prepare_lllite(): ) +def prepare_extra_controlnet(): + import os + from pathlib import PurePosixPath + + from huggingface_hub import hf_hub_download + + os.makedirs("data/models/controlnet/animatediff_controlnet", exist_ok=True) + for hub_file in [ + "controlnet_checkpoint.ckpt" + ]: + path = Path(hub_file) + + saved_path = "data/models/controlnet/animatediff_controlnet" / path + + if os.path.exists(saved_path): + continue + + hf_hub_download( + repo_id="crishhh/animatediff_controlnet", subfolder=PurePosixPath(path.parent), filename=PurePosixPath(path.name), local_dir="data/models/controlnet/animatediff_controlnet" + ) + + def prepare_motion_module(): import os from pathlib import PurePosixPath