From 74f06094e8073e15a0d014436e28f3f465aa0191 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 3 Jul 2023 15:12:38 -0400 Subject: [PATCH 1/5] auto save model in lightning --- src/dvclive/lightning.py | 19 ++++++++++++++++++- src/dvclive/live.py | 40 ++++++++++++++++++++++------------------ 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index cfb8375e..968bfb99 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -1,12 +1,13 @@ # ruff: noqa: ARG002 import inspect -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from lightning.fabric.utilities.logger import ( _convert_params, _sanitize_callable_params, _sanitize_params, ) +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment from lightning.pytorch.utilities import rank_zero_only from torch import is_tensor @@ -38,6 +39,7 @@ def __init__( self, run_name: Optional[str] = "dvclive_run", prefix="", + log_model: Union[str, bool] = False, experiment=None, dir: Optional[str] = None, # noqa: A002 resume: bool = False, @@ -60,6 +62,8 @@ def __init__( if report == "notebook": # Force Live instantiation self.experiment # noqa: B018 + self._log_model = log_model + self._checkpoint_callback: Optional[ModelCheckpoint] = None @property def name(self): @@ -119,6 +123,19 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): self.experiment._latest_studio_step -= 1 # noqa: SLF001 self.experiment.next_step() + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + self._checkpoint_callback = checkpoint_callback + if self._log_model == "all": + self.experiment.log_artifact(checkpoint_callback.dirpath) + @rank_zero_only def finalize(self, status: str) -> None: + checkpoint_callback = self._checkpoint_callback + # Save model checkpoints. + if self._log_model is True: + self.experiment.log_artifact(checkpoint_callback.dirpath) + # Log best model. + if self._log_model in (True, "all"): + best_model_path = checkpoint_callback.best_model_path + self.experiment.log_artifact(best_model_path, name="best", cache=False) self.experiment.end() diff --git a/src/dvclive/live.py b/src/dvclive/live.py index d70b6a23..deb3b968 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -410,10 +410,11 @@ def log_artifact( path: StrPath, type: Optional[str] = None, # noqa: A002 name: Optional[str] = None, - desc: Optional[str] = None, # noqa: ARG002 - labels: Optional[List[str]] = None, # noqa: ARG002 - meta: Optional[Dict[str, Any]] = None, # noqa: ARG002 + desc: Optional[str] = None, + labels: Optional[List[str]] = None, + meta: Optional[Dict[str, Any]] = None, copy: bool = False, + cache: bool = True, ): """Tracks a local file or directory with DVC""" if not isinstance(path, (str, Path)): @@ -425,21 +426,24 @@ def log_artifact( if copy: path = clean_and_copy_into(path, self.artifacts_dir) - self.cache(path) - - name = name or Path(path).stem - if name_is_compatible(name): - self._artifacts[name] = { - k: v - for k, v in locals().items() - if k in ("path", "type", "desc", "labels", "meta") and v is not None - } - else: - logger.warning( - "Can't use '%s' as artifact name (ID)." - " It will not be included in the `artifacts` section.", - name, - ) + if cache: + self.cache(path) + + if any((type, name, desc, labels, meta)): + name = name or Path(path).stem + if name_is_compatible(name): + self._artifacts[name] = { + k: v + for k, v in locals().items() + if k in ("path", "type", "desc", "labels", "meta") + and v is not None + } + else: + logger.warning( + "Can't use '%s' as artifact name (ID)." + " It will not be included in the `artifacts` section.", + name, + ) def cache(self, path): try: From dbd354b95278ed7509d016d560bfad34bb8e93d7 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 4 Jul 2023 14:53:41 -0400 Subject: [PATCH 2/5] lightning: save model at each checkpoint if save_top_k == -1 --- src/dvclive/lightning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 968bfb99..43f6c747 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -125,7 +125,9 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: self._checkpoint_callback = checkpoint_callback - if self._log_model == "all": + if self._log_model == "all" or ( + self._log_model is True and checkpoint_callback.save_top_k == -1 + ): self.experiment.log_artifact(checkpoint_callback.dirpath) @rank_zero_only From 15932a8400a5202f0d7aa92159dccd9e927e092f Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 5 Jul 2023 08:18:39 -0400 Subject: [PATCH 3/5] add type: model to lightning artifact --- src/dvclive/lightning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 43f6c747..4aaff1b0 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -139,5 +139,7 @@ def finalize(self, status: str) -> None: # Log best model. if self._log_model in (True, "all"): best_model_path = checkpoint_callback.best_model_path - self.experiment.log_artifact(best_model_path, name="best", cache=False) + self.experiment.log_artifact( + best_model_path, name="best", type="model", cache=False + ) self.experiment.end() From 110b9aa575f15b802e6c9effb43375ded4761935 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 11 Jul 2023 14:34:20 -0400 Subject: [PATCH 4/5] lightning: drop unused checkpoints --- src/dvclive/lightning.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 8f33936b..39361fa7 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -1,6 +1,7 @@ # ruff: noqa: ARG002 import inspect -from typing import Any, Dict, Optional, Union +from pathlib import Path +from typing import Any, Dict, List, Optional, Union from lightning.fabric.utilities.logger import ( _convert_params, @@ -9,6 +10,7 @@ ) from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment +from lightning.pytorch.loggers.utilities import _scan_checkpoints from lightning.pytorch.utilities import rank_zero_only from torch import is_tensor @@ -35,7 +37,7 @@ def _should_call_next_step(): class DVCLiveLogger(Logger): - def __init__( + def __init__( # noqa: PLR0913 self, run_name: Optional[str] = "dvclive_run", prefix="", @@ -65,7 +67,9 @@ def __init__( # Force Live instantiation self.experiment # noqa: B018 self._log_model = log_model + self._logged_model_time: Dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._all_checkpoint_paths: List[str] = [] @property def name(self): @@ -130,14 +134,14 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: if self._log_model == "all" or ( self._log_model is True and checkpoint_callback.save_top_k == -1 ): - self.experiment.log_artifact(checkpoint_callback.dirpath) + self._save_checkpoints(checkpoint_callback) @rank_zero_only def finalize(self, status: str) -> None: checkpoint_callback = self._checkpoint_callback # Save model checkpoints. if self._log_model is True: - self.experiment.log_artifact(checkpoint_callback.dirpath) + self._save_checkpoints(checkpoint_callback) # Log best model. if self._log_model in (True, "all"): best_model_path = checkpoint_callback.best_model_path @@ -145,3 +149,22 @@ def finalize(self, status: str) -> None: best_model_path, name="best", type="model", cache=False ) self.experiment.end() + + def _scan_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + # get checkpoints to be saved with associated score + checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) + + # update model time and append path to list of all checkpoints + for t, p, _, _ in checkpoints: + self._logged_model_time[p] = t + self._all_checkpoint_paths.append(p) + + def _save_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + # drop unused checkpoints + if not self._experiment._resume: # noqa: SLF001 + for p in Path(checkpoint_callback.dirpath).iterdir(): + if str(p) not in self._all_checkpoint_paths: + p.unlink(missing_ok=True) + + # save directory + self.experiment.log_artifact(checkpoint_callback.dirpath) From 48309f5ec54bc779f0d791507d0a85f1aae6b28a Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 12 Jul 2023 18:30:14 -0400 Subject: [PATCH 5/5] lightning: add tests for log_model --- src/dvclive/lightning.py | 14 ++++----- tests/test_frameworks/test_lightning.py | 39 +++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 39361fa7..fef5de26 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -130,7 +130,9 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): self.experiment.next_step() def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: - self._checkpoint_callback = checkpoint_callback + if self._log_model in [True, "all"]: + self._checkpoint_callback = checkpoint_callback + self._scan_checkpoints(checkpoint_callback) if self._log_model == "all" or ( self._log_model is True and checkpoint_callback.save_top_k == -1 ): @@ -138,13 +140,11 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: @rank_zero_only def finalize(self, status: str) -> None: - checkpoint_callback = self._checkpoint_callback - # Save model checkpoints. - if self._log_model is True: - self._save_checkpoints(checkpoint_callback) # Log best model. - if self._log_model in (True, "all"): - best_model_path = checkpoint_callback.best_model_path + if self._checkpoint_callback: + self._scan_checkpoints(self._checkpoint_callback) + self._save_checkpoints(self._checkpoint_callback) + best_model_path = self._checkpoint_callback.best_model_path self.experiment.log_artifact( best_model_path, name="best", type="model", cache=False ) diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index 8aa4fbf5..e572b235 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -8,8 +8,9 @@ try: import torch - from pytorch_lightning import LightningModule - from pytorch_lightning.trainer import Trainer + from lightning import LightningModule + from lightning.pytorch import Trainer + from lightning.pytorch.callbacks import ModelCheckpoint from torch import nn from torch.nn import functional as F # noqa: N812 from torch.optim import SGD, Adam @@ -18,7 +19,7 @@ from dvclive import Live from dvclive.lightning import DVCLiveLogger except ImportError: - pytest.skip("skipping pytorch_lightning tests", allow_module_level=True) + pytest.skip("skipping lightning tests", allow_module_level=True) class XORDataset(Dataset): @@ -161,6 +162,38 @@ def test_lightning_kwargs(tmp_dir): assert dvclive_logger.experiment._cache_images is True +@pytest.mark.parametrize("log_model", [False, True, "all"]) +@pytest.mark.parametrize("save_top_k", [1, -1]) +def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k): + model = LitXOR() + dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model) + checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k) + trainer = Trainer( + logger=dvclive_logger, + max_epochs=2, + log_every_n_steps=1, + callbacks=[checkpoint], + ) + log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact") + trainer.fit(model) + + # Check that log_artifact is called. + if log_model is False: + log_artifact.assert_not_called() + elif (log_model is True) and (save_top_k != -1): + # called once to cache, then again to log best artifact + assert log_artifact.call_count == 2 + else: + # once per epoch plus two calls at the end (see above) + assert log_artifact.call_count == 4 + + # Check that checkpoint files does not grow with each run. + num_checkpoints = len(os.listdir(tmp_dir / "model")) + if log_model in [True, "all"]: + trainer.fit(model) + assert len(os.listdir(tmp_dir / "model")) == num_checkpoints + + def test_lightning_steps(tmp_dir, mocker): model = LitXOR() # Handle kwargs passed to Live.