Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto save model in lightning #613

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions src/dvclive/lightning.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
# ruff: noqa: ARG002
import inspect
from typing import Any, Dict, Optional
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

try:
from lightning.fabric.utilities.logger import (
_convert_params,
_sanitize_callable_params,
_sanitize_params,
)
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.loggers.utilities import _scan_checkpoints
from lightning.pytorch.utilities import rank_zero_only
except ImportError:
from lightning_fabric.utilities.logger import (
_convert_params,
_sanitize_callable_params,
_sanitize_params,
)
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.logger import _scan_checkpoints
from pytorch_lightning.utilities import rank_zero_only

from torch import is_tensor
Expand Down Expand Up @@ -44,10 +49,11 @@ def _should_call_next_step():


class DVCLiveLogger(Logger):
def __init__(
def __init__( # noqa: PLR0913
self,
run_name: Optional[str] = "dvclive_run",
prefix="",
log_model: Union[str, bool] = False,
experiment=None,
dir: Optional[str] = None, # noqa: A002
resume: bool = False,
Expand All @@ -72,6 +78,10 @@ def __init__(
if report == "notebook":
# Force Live instantiation
self.experiment # noqa: B018
self._log_model = log_model
self._logged_model_time: Dict[str, float] = {}
self._checkpoint_callback: Optional[ModelCheckpoint] = None
self._all_checkpoint_paths: List[str] = []

@property
def name(self):
Expand Down Expand Up @@ -131,6 +141,42 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
self.experiment._latest_studio_step -= 1 # noqa: SLF001
self.experiment.next_step()

def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
if self._log_model in [True, "all"]:
self._checkpoint_callback = checkpoint_callback
self._scan_checkpoints(checkpoint_callback)
if self._log_model == "all" or (
self._log_model is True and checkpoint_callback.save_top_k == -1
):
self._save_checkpoints(checkpoint_callback)

@rank_zero_only
def finalize(self, status: str) -> None:
# Log best model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT of creating a copy in "dvclive" folder (or in the checkpoints folder itself), at least for the best?

It seems that we would be changing the path of the registered model between experiments in the current behavior

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, which I guess is also what other trackers do AFAICT? Do you think the path matters? Maybe it makes it easier to dvc get later, although we could make that work by the artifact name. No strong opinion from me.

if self._checkpoint_callback:
self._scan_checkpoints(self._checkpoint_callback)
self._save_checkpoints(self._checkpoint_callback)
best_model_path = self._checkpoint_callback.best_model_path
self.experiment.log_artifact(
best_model_path, name="best", type="model", cache=False
)
self.experiment.end()

def _scan_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
# get checkpoints to be saved with associated score
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)

# update model time and append path to list of all checkpoints
for t, p, _, _ in checkpoints:
self._logged_model_time[p] = t
self._all_checkpoint_paths.append(p)

def _save_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
# drop unused checkpoints
if not self._experiment._resume: # noqa: SLF001
for p in Path(checkpoint_callback.dirpath).iterdir():
if str(p) not in self._all_checkpoint_paths:
p.unlink(missing_ok=True)
Comment on lines +178 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be clear about this in the docs


# save directory
self.experiment.log_artifact(checkpoint_callback.dirpath)
39 changes: 36 additions & 3 deletions tests/test_frameworks/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

try:
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.trainer import Trainer
from lightning import LightningModule
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import nn
from torch.nn import functional as F # noqa: N812
from torch.optim import SGD, Adam
Expand All @@ -18,7 +19,7 @@
from dvclive import Live
from dvclive.lightning import DVCLiveLogger
except ImportError:
pytest.skip("skipping pytorch_lightning tests", allow_module_level=True)
pytest.skip("skipping lightning tests", allow_module_level=True)


class XORDataset(Dataset):
Expand Down Expand Up @@ -161,6 +162,38 @@ def test_lightning_kwargs(tmp_dir):
assert dvclive_logger.experiment._cache_images is True


@pytest.mark.parametrize("log_model", [False, True, "all"])
@pytest.mark.parametrize("save_top_k", [1, -1])
def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k):
model = LitXOR()
dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model)
checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k)
trainer = Trainer(
logger=dvclive_logger,
max_epochs=2,
log_every_n_steps=1,
callbacks=[checkpoint],
)
log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact")
trainer.fit(model)

# Check that log_artifact is called.
if log_model is False:
log_artifact.assert_not_called()
elif (log_model is True) and (save_top_k != -1):
# called once to cache, then again to log best artifact
assert log_artifact.call_count == 2
else:
# once per epoch plus two calls at the end (see above)
assert log_artifact.call_count == 4

# Check that checkpoint files does not grow with each run.
num_checkpoints = len(os.listdir(tmp_dir / "model"))
if log_model in [True, "all"]:
trainer.fit(model)
assert len(os.listdir(tmp_dir / "model")) == num_checkpoints


def test_lightning_steps(tmp_dir, mocker):
model = LitXOR()
# Handle kwargs passed to Live.
Expand Down