diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 4ce756cf..7dc2c0f6 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -1,6 +1,7 @@ # ruff: noqa: ARG002 import inspect -from typing import Any, Dict, Optional +from pathlib import Path +from typing import Any, Dict, List, Optional, Union try: from lightning.fabric.utilities.logger import ( @@ -8,7 +9,9 @@ _sanitize_callable_params, _sanitize_params, ) + from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment + from lightning.pytorch.loggers.utilities import _scan_checkpoints from lightning.pytorch.utilities import rank_zero_only except ImportError: from lightning_fabric.utilities.logger import ( @@ -16,7 +19,9 @@ _sanitize_callable_params, _sanitize_params, ) + from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment + from pytorch_lightning.utilities.logger import _scan_checkpoints from pytorch_lightning.utilities import rank_zero_only from torch import is_tensor @@ -44,10 +49,11 @@ def _should_call_next_step(): class DVCLiveLogger(Logger): - def __init__( + def __init__( # noqa: PLR0913 self, run_name: Optional[str] = "dvclive_run", prefix="", + log_model: Union[str, bool] = False, experiment=None, dir: Optional[str] = None, # noqa: A002 resume: bool = False, @@ -72,6 +78,10 @@ def __init__( if report == "notebook": # Force Live instantiation self.experiment # noqa: B018 + self._log_model = log_model + self._logged_model_time: Dict[str, float] = {} + self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._all_checkpoint_paths: List[str] = [] @property def name(self): @@ -131,6 +141,42 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): self.experiment._latest_studio_step -= 1 # noqa: SLF001 self.experiment.next_step() + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + if self._log_model in [True, "all"]: + self._checkpoint_callback = checkpoint_callback + self._scan_checkpoints(checkpoint_callback) + if self._log_model == "all" or ( + self._log_model is True and checkpoint_callback.save_top_k == -1 + ): + self._save_checkpoints(checkpoint_callback) + @rank_zero_only def finalize(self, status: str) -> None: + # Log best model. + if self._checkpoint_callback: + self._scan_checkpoints(self._checkpoint_callback) + self._save_checkpoints(self._checkpoint_callback) + best_model_path = self._checkpoint_callback.best_model_path + self.experiment.log_artifact( + best_model_path, name="best", type="model", cache=False + ) self.experiment.end() + + def _scan_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + # get checkpoints to be saved with associated score + checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) + + # update model time and append path to list of all checkpoints + for t, p, _, _ in checkpoints: + self._logged_model_time[p] = t + self._all_checkpoint_paths.append(p) + + def _save_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + # drop unused checkpoints + if not self._experiment._resume: # noqa: SLF001 + for p in Path(checkpoint_callback.dirpath).iterdir(): + if str(p) not in self._all_checkpoint_paths: + p.unlink(missing_ok=True) + + # save directory + self.experiment.log_artifact(checkpoint_callback.dirpath) diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index 8aa4fbf5..e572b235 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -8,8 +8,9 @@ try: import torch - from pytorch_lightning import LightningModule - from pytorch_lightning.trainer import Trainer + from lightning import LightningModule + from lightning.pytorch import Trainer + from lightning.pytorch.callbacks import ModelCheckpoint from torch import nn from torch.nn import functional as F # noqa: N812 from torch.optim import SGD, Adam @@ -18,7 +19,7 @@ from dvclive import Live from dvclive.lightning import DVCLiveLogger except ImportError: - pytest.skip("skipping pytorch_lightning tests", allow_module_level=True) + pytest.skip("skipping lightning tests", allow_module_level=True) class XORDataset(Dataset): @@ -161,6 +162,38 @@ def test_lightning_kwargs(tmp_dir): assert dvclive_logger.experiment._cache_images is True +@pytest.mark.parametrize("log_model", [False, True, "all"]) +@pytest.mark.parametrize("save_top_k", [1, -1]) +def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k): + model = LitXOR() + dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model) + checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k) + trainer = Trainer( + logger=dvclive_logger, + max_epochs=2, + log_every_n_steps=1, + callbacks=[checkpoint], + ) + log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact") + trainer.fit(model) + + # Check that log_artifact is called. + if log_model is False: + log_artifact.assert_not_called() + elif (log_model is True) and (save_top_k != -1): + # called once to cache, then again to log best artifact + assert log_artifact.call_count == 2 + else: + # once per epoch plus two calls at the end (see above) + assert log_artifact.call_count == 4 + + # Check that checkpoint files does not grow with each run. + num_checkpoints = len(os.listdir(tmp_dir / "model")) + if log_model in [True, "all"]: + trainer.fit(model) + assert len(os.listdir(tmp_dir / "model")) == num_checkpoints + + def test_lightning_steps(tmp_dir, mocker): model = LitXOR() # Handle kwargs passed to Live.