Skip to content

Commit

Permalink
Flux Vae broke for float16, force bfloat16 or float32 were compatible (
Browse files Browse the repository at this point in the history
…#7213)

## Summary

The Flux VAE, like many VAEs, is broken if run using float16 inputs
returning black images due to NaNs
This will fix the issue by forcing the VAE to run in bfloat16 or float32
were compatible

## Related Issues / Discussions

Fix for issue #7208

## QA Instructions

Tested on MacOS, VAE works with float16 in the invoke.yaml and left to
default.
I also briefly forced it down the float32 route to check that to.
Needs testing on CUDA / ROCm

## Merge Plan

It should be a straight forward merge,
  • Loading branch information
RyanJDick authored Nov 13, 2024
2 parents fb19621 + b89caa0 commit ca9cb1c
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
3 changes: 2 additions & 1 deletion invokeai/app/invocations/flux_vae_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
img = vae.decode(latents)

img = img.clamp(-1, 1)
Expand Down
5 changes: 2 additions & 3 deletions invokeai/app/invocations/flux_vae_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tenso
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents

Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/load/load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
self._logger = logger
self._ram_cache = ram_cache
self._torch_dtype = TorchDevice.choose_torch_dtype()
self._torch_device = TorchDevice.choose_torch_device()

def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Expand Down
10 changes: 9 additions & 1 deletion invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,15 @@ def _load_model(
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)

return model

Expand Down

0 comments on commit ca9cb1c

Please sign in to comment.