Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux fp16 inference fix #9097

Merged
merged 10 commits into from
Aug 7, 2024
4 changes: 2 additions & 2 deletions docs/source/en/api/pipelines/flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Both checkpoints have slightly difference usage which we detail below.

```python
import torch
from diffusers import FluxPipeline
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
Expand All @@ -61,7 +61,7 @@ out.save("image.png")

```python
import torch
from diffusers import FluxPipeline
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def forward(
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
latentCall145 marked this conversation as resolved.
Show resolved Hide resolved

return hidden_states

Expand Down Expand Up @@ -223,6 +225,8 @@ def forward(

context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

return encoder_hidden_states, hidden_states

Expand Down
Loading