Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NoneType object has no attribute 'replace' when text_to_image/train_text_to_image_lora save save_attn_procs #2616

Closed
better629 opened this issue Mar 9, 2023 · 9 comments
Assignees
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@better629
Copy link

better629 commented Mar 9, 2023

Describe the bug

inside examples/text_to_image/train_text_to_image_lora.py

        unet = unet.to(torch.float32)
        unet.save_attn_procs(args.output_dir)

use save_attn_procs but the weights_name is None which cause the problem in
https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py#L269-L274

Reproduction

Logs

No response

System Info

Ubuntu 20.04
Nvidia GTX 3090
CUDA Version: 11.7
Torch: 1.13.1
Diffusers: 0.15.0.dev0
deepspeed: 0.8.1
xformers: 0.0.17.dev466
accelerate: 0.16.0

@better629 better629 added the bug Something isn't working label Mar 9, 2023
@patrickvonplaten
Copy link
Contributor

Hey @better629,

Could you please provide a reproducible code snippet?

@patrickvonplaten
Copy link
Contributor

Also cc @sayakpaul for train text to lora training script

@better629
Copy link
Author

just run examples/text_to_image/train_text_to_image_lora.py, it seems like a logic problem in unet.save_attn_procs(args.output_dir)

@wfng92
Copy link
Contributor

wfng92 commented Mar 9, 2023

This is the same issue as #2548. The PR #2448 introduced a regression when saving with the unet.save_attn_procs since the new version default the value of weights_name to None instead of "pytorch_lora_weights.bin".

Since the check for None

if weights_name is None:
if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE
else:
weights_name = LORA_WEIGHT_NAME

is after the code for cleaning the folder from a previous save

weights_no_suffix = weights_name.replace(".bin", "")

The script will throw the following exception:

AttributeError: 'NoneType' object has no attribute 'replace'

Temporary workaround for now is to save and load as follows:

# save
unet.save_attn_procs(args.output_dir, weights_name="pytorch_lora_weights.bin")
# load
unet.load_attn_procs(args.output_dir, weight_name="pytorch_lora_weights.bin")

For some reasons unknown, saving uses weights_name while loading uses weight_name.

@sayakpaul
Copy link
Member

sayakpaul commented Mar 13, 2023

Thanks for investigating @wfng92! Would you like to contribute a PR? Pinging @wangjuns as well because of #2548 (comment).

@wfng92
Copy link
Contributor

wfng92 commented Mar 13, 2023

Thanks for investigating @wfng92! Would you like to contribute a PR? Pinging @wangjuns as well because of #2548 (comment).

Hi, I believe it is better for the diffusers and safetensors members to sort this issue out as this regression affects all the training scripts that save using the default way unet.save_attn_procs(args.output_dir). The proposed temporary workaround is not scalable and the default method to save unet.save_attn_procs(args.output_dir) should work out of the box without the need to pass in the filename.

There need to be better way to integrate safetensors without affecting users that

  1. did not install safetensors in the environment
  2. prefer the default setup

A few possible scenarios:

no safetensors

  1. default
    unet.save_attn_procs(args.output_dir) saves with pytorch_lora_weights.bin as filename
  2. Attempt to save as safetensors but package is not installed
    unet.save_attn_procs(args.output_dir, weights_name="pytorch_lora_weights.safetensors") saves with ???

Personally, I prefer to at least save a file (with a way to differentiate) instead of getting an error without saving. This prevents the need to re-train from the last checkpoint on exception.

installed safetensors

  1. same as default
    unet.save_attn_procs(args.output_dir) saves with pytorch_lora_weights.bin as filename

  2. Implicit save if safetensors is installed. Might cause confusion for those that prefer to save as bin but got safetensors installed
    unet.save_attn_procs(args.output_dir) saves with pytorch_lora_weights.safetensors as filename

  3. Explicit save provided that safetensors is installed
    unet.save_attn_procs(args.output_dir, weights_name="pytorch_lora_weights.safetensors") saves with pytorch_lora_weights.safetensors as filename

Also, it might be better to standardize the naming for weights_name/weight_name when saving and loading.

# save
unet.save_attn_procs(args.output_dir, weights_name="pytorch_lora_weights.bin")
# load
unet.load_attn_procs(args.output_dir, weight_name="pytorch_lora_weights.bin")

@patrickvonplaten
Copy link
Contributor

Thanks a mille for the nice summary here @better629 ! I think #2655 should fix everything - could you give it a try? :-)

@better629
Copy link
Author

@patrickvonplaten Yes, Thank You!

@github-actions
Copy link

github-actions bot commented Apr 8, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

4 participants