Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very temporary fix to use LoRA with fp8 weight enabled #209

Open
wangziyao318 opened this issue Dec 29, 2023 · 1 comment
Open

Very temporary fix to use LoRA with fp8 weight enabled #209

wangziyao318 opened this issue Dec 29, 2023 · 1 comment

Comments

@wangziyao318
Copy link

This is only for those who want to lower vram usage a bit (to be able to use larger batch size) when using tensorRT with sd webui, at a cost of accuracy.

As far as I can tell, the fp8 optimization (currently available in sd webui dev branch, under Settings/Optimizations) would slightly reduce vram usage when using with tensorRT (from 10.9G to 9.6G to train certain SDXL, compared with from 9.7G to 6.8G without tensorRT), because the tensorRT side still stores data in fp16. The vram usage would decrease further if tensorRT has option to store data in fp8 as well.

LoRA can't be converted to tensorrt under fp8 due to dtype cast issue. Here's a very temporarily and dirty fix to get it work. (in dev branch)

In model_helper.py, line 178

wt = wt.cpu().detach().half().numpy().astype(np.float16)

In exporter.py, line 80 and 82

wt_hash = hash(wt.cpu().detach().half().numpy().astype(np.float16).data.tobytes())

delta = wt.half() - torch.tensor(onnx_data_mapping[initializer_name]).to(wt.device)

The idea is to add .half() to convert tensor dtype fp8 to fp16 to do calculation with other fp16 values. Also notice that cache fp16 weight for LoRA in Settings/Optimizations doesn't work in this fix, and therefore you need to apply more weight to the fp8 LoRA you used to achieve the same effect with LoRA in fp16.

By the way, if you check out sd webui dev branch which uses cu121, you can change to 9.0.1.post12.dev4 or the newer 9.2.0.post12.dev5 for cuda 12. (9.1.0.post12.dev4 building wheel failed in my pc, so I don't suggest it) Ensure to modify install.py to update the version number. (tensorRT still work even if you don't change)

@Vinzelles
Copy link

Vinzelles commented Nov 8, 2024

你好,根据你的方法我成功解决了LoRA导出报错的问题,非常感谢。
但是,转换生成的.lora文件大小只有1kb,这是正常的吗?另外转换后的LoRA应该如何使用?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants