You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, thank you so much for sharing this repository! I want to ask a question about the implementation.
According to the code, the ModelCheckpoint callback is used to save the model weights. However, according to my understanding, it will save the whole model weights, while usually only the adapter weights is saved for memory efficiency. Could you please share your thoughts on this?
Yes, that's a good observation. This checkpoint callback from PyTorch Lightning indeed saves all weights, including the frozen weights, when it would have been more efficient just checkpointing the fine-tuned weights.
Also, I think that this efficiency issue would exist not only in LoRA fine-tuning but also in linear probing, or for that matter any kind of fine-tuning. The checkpoint callback currently does not check for or handle what type of fine-tuning is happening. Maybe it should though (this could be a ticket for the PyTorch Lightning project). Do other checkpointing tools (for the base PyTorch, for the Hugging Face Trainer, or for Tensorflow) have such a functionality?
FYI, I found an related issue in the PyTorch Lightning repository, where it is suggested that one may manually filter the weights to be saved by overriding the state_dict method of the LightningModule. Lightning-AI/pytorch-lightning#19149 (comment)
Also, I believe the official way for saving weights of adapters from the peft library is to use the PeftModel.save_pretrained method.
Hello, thank you so much for sharing this repository! I want to ask a question about the implementation.
According to the code, the
ModelCheckpoint
callback is used to save the model weights. However, according to my understanding, it will save the whole model weights, while usually only the adapter weights is saved for memory efficiency. Could you please share your thoughts on this?lightning-mlflow-hf/lightning_mlflow/train.py
Lines 57 to 76 in cf1b6b9
The text was updated successfully, but these errors were encountered: