-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux fp16 inference fix #9097
Flux fp16 inference fix #9097
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Left a comment.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
As this issue mentions, FP16 significantly changes the result of the images. This issue surprisingly has to do with the text encoders (and not the clipping). Specifically, some activations in the text encoders have to be clipped when running in FP16 (it's a dynamic range problem, not a precision one). Forcing FP32 inference on the text encoders thus allows FP16 DiT + VAE inference to be similar to FP32/BF16. Reproductionfrom diffusers import FluxPipeline
import matplotlib.pyplot as plt
import torch
import time
torch.backends.cudnn.benchmark = True
DTYPE = torch.float16
ckpt_id = "black-forest-labs/FLUX.1-schnell"
pipe = FluxPipeline.from_pretrained(
ckpt_id,
torch_dtype=torch.bfloat16,
)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()
pipe.to(DTYPE)
images = pipe(
'A laptop whose screen displays a picture of a black forest gateau cake spelling out the words "FLUX SCHNELL". The laptop screen, keyboard, and the table is on fire. no watermark, photograph',
num_inference_steps=1,
num_images_per_prompt=1,
guidance_scale=0.0,
height=1024,
width=1024,
generator=torch.Generator(device='cuda').manual_seed(0), # device='cpu' results in different random tensors across different dtypes?
).images
plt.imshow(images[0])
plt.show() PromptA laptop whose screen displays a picture of a black forest gateau cake spelling out the words "FLUX SCHNELL". The laptop screen, keyboard, and the table is on fire. no watermark, photograph Othernum_inference_steps = 1 Outputs (clipped)
Outputs (clipped, fp32 text encoders) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Sure. Which part of the investigation would you want in docs, just the difference between fp16 + bf16 inference and what causes it? |
Apologies for not being clear. I think the investigation you presented in #9097 (comment) could be wrapped under a section in the Flux document with the heading "Running FP16 Inference".
As long as it's documented like we're discussing, it should be fine IMO. This way, users have all the information to fix the problems rather than us having to silently fix it for them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a single comment.
Ah, we now have a conflict to resolve. Sorry about that. |
Thank you! |
Great job! |
Diffusers has documentation on how to do distributed inference on multiple GPUs, this’ll probably work for you: https://huggingface.co/docs/diffusers/training/distributed_inference There’s even a section for Flux.1 inference (model sharding) although if you have 32GB V100s, I don’t think you’ll need to do model sharding as long as you enable model CPU offloading because Flux.1 can fit within 32 GB (although I don’t know the behavior of offloading for distributed inference). |
In fact, the base flax dev does not fit entirely on 32 Gb and if you connect a processor, the inference speed drops sharply. The problem is that after pipe.to(“dtype”) i can't send model to the gpu. The reverse order doesn't work either. |
* clipping for fp16 * fix typo * added fp16 inference to docs * fix docs typo * include link for fp16 investigation --------- Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
Fixes #9096
Flux can now run inference with torch.half (instead of just torch.bfloat16), allowing faster inference for Turing GPUs. There are two spots where the pretrained weights overflows in fp16 and clipping the activations there results in coherent image results.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.