Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A big improvement for dtype casting system with fp8 storage type and manual cast #14031

Merged
merged 33 commits into from
Dec 16, 2023

Conversation

KohakuBlueleaf
Copy link
Collaborator

@KohakuBlueleaf KohakuBlueleaf commented Nov 19, 2023

Description

After pytorch 2.1.0, pytorch added 2 new dtype as storage type: float8_e5m2, float8_e4m3fn.[1][2]
Based on the papers which discuss the usage of fp8 as parameter/gradient for training/using NN models. I think it is worth doing some optimization with fp8 format.[3][4]
Also, some extension already support this feature too [5]

Mechanism

Although pytorch2.1.0 start supporting fp8 as storage type. We have only few hidden method for H100 to computing matmul with fp8 dtype.[6] Which means even though we can store model weights in FP8, we still need to use fp16 or bf16 to compute the result. (a.k.a upcasting)

Fortunately, pytorch's autocast can do it for us without any other changes. We just need to avoid some modules which not support fp8 storage, for example: nn.Embedding.
And for doing this for some devices which not support autocast, I also implement a manualcast hook which support GTX16xx(or even older) series to utilize fp8 features.

Manual Cast

The idea is pretty simple, when parameters and inputs have different dtype then target dtype (defined in devices.py), cast it to target dtype.

def manual_cast_forward(self, *args, **kwargs):
    org_dtype = next(self.parameters()).dtype
    self.to(dtype)
    args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
    kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
    result = self.org_forward(*args, **kwargs)
    self.to(org_dtype)
    return result

I hooked the Linear, Conv, MultiHeadAttention, GroupNorm, LayerNorm layers. Which could almost covered all the layers we need. But since the hook is very simple, we can just add the layers that need to be patched in the future.

This feature also allow GTX16xx series or older to utilize fp16 (even fp8) weight to reduce vram usage significantly.

List of implemented things

  • FP8 cast for UNet and TE. (Linear/Conv layers only)
  • Manual Cast for GTX16xx/GTX10xx/MPS/CPU... (MPS not support FP8 storage yet)
  • XYZ-grid support for FP8 mode.
  • [For Dev] Forced reload option in reload_model_weights. (in FP8, I use this feature to apply fp8 settings correctly)
  • Built-in LoRA system change some dtype convert method so we can apply lora to a fp8 weight models.

Performance on VRAM usage/Speed/Quality

Vram usage

Theoretically, FP8 can reduce 50% of "vram usage introduced by weights" (FP8 vs FP16). So for sd1.x with --medvram, it can save about 800MB vram when inference, or 2.5G vram for sdxl.

Here is some measurement with HWinfo64 on Windows 11 23H2/RTX 4090
Launch args: --medvram --opt-channelslast --xformers

1024x1024 SDXL 768x768 SD1.x 512x512 SD1.x
FP16 5923MB 2085MB 2128MB
FP8 3249MB 1198MB 935MB
image image image

Table 1. Vram usage comparison between FP16/FP8

First spike is FP16 run and Second spike is FP8

Upper bound: FP16 vram usage
Lower bound: idle vram usage
Cursor box: FP8 vram usage

We can find that FP8 save 2.5 ~ 2.8G vram in SDXL and 0.8 ~ 0.9G vram in SD1.x.
Which is almost match the theory.

Speed

Since use FP8 with FP16 computation need some extra operation to cast the dtype.
It will decrease the speed (especially for smaller batch size)

Batch size 768x768 SD1.x fp16 768x768 SD1.x fp8 1024x1024 SDXL fp16 1024x1024 SDXL fp8
1 8.27it/s 7.85it/s 3.84it/s 3.67it/s
4 3.19it/s 3.08it/s 1.51it/s 1.45it/s

Table 2. Inference speed comparison between FP16/FP8

We can find that although FP8 is slower than FP16 but the speed difference is smaller than 5%, which is acceptable.

Quality

Surprisingly, FP8 barely decrease the quality of generated image, sometime it may even improve the quality. But normally the differences are subtle.

Some comparisons here:
xyz_grid-1214-2023-11-19_51a0c178b7_kohaku-xl-gamma1_2_8be10275-1760x2165

Image 1. Image comparison between FP16/FP8 on SDXL

xyz_grid-1219-2023-11-19_7f3da1fab9_kohaku-v4-rev1 2_2_84c4546d-1584x2027

Image 2. Image comparison between FP16/FP8 on SD1.x

But interestingly(or, actually expceted), use LoRA/LyCORIS with FP8 enabled will "weaken" the effect of the lora, which means you may need higher (1.5x ~ 3x) weight to reach same effect.
For example here is the comparison on LCM-lora + SD1.x models. with Euler A/cfg 2/steps8:
xyz_grid-1224-2023-11-19_7f3da1fab9_kohaku-v4-rev1 2_2_22636977-2736x2026

Image 3. LoRA effect on FP8

some reported improvement

Information from animatediff extension, use FP8 on UNet+CN+AN can achive 1024x1024 i2i on 4090. Which is impossible in the past.
Also reported from users of animatediff, use fp8+lcm lora can improve the quality of result.

Conclusion

FP8 is good and almost zero cost improvement on VRAM usage. Which is a good news for SDXL users.
Moreover, we can even use fp8 to storage the model weight directly which can reduce the ram requirement as well. (fp8 ckpt + fp8 storage enabled can achive sys ram 8G + vram 4G requirement with --medvram)

But the "weaken effect" on LoRA also reveal some possible problem after apply this method. We may need more feedback on this feature.But Good news for this problem: just disable it can avoid every affection.

And this feature also required pytorch2.1.0 which "may" be unstable. (since pytorch 2.1.1 already be stable, maybe we can consider to wait for xformers' update for pytorch 2.1.1 and then merge this PR)

Reference

  1. RFC-0030: FP8 dtype introduction to PyTorch pytorch/rfcs#51
  2. [RFC] FP8 dtype introduction to PyTorch pytorch/pytorch#91577
  3. https://arxiv.org/abs/1812.08011
  4. https://arxiv.org/abs/2209.05433
  5. https://github.com/continue-revolution/sd-webui-animatediff#optimizations
  6. https://discuss.pytorch.org/t/fp8-support-on-h100/190236
  7. https://civitai.com/models/179309/lyco-exp-venus-park-umamusume

Appandix

Training with FP8 also be implemented based on kohya-ss/sd-scripts codebase in my fork, I also provide some example models for it. With my implementation, users can train sdxl lora/lycoris on 6G vram card with TE/latent been cached, or train them on 8G vram card with nothing been cached with 1024x1024 arb.[7]

Checklist:

@KohakuBlueleaf KohakuBlueleaf changed the title A big imporvement for dtype casting system with fp8 storage type and manual cast A big improvement for dtype casting system with fp8 storage type and manual cast Nov 19, 2023
modules/launch_utils.py Outdated Show resolved Hide resolved
@werran2
Copy link

werran2 commented Nov 20, 2023

great work

@BetaDoggo
Copy link

Could the Lora issue be solved by merging the lora weights before converting the model to FP8 for inference? It would require the model to be reloaded every time a lora is changed but if the fp16 version is cached in ram I think it could still be fast enough to be worth it.

@KohakuBlueleaf
Copy link
Collaborator Author

Could the Lora issue be solved by merging the lora weights before converting the model to FP8 for inference? It would require the model to be reloaded every time a lora is changed but if the fp16 version is cached in ram I think it could still be fast enough to be worth it.

Yes it could be solved
I can add this kind of options

@KohakuBlueleaf
Copy link
Collaborator Author

Could the Lora issue be solved by merging the lora weights before converting the model to FP8 for inference? It would require the model to be reloaded every time a lora is changed but if the fp16 version is cached in ram I think it could still be fast enough to be worth it.

Sorry I'm wrong, it is quite hard to be solved since lora ext actually don't know when it should load the fp16 weights.

I have an idea is to cache the fp16 weight in fp8-layers directly (in CPU) but it will require more sys ram.
But it will be more easy to done it.

Don't know if you think it is ok. (and definitely, it will be an option)

@KohakuBlueleaf
Copy link
Collaborator Author

@BetaDoggo I have added a mechanism to cache/restore fp16 weight when needed. This will require 5G more system ram to achive it on SDXL. But it do give us closer result with fp8 when we using lora:
xyz_grid-1251-2023-11-21_7f3da1fab9_kohaku-v4-rev1 2_2_eac61c51-2736x2026

@nosferatu500
Copy link

M2 Pro 16gb RAM

Same prompt + sampler + seed + plugins (too tired to describe everything).

Main branch: Using 16gb + 10gb swap
test-fp8 branch: using 13gb ram and 0mb swap

Bravo!

@saunderez
Copy link

Just wanted to give some feedback that I've been using your branch for a couple of weeks now and have had no problems at all (CUDA 4080). I'm finding minimal quality loss versus BF16 and generation speeds pretty much on par with LCM/Turbo models and LORAs at 16bit.

@ClashSAN
Copy link
Collaborator

Hi @KohakuBlueleaf,

Is fp8 going to be in v1.7.0?
If so, could you add a "fp8" or "experimental fp8" to metadata? it looks like the SD1.5 difference is much larger than --xformers or --upcast-sampling variations.

I also tried the fp8 settings option, and I get a non-reproducible image during the first switch from fp16 to fp8, when applying changes in settings:

stable-diffusion-v1-5 fp8 non-reproducible image
sd-1-5-fp16 fp8 glitch

@KohakuBlueleaf
Copy link
Collaborator Author

KohakuBlueleaf commented Dec 15, 2023

Hi @KohakuBlueleaf,

Is fp8 going to be in v1.7.0? If so, could you add a "fp8" or "experimental fp8" to metadata? it looks like the SD1.5 difference is much larger than --xformers or --upcast-sampling variations.

I also tried the fp8 settings option, and I get a non-reproducible image during the first switch from fp16 to fp8, when applying changes in settings:

stable-diffusion-v1-5 fp8 non-reproducible image
sd-1-5-fp16 fp8 glitch

  1. I don't think fp8 will be merged into 1.7.0 but this is depends on Automatic
  2. I'm not very sure what point is non-reproducible. Do you mean "when I first try fp8, it give me a totally different image but I cannot reproduce it" or "when I use fp8 branch, it give me a totally different image then before"
  3. Thx for noticing me to add FP8 related settings into infotext

@ClashSAN
Copy link
Collaborator

ClashSAN commented Dec 15, 2023

vid.mp4

when I have fp8 enabled in settings, then exit the program. Then open the program, and the random seed distribution is entirely different. GPU seed is affected, not CPU seed.

@KohakuBlueleaf
Copy link
Collaborator Author

vid.mp4
when I have fp8 enabled in settings, then exit the program. Then open the program, and the random seed distribution is entirely different. GPU seed is affected, not CPU seed.

This is quite interesting, will check it

@KohakuBlueleaf
Copy link
Collaborator Author

vid.mp4
when I have fp8 enabled in settings, then exit the program. Then open the program, and the random seed distribution is entirely different. GPU seed is affected, not CPU seed.

Want to check if I understand this correctly:
"When startup the program with 'fp8' enabled, it will generate non-reproducible strange image"

But if the "fp8" is enabled "after" startup (startup with fp16), it will be normal.

@AUTOMATIC1111 AUTOMATIC1111 merged commit c121f8c into dev Dec 16, 2023
6 checks passed
@AUTOMATIC1111 AUTOMATIC1111 deleted the test-fp8 branch December 16, 2023 07:22
@KohakuBlueleaf
Copy link
Collaborator Author

@ClashSAN I have investigated some similar effect but different.
I tried some debug log and all of them looks normal for me...
I will put this into dev server for more help.

@AUTOMATIC1111
Copy link
Owner

AUTOMATIC1111 commented Dec 16, 2023

I have a similar effect:

  1. generate pic:
    00052-1

  2. enable fp8 and generate pic:
    00054-1

  3. disable fp8 and generate pic:
    00057-1

My suspicion is that it has to do with cond caching.

@KohakuBlueleaf
Copy link
Collaborator Author

cond caching.
Makes sense to me.

@pkuliyi2015
Copy link

My suspicion is that it has to do with cond caching.

I confirmed your conjesture by breakpointing at the UNetModel's forward and do a switch of fp16->fp8->fp16. The result shows that context cache is permanently changed after enabling fp8. So the problem can be effectively fixed by invalidating the cond cache at the switching time.

Here is the crucial evidence:

db2a8c565af3ff77468d1194e697a3ff

@Manchovies
Copy link

I keep getting this error:
Traceback (most recent call last): File "C:\Users\sonic\AppData\Local\Programs\Python\Python310\lib\threading.py", line 973, in _bootstrap self._bootstrap_inner() File "C:\Users\sonic\AppData\Local\Programs\Python\Python310\lib\threading.py", line 1016, in _bootstrap_inner self.run() File "C:\Users\sonic\AppData\Local\Programs\Python\Python310\lib\threading.py", line 953, in run self._target(*self._args, **self._kwargs) File "D:\stable-diffusion-webui-1.5.1 (1)\newstablediffusionwebui\stable-diffusion-webui\modules\initialize.py", line 147, in load_model shared.sd_model # noqa: B018 File "D:\stable-diffusion-webui-1.5.1 (1)\newstablediffusionwebui\stable-diffusion-webui\modules\shared_items.py", line 110, in sd_model return modules.sd_models.model_data.get_sd_model() File "D:\stable-diffusion-webui-1.5.1 (1)\newstablediffusionwebui\stable-diffusion-webui\modules\sd_models.py", line 522, in get_sd_model load_model() File "D:\stable-diffusion-webui-1.5.1 (1)\newstablediffusionwebui\stable-diffusion-webui\modules\sd_models.py", line 649, in load_model load_model_weights(sd_model, checkpoint_info, state_dict, timer) File "D:\stable-diffusion-webui-1.5.1 (1)\newstablediffusionwebui\stable-diffusion-webui\modules\sd_models.py", line 395, in load_model_weights model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) AttributeError: module 'torch' has no attribute 'float8_e4m3fn'

How can I fix this? Sorry if this is the wrong place to ask, I'm somewhat new to github. Open to try any troubleshooting steps you think may help! Thanks for all your work on this! Looking forward to playing around with this.

@BetaDoggo
Copy link

@Manchovies You probably have an older version of torch. The fp8 options were added in torch 2.1. If you haven't reinstalled recently you likely have 2.0 or below. I believe if you delete or rename your venv folder the webui will install 2.1 automatically.

@Manchovies
Copy link

Manchovies commented Jan 26, 2024

@Manchovies You probably have an older version of torch. The fp8 options were added in torch 2.1. If you haven't reinstalled recently you likely have 2.0 or below. I believe if you delete or rename your venv folder the webui will install 2.1 automatically.

That worked! Thank you so much. Now, how can I get xformers installed? I remember trying in the past, and it would always uninstall the version of torch I have installed and install an older version and xformers to go with it. Maybe put "xformers==0.0.22.post7" in the requirements.txt file, or pip install xformers==0.0.22.post7 in the venv folder?

Edit: that seems to have done it. went to /venv/scripts, ran "activate" in cmd, and typed "pip install xformers==0.0.22.post7" and it installed correctly without uninstalling or tampering with the torch install I had set up. Thanks again! Happy to be playing around with FP8 and enjoying the VRAM savings :)

@w-e-w w-e-w mentioned this pull request Feb 17, 2024
@Dampfinchen
Copy link

Dampfinchen commented Mar 3, 2024

Hm, I thought I could run SDXL completely in VRAM with this (6 GB) but it needs --medvram for it to not OOM. Comparing it to Comfy without FP8, Comfy takes around 12 seconds, WebUi with FP8 and medvram around 9 seconds, so its a decent improvement. But idk why it won't fit completely into VRAM.

@KohakuBlueleaf
Copy link
Collaborator Author

p

For 6GB or worse card, you can wait for the next big update for lowvram, which will be as fast as comfy/forge for super lowvram cards. At least the author have tested it on 3060 6G

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.