Skip to content

Commit

Permalink
New device selection
Browse files Browse the repository at this point in the history
  • Loading branch information
MrReclusive authored Jan 1, 2025
1 parent 561c058 commit 6196fac
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 32 deletions.
111 changes: 90 additions & 21 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ def INPUT_TYPES(s):

def setargs(self, **kwargs):
return (kwargs, )

class HyVideoEnhanceAVideo:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}),
"single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Enable Enhance-A-Video for single blocks"}),
"double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Enable Enhance-A-Video for double blocks"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}),
},
}
RETURN_TYPES = ("FETAARGS",)
RETURN_NAMES = ("feta_args",)
FUNCTION = "setargs"
CATEGORY = "HunyuanVideoWrapper"
DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"

def setargs(self, **kwargs):
return (kwargs, )

class HyVideoSTG:
@classmethod
Expand Down Expand Up @@ -203,6 +224,8 @@ def INPUT_TYPES(s):
"compile_args": ("COMPILEARGS", ),
"block_swap_args": ("BLOCKSWAPARGS", ),
"lora": ("HYVIDLORA", {"default": None}),
"auto_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "Enable auto offloading for reduced VRAM usage, implementation from DiffSynth-Studio, slightly different from block swapping and uses even less VRAM, but can be slower as you can't define how much VRAM to use"}),
"cuda_device": ("CUDADEVICE", ),
}
}

Expand All @@ -212,9 +235,9 @@ def INPUT_TYPES(s):
CATEGORY = "HunyuanVideoWrapper"

def loadmodel(self, model, base_precision, load_device, quantization,
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None):
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, auto_cpu_offload=False,cuda_device=None):
transformer = None
mm.unload_all_models()
#mm.unload_all_models()
mm.soft_empty_cache()
manual_offloading = True
if "sage" in attention_mode:
Expand All @@ -223,7 +246,7 @@ def loadmodel(self, model, base_precision, load_device, quantization,
except Exception as e:
raise ValueError(f"Can't import SageAttention: {str(e)}")

device = mm.get_torch_device()
device = mm.get_torch_device() if cuda_device is None else cuda_device
offload_device = mm.unet_offload_device()
manual_offloading = True
transformer_load_device = device if load_device == "main_device" else offload_device
Expand Down Expand Up @@ -307,7 +330,7 @@ def loadmodel(self, model, base_precision, load_device, quantization,

patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0)

comfy.model_management.load_model_gpu(patcher)
comfy.model_management.load_models_gpu([patcher])
if load_device == "offload_device":
patcher.model.diffusion_model.to(offload_device)

Expand All @@ -318,6 +341,9 @@ def loadmodel(self, model, base_precision, load_device, quantization,
from .hyvideo.modules.fp8_optimization import convert_fp8_linear
convert_fp8_linear(patcher.model.diffusion_model, base_dtype)

if auto_cpu_offload:
transformer.enable_auto_offload(dtype=dtype, device=device)

#compile
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
Expand Down Expand Up @@ -419,6 +445,7 @@ def loadmodel(self, model, base_precision, load_device, quantization,
patcher.model["manual_offloading"] = manual_offloading
patcher.model["quantization"] = "disabled"
patcher.model["block_swap_args"] = block_swap_args
patcher.model["auto_cpu_offload"] = auto_cpu_offload

return (patcher,)

Expand All @@ -430,13 +457,13 @@ def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}),
"device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],),
},
"optional": {
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16"}
),
"compile_args":("COMPILEARGS", ),
"cuda_device": ("CUDADEVICE", ),
}
}

Expand All @@ -446,8 +473,9 @@ def INPUT_TYPES(s):
CATEGORY = "HunyuanVideoWrapper"
DESCRIPTION = "Loads Hunyuan VAE model from 'ComfyUI/models/vae'"

def loadmodel(self, model_name, device, precision, compile_args=None):
def loadmodel(self, model_name, precision, compile_args=None, cuda_device=None):

device = mm.get_torch_device() if cuda_device is None else cuda_device
offload_device = mm.unet_offload_device()

dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
Expand All @@ -467,6 +495,7 @@ def loadmodel(self, model_name, device, precision, compile_args=None):
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
vae = torch.compile(vae, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])


return (vae,)

Expand Down Expand Up @@ -521,7 +550,6 @@ def INPUT_TYPES(s):
return {
"required": {
"llm_model": (["Kijai/llava-llama-3-8b-text-encoder-tokenizer","xtuner/llava-llama-3-8b-v1_1-transformers"],),
"device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],),
"clip_model": (["disabled","openai/clip-vit-large-patch14",],),
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16"}
Expand All @@ -531,6 +559,7 @@ def INPUT_TYPES(s):
"apply_final_norm": ("BOOLEAN", {"default": False}),
"hidden_state_skip_layer": ("INT", {"default": 2}),
"quantization": (['disabled', 'bnb_nf4', "fp8_e4m3fn"], {"default": 'disabled'}),
"cuda_device": ("CUDADEVICE", ),
}
}

Expand All @@ -540,12 +569,13 @@ def INPUT_TYPES(s):
CATEGORY = "HunyuanVideoWrapper"
DESCRIPTION = "Loads Hunyuan text_encoder model from 'ComfyUI/models/LLM'"

def loadmodel(self, llm_model, device, clip_model, precision, apply_final_norm=False, hidden_state_skip_layer=2, quantization="disabled"):
def loadmodel(self, llm_model, clip_model, precision, apply_final_norm=False, hidden_state_skip_layer=2, quantization="disabled", cuda_device=None):
lm_type_mapping = {
"Kijai/llava-llama-3-8b-text-encoder-tokenizer": "llm",
"xtuner/llava-llama-3-8b-v1_1-transformers": "vlm",
}
lm_type = lm_type_mapping[llm_model]
device = mm.get_torch_device() if cuda_device is None else cuda_device
offload_device = mm.unet_offload_device()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
quantization_config = None
Expand Down Expand Up @@ -660,14 +690,14 @@ def INPUT_TYPES(s):
return {"required": {
"text_encoders": ("HYVIDTEXTENCODER",),
"prompt": ("STRING", {"default": "", "multiline": True} ),
"device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],),
},
"optional": {
"force_offload": ("BOOLEAN", {"default": True}),
"prompt_template": (["video", "image", "custom", "disabled"], {"default": "video", "tooltip": "Use the default prompt templates for the llm text encoder"}),
"custom_prompt_template": ("PROMPT_TEMPLATE", {"default": PROMPT_TEMPLATE["dit-llm-encode-video"], "multiline": True}),
"clip_l": ("CLIP", {"tooltip": "Use comfy clip model instead, in this case the text encoder loader's clip_l should be disabled"}),
"hyvid_cfg": ("HYVID_CFG", ),
"cuda_device": ("CUDADEVICE", ),
}
}

Expand All @@ -676,9 +706,10 @@ def INPUT_TYPES(s):
FUNCTION = "process"
CATEGORY = "HunyuanVideoWrapper"

def process(self, text_encoders, device, prompt, force_offload=True, prompt_template="video", custom_prompt_template=None, clip_l=None, image_token_selection_expr="::4", hyvid_cfg=None, image1=None, image2=None, clip_text_override=None):
def process(self, text_encoders, prompt, force_offload=True, prompt_template="video", custom_prompt_template=None, clip_l=None, image_token_selection_expr="::4", hyvid_cfg=None, image1=None, image2=None, clip_text_override=None, cuda_device=None):
if clip_text_override is not None and len(clip_text_override) == 0:
clip_text_override = None
device = mm.get_torch_device() if cuda_device is None else cuda_device
offload_device = mm.text_encoder_offload_device()

text_encoder_1 = text_encoders["text_encoder"]
Expand Down Expand Up @@ -855,7 +886,6 @@ def INPUT_TYPES(s):
return {"required": {
"text_encoders": ("HYVIDTEXTENCODER",),
"prompt": ("STRING", {"default": "", "multiline": True} ),
"device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],),
"image_token_selection_expr": ("STRING", {"default": "::4", "multiline": False} ),
},
"optional": {
Expand Down Expand Up @@ -1021,17 +1051,18 @@ def INPUT_TYPES(s):
"num_frames": ("INT", {"default": 49, "min": 1, "max": 1024, "step": 4}),
"steps": ("INT", {"default": 30, "min": 1}),
"embedded_guidance_scale": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"flow_shift": ("FLOAT", {"default": 9.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"flow_shift": ("FLOAT", {"default": 9.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"force_offload": ("BOOLEAN", {"default": True}),
"device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],),

},
"optional": {
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"stg_args": ("STGARGS", ),
"context_options": ("COGCONTEXT", ),
"feta_args": ("FETAARGS", ),
"cuda_device": ("CUDADEVICE", ),
}
}

Expand All @@ -1040,10 +1071,10 @@ def INPUT_TYPES(s):
FUNCTION = "process"
CATEGORY = "HunyuanVideoWrapper"

def process(self, model, device, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, width, height, num_frames,
samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None):
def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, width, height, num_frames,
samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None, feta_args=None, cuda_device=None):
model = model.model

device = mm.get_torch_device() if cuda_device is None else cuda_device
offload_device = mm.unet_offload_device()
dtype = model["dtype"]
transformer = model["pipe"].transformer
Expand Down Expand Up @@ -1095,7 +1126,10 @@ def process(self, model, device, hyvid_embeds, flow_shift, steps, embedded_guida
offload_txt_in = model["block_swap_args"]["offload_txt_in"],
offload_img_in = model["block_swap_args"]["offload_img_in"],
)

elif model["auto_cpu_offload"]:
for name, param in transformer.named_parameters():
if "single" not in name and "double" not in name:
param.data = param.data.to(device)
elif model["manual_offloading"]:
transformer.to(device)

Expand Down Expand Up @@ -1129,6 +1163,7 @@ def process(self, model, device, hyvid_embeds, flow_shift, steps, embedded_guida
stg_start_percent=stg_args["stg_start_percent"] if stg_args is not None else 0.0,
stg_end_percent=stg_args["stg_end_percent"] if stg_args is not None else 1.0,
context_options=context_options,
feta_args=feta_args,
)

print_memory(device)
Expand Down Expand Up @@ -1158,7 +1193,9 @@ def INPUT_TYPES(s):
"temporal_tiling_sample_size": ("INT", {"default": 64, "min": 4, "max": 256, "tooltip": "Smaller values use less VRAM, model default is 64, any other value will cause stutter"}),
"spatial_tile_sample_min_size": ("INT", {"default": 256, "min": 32, "max": 2048, "step": 32, "tooltip": "Spatial tile minimum size in pixels, smaller values use less VRAM, may introduce more seams"}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Automatically set tile size based on defaults, above settings are ignored"}),
"device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],),
},
"optional": {
"cuda_device": ("CUDADEVICE", ),
},
}

Expand All @@ -1167,7 +1204,8 @@ def INPUT_TYPES(s):
FUNCTION = "decode"
CATEGORY = "HunyuanVideoWrapper"

def decode(self, vae, samples, enable_vae_tiling, temporal_tiling_sample_size, spatial_tile_sample_min_size, auto_tile_size, device):
def decode(self, vae, samples, enable_vae_tiling, temporal_tiling_sample_size, spatial_tile_sample_min_size, auto_tile_size, cuda_device=None):
device = mm.get_torch_device() if cuda_device is None else cuda_device
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
latents = samples["samples"]
Expand Down Expand Up @@ -1242,7 +1280,9 @@ def INPUT_TYPES(s):
"temporal_tiling_sample_size": ("INT", {"default": 64, "min": 4, "max": 256, "tooltip": "Smaller values use less VRAM, model default is 64, any other value will cause stutter"}),
"spatial_tile_sample_min_size": ("INT", {"default": 256, "min": 32, "max": 2048, "step": 32, "tooltip": "Spatial tile minimum size in pixels, smaller values use less VRAM, may introduce more seams"}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Automatically set tile size based on defaults, above settings are ignored"}),
"device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],),
},
"optional": {
"cuda_device": ("CUDADEVICE", ),
},
}

Expand All @@ -1251,7 +1291,8 @@ def INPUT_TYPES(s):
FUNCTION = "encode"
CATEGORY = "HunyuanVideoWrapper"

def encode(self, vae, image, enable_vae_tiling, temporal_tiling_sample_size, auto_tile_size, spatial_tile_sample_min_size, device):
def encode(self, vae, image, enable_vae_tiling, temporal_tiling_sample_size, auto_tile_size, spatial_tile_sample_min_size, cuda_device=None):
device = mm.get_torch_device() if cuda_device is None else cuda_device
offload_device = mm.unet_offload_device()

generator = torch.Generator(device=torch.device("cpu"))#.manual_seed(seed)
Expand Down Expand Up @@ -1356,6 +1397,30 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):

return (latent_images.float().cpu(), out_factors)

class HyVideoCudaSelect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cuda_device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],),
}
}

RETURN_TYPES = ("CUDADEVICE",)
RETURN_NAMES = ("cuda_device",)
FUNCTION = "select_device"

CATEGORY = "HunyuanVideoWrapper"

def select_device(self, cuda_device):
if not cuda_device:
raise ValueError("No CUDA device selected.")

# Return the selected device
print (cuda_device,)
return (cuda_device,)


NODE_CLASS_MAPPINGS = {
"HyVideoSampler": HyVideoSampler,
"HyVideoDecode": HyVideoDecode,
Expand All @@ -1376,6 +1441,8 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):
"HyVideoTextEmbedsSave": HyVideoTextEmbedsSave,
"HyVideoTextEmbedsLoad": HyVideoTextEmbedsLoad,
"HyVideoContextOptions": HyVideoContextOptions,
"HyVideoEnhanceAVideo": HyVideoEnhanceAVideo,
"HyVideoCudaSelect": HyVideoCudaSelect,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"HyVideoSampler": "HunyuanVideo Sampler",
Expand All @@ -1397,4 +1464,6 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):
"HyVideoTextEmbedsSave": "HunyuanVideo TextEmbeds Save",
"HyVideoTextEmbedsLoad": "HunyuanVideo TextEmbeds Load",
"HyVideoContextOptions": "HunyuanVideo Context Options",
"HyVideoEnhanceAVideo": "HunyuanVideo Enhance A Video",
"HyVideoCudaSelect": "HunyuanVideo Cuda Device Selector",
}
Loading

0 comments on commit 6196fac

Please sign in to comment.