Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] log a warning when there are missing keys in the LoRA loading. #9622

Merged
merged 20 commits into from
Oct 16, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Oct 9, 2024

What does this PR do?

This came up during an investigation of the Xlabs LoRAs (internal link). It turns out that a couple of keys are missing after the set_peft_model_state_dict() call. This affects the results.

This PR adds a warning for missing keys and unexpected keys similar to https://github.com/huggingface/peft/pull/2084/files#diff-95787fcc5356568929e383449ade2c1bac5da43deccab17318652b62ed171ae7R1181-R1199. Thanks to @BenjaminBossan for the idea. I have added the necessary tests for this, too.

Turns out that this PR and the PEFT PRs referenced by @BenjaminBossan below solve a couple of issues: #9622 (comment).

Once I have an approval from Ben, will request for Yiyi's reviews.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

About

missing_keys = getattr(incompatible_keys, "missing_keys", None)

Does that work as expected? I thought that we'll always get missing keys that stem from the base model, which is why in PEFT we filter out only the keys with the prefix (in this case it would be "lora_"). Or is this not applicable for some reason?

Regarding the testing, we would probably have to take a valid LoRA adapter (or create one), edit the state dict to delete an entry, and then try to load that adapter. There is a test along these lines in PEFT here:

https://github.com/huggingface/peft/pull/2118/files#diff-df9ecc7077bee932f56e76161ada47693d73acd3ed175a5b9a9158cfe03ec381R1517-R1546

@sayakpaul
Copy link
Member Author

Does that work as expected? I thought that we'll always get missing keys that stem from the base model, which is why in PEFT we filter out only the keys with the prefix (in this case it would be "lora_"). Or is this not applicable for some reason?

If we try to do:

from diffusers import DiffusionPipeline
import torch

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
pipe.load_lora_weights("XLabs-AI/flux-RealismLora")

image = pipe("a mecha robot", generator=torch.manual_seed(1)).images[0]
image.save("xlabs.png")

It correctly gives:

Warning
Loading adapter weights from state_dict led to missing keys in the model:  ['single_transformer_blocks.0.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.0.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.0.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.0.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.0.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.0.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.1.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.1.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.1.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.1.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.1.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.1.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.2.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.2.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.2.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.2.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.2.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.2.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.3.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.3.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.3.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.3.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.3.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.3.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.4.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.4.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.4.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.4.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.4.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.4.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.5.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.5.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.5.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.5.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.5.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.5.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.6.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.6.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.6.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.6.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.6.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.6.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.7.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.7.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.7.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.7.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.7.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.7.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.8.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.8.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.8.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.8.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.8.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.8.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.9.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.9.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.9.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.9.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.9.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.9.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.10.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.10.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.10.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.10.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.10.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.10.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.11.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.11.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.11.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.11.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.11.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.11.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.12.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.12.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.12.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.12.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.12.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.12.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.13.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.13.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.13.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.13.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.13.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.13.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.14.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.14.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.14.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.14.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.14.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.14.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.15.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.15.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.15.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.15.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.15.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.15.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.16.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.16.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.16.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.16.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.16.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.16.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.17.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.17.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.17.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.17.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.17.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.17.attn.to_v.lora_B.default_0.weight', 'single_transformer_blocks.18.attn.to_q.lora_A.default_0.weight', 'single_transformer_blocks.18.attn.to_q.lora_B.default_0.weight', 'single_transformer_blocks.18.attn.to_k.lora_A.default_0.weight', 'single_transformer_blocks.18.attn.to_k.lora_B.default_0.weight', 'single_transformer_blocks.18.attn.to_v.lora_A.default_0.weight', 'single_transformer_blocks.18.attn.to_v.lora_B.default_0.weight']. 

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Oct 9, 2024

Indeed the warning looks correct although I'm not sure why it works. When I jump into the debugger (here), I see 1274 missing keys though, when it really should be 114 missing keys. When I copy your fix, my warning contains 1274 keys.

@sayakpaul
Copy link
Member Author

Indeed the warning looks correct although I'm not sure why it works. When I jump into the debugger (here), I see 1274 missing keys though, when it really should be 114 missing keys. When I copy your fix, my warning contains 1274 keys.

Your hunch was correct all along. 58da14b has addressed this. Does this work for you?

@BenjaminBossan
Copy link
Member

Yes, that change should do it. In PEFT, we check for "lora_", not "lora", which is probably irrelevant for 99.99% of cases, but just to let you know.

@sayakpaul
Copy link
Member Author

@BenjaminBossan ready for your review

@sayakpaul sayakpaul marked this pull request as ready for review October 10, 2024 07:54
@sayakpaul
Copy link
Member Author

sayakpaul commented Oct 10, 2024

@BenjaminBossan some interesting findings here.

When I ran the following with this PR

from diffusers import DiffusionPipeline
import torch

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
pipe.load_lora_weights("XLabs-AI/flux-RealismLora")

image = pipe("a mecha robot", generator=torch.manual_seed(1)).images[0]
image.save("xlabs.png")

I get:

Loading adapter weights from state_dict led to missing keys in the model: 

single_transformer_blocks.0.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.0.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.0.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.0.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.0.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.0.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.1.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.1.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.1.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.1.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.1.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.1.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.2.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.2.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.2.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.2.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.2.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.2.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.3.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.3.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.3.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.3.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.3.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.3.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.4.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.4.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.4.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.4.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.4.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.4.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.5.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.5.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.5.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.5.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.5.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.5.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.6.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.6.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.6.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.6.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.6.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.6.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.7.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.7.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.7.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.7.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.7.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.7.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.8.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.8.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.8.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.8.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.8.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.8.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.9.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.9.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.9.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.9.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.9.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.9.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.10.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.10.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.10.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.10.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.10.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.10.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.11.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.11.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.11.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.11.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.11.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.11.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.12.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.12.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.12.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.12.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.12.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.12.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.13.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.13.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.13.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.13.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.13.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.13.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.14.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.14.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.14.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.14.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.14.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.14.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.15.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.15.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.15.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.15.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.15.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.15.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.16.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.16.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.16.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.16.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.16.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.16.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.17.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.17.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.17.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.17.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.17.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.17.attn.to_v.lora_B.default_0.weight, single_transformer_blocks.18.attn.to_q.lora_A.default_0.weight, single_transformer_blocks.18.attn.to_q.lora_B.default_0.weight, single_transformer_blocks.18.attn.to_k.lora_A.default_0.weight, single_transformer_blocks.18.attn.to_k.lora_B.default_0.weight, single_transformer_blocks.18.attn.to_v.lora_A.default_0.weight, single_transformer_blocks.18.attn.to_v.lora_B.default_0.weight

But,

target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})

doesn't actually contain ANY single_transformer_blocks whatsoever. So, isn't the warning on missing_keys weird or am I missing something?

@apolinario internally reported https://huggingface.co/dataautogpt3/FLUX-SyntheticAnime to be causing something similar to this. But when I did:

from diffusers import DiffusionPipeline
import torch

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
pipe.load_lora_weights("dataautogpt3/FLUX-SyntheticAnime")

image = pipe("1980s anime screengrab, VHS quality, a woman with her face glitching and disorted, a halo above her head", generator=torch.manual_seed(1)).images[0]
image.save("syn_anime.png")

I didn't get any warnings on unexpected keys or missing keys either.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Oct 10, 2024

doesn't actually contain ANY single_transformer_blocks whatsoever. So, isn't the warning on missing_keys weird or am I missing something?

Well, the keys are missing, right? Therefore, you can't find them in the state_dict. But when we look at the model, we see:

FluxTransformer2DModel(
  [...]
  (single_transformer_blocks): ModuleList(
    (0-18): 19 x FluxSingleTransformerBlock(
      (norm): AdaLayerNormZeroSingle(
        (silu): SiLU()
        (linear): Linear(in_features=3072, out_features=9216, bias=True)
        (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      )
      (proj_mlp): Linear(in_features=3072, out_features=12288, bias=True)
      (act_mlp): GELU(approximate='tanh')
      (proj_out): Linear(in_features=15360, out_features=3072, bias=True)
      (attn): Attention(
        (norm_q): RMSNorm()
        (norm_k): RMSNorm()
        (to_q): lora.Linear(
          (base_layer): Linear(in_features=3072, out_features=3072, bias=True)
          (lora_dropout): ModuleDict(
            (default_0): Identity()
          )
          (lora_A): ModuleDict(
            (default_0): Linear(in_features=3072, out_features=16, bias=False)
          )
          (lora_B): ModuleDict(
            (default_0): Linear(in_features=16, out_features=3072, bias=False)
          )
          (lora_embedding_A): ParameterDict()
          (lora_embedding_B): ParameterDict()
          (lora_magnitude_vector): ModuleDict()
        )
        (to_k): lora.Linear(
          (base_layer): Linear(in_features=3072, out_features=3072, bias=True)
          (lora_dropout): ModuleDict(
            (default_0): Identity()
          )
          (lora_A): ModuleDict(
            (default_0): Linear(in_features=3072, out_features=16, bias=False)
          )
          (lora_B): ModuleDict(
            (default_0): Linear(in_features=16, out_features=3072, bias=False)
          )
          (lora_embedding_A): ParameterDict()
          (lora_embedding_B): ParameterDict()
          (lora_magnitude_vector): ModuleDict()
        )
        (to_v): lora.Linear(
          (base_layer): Linear(in_features=3072, out_features=3072, bias=True)
          (lora_dropout): ModuleDict(
            (default_0): Identity()
          )
          (lora_A): ModuleDict(
            (default_0): Linear(in_features=3072, out_features=16, bias=False)
          )
          (lora_B): ModuleDict(
            (default_0): Linear(in_features=16, out_features=3072, bias=False)
          )
          (lora_embedding_A): ParameterDict()
          (lora_embedding_B): ParameterDict()
          (lora_magnitude_vector): ModuleDict()
        )
      )
    )
  [...]

So there is a part of the model called "single_transformer_blocks" that has LoRA layers. But should it have LoRA layers?

I checked the target_modules and found:

['add_q_proj'', '4.attn.to_v'', 'to_out.0'', '1.attn.to_k'', '1.attn.to_q'', '16.attn.to_k'', '6.attn.to_q'', '15.attn.to_q'', '3.attn.to_q'', '10.attn.to_v'', '3.attn.to_k'', '17.attn.to_q'', '9.attn.to_q'', '7.attn.to_v'', '12.attn.to_k'', '10.attn.to_q'', '12.attn.to_q'', '16.attn.to_q'', '18.attn.to_k'', '13.attn.to_q'', '14.attn.to_k'', '4.attn.to_k'', '7.attn.to_k'', '11.attn.to_q'', '4.attn.to_q'', '8.attn.to_q'', '2.attn.to_q'', '0.attn.to_v'', '2.attn.to_k'', '7.attn.to_q'', '11.attn.to_k'', 'to_add_out'', '12.attn.to_v'', '9.attn.to_k'', '17.attn.to_k'', '3.attn.to_v'', '14.attn.to_v'', '18.attn.to_q'', '13.attn.to_k'', '16.attn.to_v'', '1.attn.to_v'', '5.attn.to_k'', '13.attn.to_v'', '8.attn.to_v'', 'add_v_proj'', '9.attn.to_v'', '17.attn.to_v'', '0.attn.to_q'', '5.attn.to_q'', '5.attn.to_v'', '15.attn.to_v'', '18.attn.to_v'', '0.attn.to_k'', '8.attn.to_k'', '15.attn.to_k'', 'add_k_proj'', '6.attn.to_k'', '14.attn.to_q'', '10.attn.to_k'', '11.attn.to_v'', '6.attn.to_v'', '2.attn.to_v']

This looks very suspicious to me. Keys like "10.attn.to_q" do match "single_transformer_blocks" and explain why it has LoRA layers but is the key right in the first place?

I checked if it could be due to huggingface/peft#2045 and indeed, if I remove this optimization, the error goes away because "single_transformer_blocks" is no longer targeted. So what this means is that the optimization leads to the wrong target_modules being defined. As a consequence, we establish LoRA layers where there shouldn't be any, which results in these missing keys.

Targeting these missing LoRA layers before we had low_cpu_mem_usage=True did not result in errors because we would initialize those LoRA weights to be identity transforms, so they didn't do anything (except slowing down inference a bit). But now, with low_cpu_mem_usage=True, this is no longer true and we see the error in the generated image.

I'll investigate why we get those incorrect target modules but more and more I get the impression that this PEFT optimization was a bad idea.

Edit: Bugfix here: huggingface/peft#2144. Still wondering if this whole feature should be rolled back though.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Oct 10, 2024
Solves the following bug:

huggingface/diffusers#9622 (comment)

The cause for the bug is as follows: When we have, say, a module called
"bar.0.query" that we want to target and another module called
"foo_bar.0.query" that we don't want to target, there was potential for
an error. This is not caused by _find_minimal_target_modules directly,
but rather the bug was inside of BaseTuner.inject_adapter and how the
names_no_target were chosen. Those used to be chosen based on suffix. In
our example, however, "bar.0.query" is a suffix of "foo_bar.0.query",
therefore "foo_bar.0.query" was *not* added to names_no_target when it
should have. As a consequence, during the optimization, it looks like
"query" is safe to use as target_modules because we don't see that it
wrongly matches "foo_bar.0.query".
BenjaminBossan added a commit to BenjaminBossan/transformers that referenced this pull request Oct 10, 2024
When loading a LoRA adapter, so far, there was only a warning when there
were unexpected keys in the checkpoint. Now, there is also a warning
when there are missing keys.

This change is consistent with
huggingface/peft#2118 in PEFT and the planned PR
huggingface/diffusers#9622 in diffusers.

Apart from this change, the error message for unexpected keys was
slightly altered for consistency (it should be more readable now). Also,
besides adding a test for the missing keys warning, a test for
unexpected keys warning was also added, as it was missing so far.
@sayakpaul
Copy link
Member Author

sayakpaul commented Oct 10, 2024

Still wondering if this whole feature should be rolled back though.

You mean the _find_minimal_target_modules() feature or low_cpu_mem_usage?

Targeting these missing LoRA layers before we had low_cpu_mem_usage=True did not result in errors because we would initialize those LoRA weights to be identity transforms, so they didn't do anything (except slowing down inference a bit). But now, with low_cpu_mem_usage=True, this is no longer true and we see the error in the generated image.

I guess you meant

Targeting these missing LoRA layers before we had low_cpu_mem_usage=False

BenjaminBossan added a commit to huggingface/peft that referenced this pull request Oct 10, 2024
Solves the following bug:

huggingface/diffusers#9622 (comment)

The cause for the bug is as follows: When we have, say, a module called
"bar.0.query" that we want to target and another module called
"foo_bar.0.query" that we don't want to target, there was potential for
an error. This is not caused by _find_minimal_target_modules directly,
but rather the bug was inside of BaseTuner.inject_adapter and how the
names_no_target were chosen. Those used to be chosen based on suffix. In
our example, however, "bar.0.query" is a suffix of "foo_bar.0.query",
therefore "foo_bar.0.query" was *not* added to names_no_target when it
should have. As a consequence, during the optimization, it looks like
"query" is safe to use as target_modules because we don't see that it
wrongly matches "foo_bar.0.query".
@BenjaminBossan
Copy link
Member

You mean the _find_minimal_target_modules() feature or low_cpu_mem_usage?

The _find_minimal_target_modules feature. It's already the second bug after this one :-/

I guess you meant

What I meant: When we have low_cpu_mem_usage=False, the error would not surface for the reasons I mentioned. Same for the time before the low_cpu_mem_usage option was introduced.

@sayakpaul
Copy link
Member Author

Okay thanks for explaining, @BenjaminBossan!

I guess this PR should be in a good state to review now.

Edit: Bugfix here: huggingface/peft#2144. Still wondering if this whole feature should be rolled back though.

  1. Do you think it makes sense to do a patch release for this?
  2. I think it's okay to keep it for now, honestly as it does help improve things quite a bit as long as we're able to catch the edge cases nicely.

@sayakpaul
Copy link
Member Author

@apolinario can you run some tests, ensuring you're on peft and transformers from main and diffusers from this PR branch?

@BenjaminBossan
Copy link
Member

  1. Do you think it makes sense to do a patch release for this?

Depending on the urgency, I can look into making a patch release.

2. I think it's okay to keep it for now, honestly as it does help improve things quite a bit as long as we're able to catch the edge cases nicely.

At least, it should be easier to catch possible errors with the new warning. I wonder if we can check a bunch of external LoRAs to see if there are more warnings?

@sayakpaul
Copy link
Member Author

At least, it should be easier to catch possible errors with the new warning. I wonder if we can check a bunch of external LoRAs to see if there are more warnings?

Will trigger the integration tests from https://github.com/huggingface/diffusers/blob/main/tests/lora/ tomorrow and see.

@apolinario
Copy link
Collaborator

@apolinario can you run some tests, ensuring you're on peft and transformers from main and diffusers from this PR branch?

With this combination of versions I can confirm that:

  1. The LoRAs produce expected results (not jumbled)
  2. set_adapter is working combining this LoRAs with the ones that were previously breaking
  3. The warning messages are not displayed, but that is because it is working -- hence no key issues

I think huggingface/peft#2144 fixed the broken LoRAs issues I was experiencing, which imo could merit a patch @BenjaminBossan as it will fix a few LoRAs for quite a few people where it's currently failing silently; great work folks!

BenjaminBossan added a commit to huggingface/peft that referenced this pull request Oct 11, 2024
Solves the following bug:

huggingface/diffusers#9622 (comment)

The cause for the bug is as follows: When we have, say, a module called
"bar.0.query" that we want to target and another module called
"foo_bar.0.query" that we don't want to target, there was potential for
an error. This is not caused by _find_minimal_target_modules directly,
but rather the bug was inside of BaseTuner.inject_adapter and how the
names_no_target were chosen. Those used to be chosen based on suffix. In
our example, however, "bar.0.query" is a suffix of "foo_bar.0.query",
therefore "foo_bar.0.query" was *not* added to names_no_target when it
should have. As a consequence, during the optimization, it looks like
"query" is safe to use as target_modules because we don't see that it
wrongly matches "foo_bar.0.query".
@BenjaminBossan
Copy link
Member

which imo could merit a patch

Done: https://github.com/huggingface/peft/releases/tag/v0.13.2.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM.

saqlain2204 and others added 3 commits October 12, 2024 00:20
* Added diff diff support for kolors img2img

* Fized relative imports

* Fized relative imports

* Added diff diff support for Kolors

* Fized import issues

* Added map

* Fized import issues

* Fixed naming issues

* Added diffdiff support for Kolors img2img pipeline

* Removed example docstrings

* Added map input

* Updated latents

Co-authored-by: Álvaro Somoza <[email protected]>

* Updated `original_with_noise`

Co-authored-by: Álvaro Somoza <[email protected]>

* Improved code quality

---------

Co-authored-by: Álvaro Somoza <[email protected]>
@sayakpaul
Copy link
Member Author

@yiyixuxu could you give this a look too?

I have run the integration tests for SD and SDXL LoRAs and they are already passing. Will run the Flux LoRA integration tests separately on a 24GB card.

@@ -208,9 +209,11 @@ def test_flux_the_last_ben(self):
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090])
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the recent fixes in peft. The tests would pass with peft==0.12.0.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@sayakpaul sayakpaul merged commit cef4f65 into main Oct 16, 2024
16 checks passed
@sayakpaul sayakpaul deleted the handle-missing-keys-lora branch October 16, 2024 02:16
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Oct 22, 2024
Solves the following bug:

huggingface/diffusers#9622 (comment)

The cause for the bug is as follows: When we have, say, a module called
"bar.0.query" that we want to target and another module called
"foo_bar.0.query" that we don't want to target, there was potential for
an error. This is not caused by _find_minimal_target_modules directly,
but rather the bug was inside of BaseTuner.inject_adapter and how the
names_no_target were chosen. Those used to be chosen based on suffix. In
our example, however, "bar.0.query" is a suffix of "foo_bar.0.query",
therefore "foo_bar.0.query" was *not* added to names_no_target when it
should have. As a consequence, during the optimization, it looks like
"query" is safe to use as target_modules because we don't see that it
wrongly matches "foo_bar.0.query".
ArthurZucker pushed a commit to huggingface/transformers that referenced this pull request Oct 24, 2024
When loading a LoRA adapter, so far, there was only a warning when there
were unexpected keys in the checkpoint. Now, there is also a warning
when there are missing keys.

This change is consistent with
huggingface/peft#2118 in PEFT and the planned PR
huggingface/diffusers#9622 in diffusers.

Apart from this change, the error message for unexpected keys was
slightly altered for consistency (it should be more readable now). Also,
besides adding a test for the missing keys warning, a test for
unexpected keys warning was also added, as it was missing so far.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants