Skip to content

Commit

Permalink
Merge pull request #14597 from AUTOMATIC1111/improved-manual-cast
Browse files Browse the repository at this point in the history
Improve the implementation of Manual Cast and IPEX support
  • Loading branch information
AUTOMATIC1111 authored Jan 9, 2024
2 parents 6869d95 + ca671e5 commit 905b142
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
56 changes: 40 additions & 16 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def enable_tf32():
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
dtype_inference: torch.dtype = torch.float16
unet_needs_upcast = False


Expand All @@ -131,21 +132,44 @@ def cond_cast_float(input):
]


def manual_cast_forward(self, *args, **kwargs):
org_dtype = torch_utils.get_param(self).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
return result
def manual_cast_forward(target_dtype):
def forward_wrapper(self, *args, **kwargs):
if any(
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
for arg in args
):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}

org_dtype = torch_utils.get_param(self).dtype
if org_dtype != target_dtype:
self.to(target_dtype)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
self.to(org_dtype)

if target_dtype != dtype_inference:
if isinstance(result, tuple):
result = tuple(
i.to(dtype_inference)
if isinstance(i, torch.Tensor)
else i
for i in result
)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
return result
return forward_wrapper


@contextlib.contextmanager
def manual_cast():
def manual_cast(target_dtype):
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward
if module_type == torch.nn.MultiheadAttention and has_xpu():
module_type.forward = manual_cast_forward(torch.float32)
else:
module_type.forward = manual_cast_forward(target_dtype)
module_type.org_forward = org_forward
try:
yield None
Expand All @@ -161,15 +185,15 @@ def autocast(disable=False):
if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)

if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_cast()
if fp8 and dtype_inference == torch.float32:
return manual_cast(dtype)

if has_mps() and shared.cmd_opts.precision != "full":
return manual_cast()

if dtype == torch.float32 or shared.cmd_opts.precision == "full":
if dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()

if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype)

return torch.autocast("cuda")


Expand Down
1 change: 1 addition & 0 deletions modules/shared_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def initialize():

devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype

shared.device = devices.device
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
Expand Down

0 comments on commit 905b142

Please sign in to comment.