Skip to content

Commit

Permalink
Speed up SDXL on 16xx series with fp16 weights and manual cast.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Feb 4, 2024
1 parent 98b80ad commit 24129d7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def unet_dtype(device=None, model_params=0):
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params):
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
return torch.float16
return torch.float32

Expand Down Expand Up @@ -696,7 +696,7 @@ def is_device_mps(device):
return True
return False

def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled

if device is not None:
Expand Down Expand Up @@ -738,7 +738,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if x in props.name.lower():
fp16_works = True

if fp16_works:
if fp16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
Expand Down

0 comments on commit 24129d7

Please sign in to comment.