Skip to content

Commit

Permalink
Add new controlnet(animatediff_controlnet)
Browse files Browse the repository at this point in the history
  • Loading branch information
s9roll7 committed Dec 3, 2023
1 parent 749a2e7 commit 0490a64
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 4 deletions.
9 changes: 9 additions & 0 deletions config/prompts/prompt_travel.json
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@
"control_guidance_end": 1.0,
"control_scale_list":[0.5,0.4,0.3,0.2,0.1]
},
"animatediff_controlnet": {
"enable": true,
"use_preprocessor":true,
"guess_mode":false,
"controlnet_conditioning_scale": 1.0,
"control_guidance_start": 0.0,
"control_guidance_end": 1.0,
"control_scale_list":[0.5,0.4,0.3,0.2,0.1]
},
"controlnet_ref": {
"enable": false,
"ref_image": "ref_image/ref_sample.png",
Expand Down
40 changes: 36 additions & 4 deletions src/animatediff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
from animatediff.utils.util import (get_resized_image, get_resized_image2,
get_resized_images,
get_tensor_interpolation_method,
prepare_dwpose, prepare_ip_adapter,
prepare_dwpose, prepare_extra_controlnet,
prepare_ip_adapter,
prepare_ip_adapter_sdxl, prepare_lcm_lora,
prepare_lllite, prepare_motion_module,
save_frames, save_imgs, save_video)
Expand All @@ -70,6 +71,7 @@
"qr_code_monster_v1" : ['monster-labs/control_v1p_sd15_qrcode_monster'],
"qr_code_monster_v2" : ['monster-labs/control_v1p_sd15_qrcode_monster', 'v2'],
"controlnet_mediapipe_face" : ['CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"],
"animatediff_controlnet" : [None, "data/models/controlnet/animatediff_controlnet/controlnet_checkpoint.ckpt"]
}

# Edit this table if you want to change to another controlnet checkpoint
Expand Down Expand Up @@ -284,17 +286,47 @@ def is_valid_controlnet_type(type_str, is_sdxl):
else:
return (type_str in controlnet_address_table_sdxl) or (type_str in lllite_address_table_sdxl)

def load_controlnet_from_file(file_path, torch_dtype):
from safetensors.torch import load_file

prepare_extra_controlnet()

file_path = Path(file_path)

if file_path.exists() and file_path.is_file():
if file_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
controlnet_state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
elif file_path.suffix.lower() == ".safetensors":
controlnet_state_dict = load_file(file_path, device="cpu")
else:
raise RuntimeError(
f"unknown file format for controlnet weights: {file_path.suffix}"
)
else:
raise FileNotFoundError(f"no controlnet weights found in {file_path}")

if file_path.parent.name == "animatediff_controlnet":
model = ControlNetModel(cross_attention_dim=768)
else:
model = ControlNetModel()

missing, _ = model.load_state_dict(controlnet_state_dict["state_dict"], strict=False)
if len(missing) > 0:
logger.info(f"ControlNetModel has missing keys: {missing}")

return model.to(dtype=torch_dtype)

def create_controlnet_model(pipe, type_str, is_sdxl):
if not is_sdxl:
if type_str in controlnet_address_table:
addr = controlnet_address_table[type_str]
if len(addr) == 1:
return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16)
if addr[0] != None:
if len(addr) == 1:
return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16)
else:
return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16)
else:
return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16)
return load_controlnet_from_file(addr[1],torch_dtype=torch.float16)
else:
raise ValueError(f"unknown controlnet type {type_str}")
else:
Expand Down
1 change: 1 addition & 0 deletions src/animatediff/stylize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"qr_code_monster_v1",
"qr_code_monster_v2",
"controlnet_mediapipe_face",
"animatediff_controlnet",
]

def create_controlnet_dir(controlnet_root):
Expand Down
22 changes: 22 additions & 0 deletions src/animatediff/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,28 @@ def prepare_lllite():
)


def prepare_extra_controlnet():
import os
from pathlib import PurePosixPath

from huggingface_hub import hf_hub_download

os.makedirs("data/models/controlnet/animatediff_controlnet", exist_ok=True)
for hub_file in [
"controlnet_checkpoint.ckpt"
]:
path = Path(hub_file)

saved_path = "data/models/controlnet/animatediff_controlnet" / path

if os.path.exists(saved_path):
continue

hf_hub_download(
repo_id="crishhh/animatediff_controlnet", subfolder=PurePosixPath(path.parent), filename=PurePosixPath(path.name), local_dir="data/models/controlnet/animatediff_controlnet"
)


def prepare_motion_module():
import os
from pathlib import PurePosixPath
Expand Down

2 comments on commit 0490a64

@amirothman
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this controlnet do? Can't find any info on this online 😅

@s9roll7
Copy link
Owner Author

@s9roll7 s9roll7 commented on 0490a64 Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know all the details, but it seems to work like ip2p.
https://github.com/s9roll7/animatediff-cli-prompt-travel#example
#189

Please sign in to comment.