Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux fp16 inference fix #9097

Merged
merged 10 commits into from
Aug 7, 2024
Merged

Conversation

latentCall145
Copy link
Contributor

What does this PR do?

Fixes #9096
Flux can now run inference with torch.half (instead of just torch.bfloat16), allowing faster inference for Turing GPUs. There are two spots where the pretrained weights overflows in fp16 and clipping the activations there results in coherent image results.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Left a comment.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@latentCall145
Copy link
Contributor Author

latentCall145 commented Aug 6, 2024

As this issue mentions, FP16 significantly changes the result of the images. This issue surprisingly has to do with the text encoders (and not the clipping). Specifically, some activations in the text encoders have to be clipped when running in FP16 (it's a dynamic range problem, not a precision one). Forcing FP32 inference on the text encoders thus allows FP16 DiT + VAE inference to be similar to FP32/BF16.

Reproduction

from diffusers import FluxPipeline
import matplotlib.pyplot as plt
import torch
import time
torch.backends.cudnn.benchmark = True

DTYPE = torch.float16

ckpt_id = "black-forest-labs/FLUX.1-schnell"
pipe = FluxPipeline.from_pretrained(
    ckpt_id,
    torch_dtype=torch.bfloat16,
)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()
pipe.to(DTYPE)

images = pipe(
    'A laptop whose screen displays a picture of a black forest gateau cake spelling out the words "FLUX SCHNELL". The laptop screen, keyboard, and the table is on fire. no watermark, photograph',
    num_inference_steps=1,
    num_images_per_prompt=1,
    guidance_scale=0.0,
    height=1024,
    width=1024,
    generator=torch.Generator(device='cuda').manual_seed(0), # device='cpu' results in different random tensors across different dtypes?
).images

plt.imshow(images[0])
plt.show()

Prompt

A laptop whose screen displays a picture of a black forest gateau cake spelling out the words "FLUX SCHNELL". The laptop screen, keyboard, and the table is on fire. no watermark, photograph

Other

num_inference_steps = 1
height = width = 1024

Outputs (clipped)

catted
left to right: fp32, bf16, fp16

Outputs (clipped, fp32 text encoders)

catted
left to right: fp32, bf16, fp16

@sayakpaul
Copy link
Member

Thank you for this investigation. Would you be able to put this analysis in the Flux documentation we have here? I believe this will be extremely valuable to the community. Cc: @DN6 @yiyixuxu

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@latentCall145
Copy link
Contributor Author

Thank you for this investigation. Would you be able to put this analysis in the Flux documentation we have here? I believe this will be extremely valuable to the community

Sure. Which part of the investigation would you want in docs, just the difference between fp16 + bf16 inference and what causes it?
On a side note, should I also include an option to force fp32 inference for the text encoders when running the Flux pipeline in fp16?

@sayakpaul
Copy link
Member

Sure. Which part of the investigation would you want in docs, just the difference between fp16 + bf16 inference and what causes it?

Apologies for not being clear. I think the investigation you presented in #9097 (comment) could be wrapped under a section in the Flux document with the heading "Running FP16 Inference".

On a side note, should I also include an option to force fp32 inference for the text encoders when running the Flux pipeline in fp16?

As long as it's documented like we're discussing, it should be fine IMO. This way, users have all the information to fix the problems rather than us having to silently fix it for them.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a single comment.

@sayakpaul
Copy link
Member

Ah, we now have a conflict to resolve. Sorry about that.

@sayakpaul sayakpaul merged commit 9b5180c into huggingface:main Aug 7, 2024
15 checks passed
@sayakpaul
Copy link
Member

Thank you!

@Roman-dem
Copy link

Great job!
Is it possible to make such an optimization for inference on 2 gpu v100?

@latentCall145
Copy link
Contributor Author

latentCall145 commented Oct 23, 2024

Great job!
Is it possible to make such an optimization for inference on 2 gpu v100?

Diffusers has documentation on how to do distributed inference on multiple GPUs, this’ll probably work for you: https://huggingface.co/docs/diffusers/training/distributed_inference

There’s even a section for Flux.1 inference (model sharding) although if you have 32GB V100s, I don’t think you’ll need to do model sharding as long as you enable model CPU offloading because Flux.1 can fit within 32 GB (although I don’t know the behavior of offloading for distributed inference).

@Roman-dem
Copy link

Great job!
Is it possible to make such an optimization for inference on 2 gpu v100?

Diffusers has documentation on how to do distributed inference on multiple GPUs, this’ll probably work for you: https://huggingface.co/docs/diffusers/training/distributed_inference

There’s even a section for Flux.1 inference (model sharding) although if you have 32GB V100s, I don’t think you’ll need to do model sharding as long as you enable model CPU offloading because Flux.1 can fit within 32 GB (although I don’t know the behavior of offloading for distributed inference).

In fact, the base flax dev does not fit entirely on 32 Gb and if you connect a processor, the inference speed drops sharply. The problem is that after pipe.to(“dtype”) i can't send model to the gpu. The reverse order doesn't work either.

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* clipping for fp16

* fix typo

* added fp16 inference to docs

* fix docs typo

* include link for fp16 investigation

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Flux inference with torch.half outputs NaN values
6 participants