-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
auto save model in lightning #613
Conversation
Overall, I felt opening a PR with the desired behavior would be more effective than explaining and discussing in an issue. I hope this will help resolve some of the rough edges around saving models and that we can work through the other framework callbacks to implement similar functionality that works with the existing framework conventions and resembles mlflow, wandb, etc. |
One thing to note: lightning will not overwrite existing files or clean up between runs. Instead, it will append a version number, so if you run the same code repeatedly, you will end up with a directory that tracks your entire history of model checkpoints instead of only the latest run:
If you are running a pipeline, this is probably fine since you can control if you want to delete that directory each time. We might also want to consider dropping the existing checkpoints directory in the dvclive callback if |
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
I think this gets us to a good place with logging models in lightning. In fact, comparing to other trackers, it feels a bit easier to manage the models this way in dvc. On a different machine, you can pull the lightning checkpoints dir and keep using lightning methods to load those checkpoints. With other trackers, once you are on a different machine, the only way to load models is using the experiment tracker's api. |
# Save model checkpoints. | ||
if self._log_model is True: | ||
self.experiment.log_artifact(checkpoint_callback.dirpath) | ||
# Log best model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT of creating a copy in "dvclive" folder (or in the checkpoints folder itself), at least for the best?
It seems that we would be changing the path of the registered model between experiments in the current behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, which I guess is also what other trackers do AFAICT? Do you think the path matters? Maybe it makes it easier to dvc get
later, although we could make that work by the artifact name. No strong opinion from me.
I am ok with moving forward in this direction and prioritizing similar behavior in the other (most used) frameworks. We should invest some time in properly documenting the behavior and expected workflow (how to use the dvc-tracked artifacts later), though |
Didn't look in details, but seems like the other loggers use _scan_checkpoints to only track the ones related to the current experiment |
Great idea. I'll look into it. |
Lightning warns if the directory is not empty: UserWarning: Checkpoint directory /Users/dave/Code/lstm_seq2seq/model exists and is not empty.
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Removing the files only happens after the checkpoint is saved, so sometimes the first checkpoint will still get a version number like: $ tree model
model
├── epoch=0-step=2-v1.ckpt # saved this checkpoint before previous one was dropped
├── epoch=1-step=4.ckpt
├── epoch=2-step=6.ckpt
├── epoch=3-step=8.ckpt
└── epoch=4-step=10.ckpt Overall, this works and probably meets most user's expectations, so I think we should keep it, but I don't feel strongly that it outweighs the added complexity or potential surprise that dvclive is deleting model checkpoints. |
ping @daavoo |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #613 +/- ##
==========================================
- Coverage 89.47% 88.06% -1.42%
==========================================
Files 44 43 -1
Lines 2994 3042 +48
Branches 250 260 +10
==========================================
Hits 2679 2679
- Misses 276 324 +48
Partials 39 39
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a high level, the code and description make sense to me.
I have not actually tried in a project the different options, but the test looks reasonable so trusting that.
if str(p) not in self._all_checkpoint_paths: | ||
p.unlink(missing_ok=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be clear about this in the docs
I think I would like to have a |
This PR auto logs models with
dvclive.lightning.DVCLiveLogger(log_model=True)
:log_model
follows the conventions in mlflow and wandb:False
saves no models (this is the default).True
saves all model checkpoints at the end of training."all"
saves all model checkpoints whenever a model checkpoint is saved.If
log_model
isTrue
or"all"
, dvclive caches the entire checkpoints folder.Dvclive will also add a model artifact named "best" at the end of training that references the best model checkpoint. (edit: this resembles the best alias in wandb)
Edit: example
dvclive/dvc.yaml
output:To support this,
log_artifact
was also changed:dvc.yaml:artifacts
if some metadata is provided (type, name, desc, labels, or meta). This is a breaking change, but I can't see how anyone is making use of this without any metadata since it won't be used by the model registry.cache
kwarg tolog_artifact
(defaults toTrue
) so that it's possible to add the artifact metadata without caching the object.Related:
log_artifact
: external and non-DVC tracked files support #551dvc add
fromlog_artifact
#572log_model
#586To do: