Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Will all of the model weights be saved? #1

Open
function2-llx opened this issue Feb 21, 2024 · 2 comments
Open

Will all of the model weights be saved? #1

function2-llx opened this issue Feb 21, 2024 · 2 comments

Comments

@function2-llx
Copy link

Hello, thank you so much for sharing this repository! I want to ask a question about the implementation.

According to the code, the ModelCheckpoint callback is used to save the model weights. However, according to my understanding, it will save the whole model weights, while usually only the adapter weights is saved for memory efficiency. Could you please share your thoughts on this?

checkpoint_callback = ModelCheckpoint(
filename="{epoch}-{Val_F1_Score:.2f}",
monitor="Val_F1_Score",
mode="max",
verbose=True,
save_top_k=1,
)
# Run the training loop.
trainer = Trainer(
callbacks=[
EarlyStopping(
monitor="Val_F1_Score",
min_delta=config.min_delta,
patience=config.patience,
verbose=True,
mode="max",
),
checkpoint_callback,
],

@zjohn77
Copy link
Owner

zjohn77 commented Feb 22, 2024

Yes, that's a good observation. This checkpoint callback from PyTorch Lightning indeed saves all weights, including the frozen weights, when it would have been more efficient just checkpointing the fine-tuned weights.

Also, I think that this efficiency issue would exist not only in LoRA fine-tuning but also in linear probing, or for that matter any kind of fine-tuning. The checkpoint callback currently does not check for or handle what type of fine-tuning is happening. Maybe it should though (this could be a ticket for the PyTorch Lightning project). Do other checkpointing tools (for the base PyTorch, for the Hugging Face Trainer, or for Tensorflow) have such a functionality?

@function2-llx
Copy link
Author

FYI, I found an related issue in the PyTorch Lightning repository, where it is suggested that one may manually filter the weights to be saved by overriding the state_dict method of the LightningModule.
Lightning-AI/pytorch-lightning#19149 (comment)

Also, I believe the official way for saving weights of adapters from the peft library is to use the PeftModel.save_pretrained method.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants