Skip to content

Commit

Permalink
Merge branch 'comfyanonymous:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
KubaBir authored Feb 5, 2024
2 parents eca76f0 + 236bda2 commit 7bc9086
Show file tree
Hide file tree
Showing 15 changed files with 233 additions and 43 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,23 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints

Put your VAE in: models/vae

Note: pytorch stable does not support python 3.12 yet. If you have python 3.12 you will have to use the nightly version of pytorch. If you run into issues you should try python 3.11 instead.

### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:

```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7```

This is the command to install the nightly with ROCm 5.7 which has a python 3.12 package and might have some performance improvements:
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:

```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```

### NVIDIA

Nvidia users should install stable pytorch using this command:

```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```

This is the command to install pytorch nightly instead which has a python 3.12 package and might have performance improvements:
This is the command to install pytorch nightly instead which might have performance improvements:

```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```

Expand Down
17 changes: 9 additions & 8 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def unet_dtype(device=None, model_params=0):
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params):
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
return torch.float16
return torch.float32

Expand Down Expand Up @@ -546,10 +546,8 @@ def text_encoder_dtype(device=None):
if is_device_cpu(device):
return torch.float16

if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
return torch.float16


def intermediate_device():
if args.gpu_only:
Expand Down Expand Up @@ -698,7 +696,7 @@ def is_device_mps(device):
return True
return False

def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled

if device is not None:
Expand All @@ -724,10 +722,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if is_intel_xpu():
return True

if torch.cuda.is_bf16_supported():
if torch.version.hip:
return True

props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True

if props.major < 6:
return False

Expand All @@ -740,7 +741,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if x in props.name.lower():
fp16_works = True

if fp16_works:
if fp16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
Expand Down
2 changes: 1 addition & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ class WeightsLoader(torch.nn.Module):
model.load_model_weights(sd, "model.diffusion_model.")

if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)

Expand Down
1 change: 1 addition & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class BASE:
noise_aug_config = None
sampling_settings = {}
latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]

manual_cast_dtype = None

Expand Down
2 changes: 2 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
x = max(0, min(s.shape[-1] - overlap, x))
y = max(0, min(s.shape[-2] - overlap, y))
s_in = s[:,:,y:y+tile_y,x:x+tile_x]

ps = function(s_in).to(output_device)
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class LatentBatchSeedBehavior:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],),}}
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}

RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
Expand Down
13 changes: 13 additions & 0 deletions custom_nodes/example_node.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class Example:
-------------
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.

Attributes
----------
Expand Down Expand Up @@ -89,6 +91,17 @@ class Example:
image = 1.0 - image
return (image,)

"""
The node will always be re executed if any of the inputs change but
this method can be used to force the node to execute again even when the inputs don't change.
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
executed, if it is different the node will be executed again.
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
changes between executions the LoadImage node is executed again.
"""
#@classmethod
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
# return ""

# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
Expand Down
3 changes: 3 additions & 0 deletions web/extensions/core/groupNode.js
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,9 @@ export class GroupNodeHandler {
const self = this;
const onNodeCreated = this.node.onNodeCreated;
this.node.onNodeCreated = function () {
if (!this.widgets) {
return;
}
const config = self.groupData.nodeData.config;
if (config) {
for (const n in config) {
Expand Down
2 changes: 1 addition & 1 deletion web/extensions/core/groupNodeManage.css
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
list-style: none;
}
.comfy-group-manage-list-items {
max-height: 70vh;
max-height: calc(100% - 40px);
overflow-y: scroll;
overflow-x: hidden;
}
Expand Down
Loading

0 comments on commit 7bc9086

Please sign in to comment.