Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency between get_nb_trainable_parameters and num_parameters(only_trainable=True) for prompt tuning #1526

Closed
4 tasks
kmehant opened this issue Mar 4, 2024 · 9 comments

Comments

@kmehant
Copy link
Contributor

kmehant commented Mar 4, 2024

System Info

peft==0.8.2
transformers==4.37.2

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from peft import get_peft_model, PromptTuningConfig

from transformers import AutoConfig, AutoModelForCausalLM

model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained("bigcode/starcoderbase-1b"))

pt_config = PromptTuningConfig(
    peft_type="PROMPT_TUNING",
    task_type="CAUSAL_LM",
    num_virtual_tokens=20,
    prompt_tuning_init="TEXT",
    prompt_tuning_init_text="some prompt tuning init text",
    tokenizer_name_or_path="path to tokenizer",
)

model_pt = get_peft_model(model, pt_config)

model_pt.get_nb_trainable_parameters() # output 1 (first element in the tuple)
model_pt.num_parameters(only_trainable=True) # output 2

# output 1 does not match output 2

Expected behavior

[UPDATED]

we should have trainable paramters > 0 this is consistent with LoRA. For CodeLlama 7B I get

model_lora.get_nb_trainable_parameters() # (8388608, 6746935296)

model_lora.num_parameters(only_trainable=True) # 8388608

However, for promptuning,

model_lora.get_nb_trainable_parameters() # (81920, 6738628608)

model_lora.num_parameters(only_trainable=True) # 0

I agree with the outputs from get_nb_trainable_parameters() However, trying to understand this inconsistent behaviour of num_parameters(only_trainable=True) for LoRA and prompt tuning techniques.

@BenjaminBossan
Copy link
Member

When I tried, I got 40960 and 0 parameters, respectively. But that isn't too surprising. The num_parameters method is not defined in the PeftModel (i.e. the thing you get back after calling get_peft_model). Therefore, the method call is delegated to the base model itself, i.e. the starcoderbase. As the base model is frozen, we get back 0. As for the PeftModel, it adds new trainable parameters, or else fine-tuning wouldn't work, so that's why we get a number > 0 back.

@kmehant
Copy link
Contributor Author

kmehant commented Mar 4, 2024

#1526 (comment)

@BenjaminBossan

Right we should have trainable paramters > 0 this is consistent with LoRA. For CodeLlama 7B I get

model_lora.get_nb_trainable_parameters() # (8388608, 6746935296)

model_lora.num_parameters(only_trainable=True) # 8388608

However, for promptuning,

model_lora.get_nb_trainable_parameters() # (81920, 6738628608)

model_lora.num_parameters(only_trainable=True) # 0

I agree with the outputs from get_nb_trainable_parameters() However, trying to understand this inconsistent behaviour of num_parameters(only_trainable=True) for LoRA and prompt tuning techniques. (just updated the issue as well with this info)

When I tried, I got 40960 and 0 parameters, respectively. But that isn't too surprising. The num_parameters method is not defined in the PeftModel (i.e. the thing you get back after calling get_peft_model). Therefore, the method call is delegated to the base model itself, i.e. the starcoderbase. As the base model is frozen, we get back 0.

I see, does that mean LoRA is modifying the base model in place with the adapters?

@kmehant
Copy link
Contributor Author

kmehant commented Mar 4, 2024

When I tried, I got 40960 and 0 parameters, respectively. But that isn't too surprising. The num_parameters method is not defined in the PeftModel (i.e. the thing you get back after calling get_peft_model). Therefore, the method call is delegated to the base model itself, i.e. the starcoderbase. As the base model is frozen, we get back 0.

I see for prompt tuning, the promt_encoder is added as it is meant it wont modify the base model so num_parameters(only_trainable=True) reports it to be 0

@kmehant
Copy link
Contributor Author

kmehant commented Mar 4, 2024

@BenjaminBossan

Do you think overriding num_parameters() in the PeftModel would make it less confusing or you think this inconsistent behaviour should be easily be understood? Though num_parameters() this is from the parent class, as it exposed with PeftModel, I am just worried this might be misunderstood until digged into the code.

@BenjaminBossan
Copy link
Member

I see, does that mean LoRA is modifying the base model in place with the adapters?

Yes, exactly, this is the reason.

I agree that it could be confusing. In the PEFT docs, we only advertise print_trainable_parameters, which calls get_nb_trainable_parameters under the hood. I guess we could override num_parameters() to avoid confusion, but we could also leave it there, as it serves a slightly different purpose. Maybe we could just document the difference in the docstring of print_trainable_parameters, WDYT?

@kmehant
Copy link
Contributor Author

kmehant commented Mar 4, 2024

I guess we could override num_parameters() to avoid confusion, but we could also leave it there, as it serves a slightly different purpose. Maybe we could just document the difference in the docstring of print_trainable_parameters, WDYT?

Your call @BenjaminBossan. I happy to raise a PR for either of them.

@BenjaminBossan
Copy link
Member

I think documenting is the better solution. Overriding num_parameters() could theoretically even break some existing code, even if it's unlikely. It would be great if you opened a PR, thanks.

@kmehant
Copy link
Contributor Author

kmehant commented Mar 4, 2024

@BenjaminBossan raised a PR - #1531 Thanks.

@kmehant
Copy link
Contributor Author

kmehant commented Mar 6, 2024

fixed in #1531

@kmehant kmehant closed this as completed Mar 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants