From ff48a1cdbd12879842a0ee7aa833f4dddab0b083 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 18 Oct 2023 15:48:57 -0400 Subject: [PATCH 1/3] custom exp name --- src/dvclive/live.py | 8 ++++++-- tests/test_dvc.py | 11 +++++++++++ tests/test_post_to_studio.py | 11 +++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 325f9e32..5d59831c 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -65,6 +65,7 @@ def __init__( save_dvc_exp: bool = True, dvcyaml: Union[str, bool] = True, cache_images: bool = False, + exp_name: Optional[str] = None, exp_message: Optional[str] = None, ): self.summary: Dict[str, Any] = {} @@ -89,7 +90,7 @@ def __init__( self._init_report() self._baseline_rev: Optional[str] = None - self._exp_name: Optional[str] = None + self._exp_name: Optional[str] = exp_name self._exp_message: Optional[str] = exp_message self._experiment_rev: Optional[str] = None self._inside_dvc_exp: bool = False @@ -178,7 +179,10 @@ def _init_dvc(self): self._baseline_rev = self._dvc_repo.scm.get_rev() if self._save_dvc_exp: - self._exp_name = get_random_exp_name(self._dvc_repo.scm, self._baseline_rev) + if not self._exp_name: + self._exp_name = get_random_exp_name( + self._dvc_repo.scm, self._baseline_rev + ) mark_dvclive_only_started(self._exp_name) self._include_untracked.append(self.dir) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index d3c5055d..b82cea4d 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -181,6 +181,17 @@ def test_exp_save_message(tmp_dir, mocked_dvc_repo): ) +def test_exp_save_name(tmp_dir, mocked_dvc_repo): + live = Live(exp_name="custom-name") + live.end() + mocked_dvc_repo.experiments.save.assert_called_with( + name="custom-name", + include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")], + force=True, + message=None, + ) + + def test_no_scm_repo(tmp_dir, mocker): dvc_repo = mocker.MagicMock() dvc_repo.scm = NoSCM() diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index 3591a831..91a9d732 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -359,3 +359,14 @@ def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post): "https://0.0.0.0/api/live", **get_studio_call("start", exp_name=live._exp_name, message="Custom message"), ) + + +def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post): + Live(exp_name="custom-name") + + mocked_post, _ = mocked_studio_post + + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + **get_studio_call("start", exp_name="custom-name"), + ) From 30a6faf9b7d0f9e58197489c3418a3dd4d9797f1 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 18 Oct 2023 16:49:03 -0400 Subject: [PATCH 2/3] add to lightning --- src/dvclive/lightning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 9f56c5b1..07991df9 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -62,6 +62,7 @@ def __init__( # noqa: PLR0913 save_dvc_exp: bool = False, dvcyaml: bool = True, cache_images: bool = False, + exp_name: Optional[str] = None, ): super().__init__() self._prefix = prefix From 691839644ac4cb2372f5beae7b3bc2621f6c6da7 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Thu, 19 Oct 2023 10:52:28 -0400 Subject: [PATCH 3/3] validate exp name --- src/dvclive/dvc.py | 26 ++++++++++++++++++++------ src/dvclive/live.py | 10 +++++----- tests/test_dvc.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index 4aafb1e9..2a1ade36 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -1,5 +1,6 @@ # ruff: noqa: SLF001 import copy +import logging import os from pathlib import Path from typing import TYPE_CHECKING, Any, List, Optional @@ -12,6 +13,8 @@ from dvc.repo import Repo from dvc.stage import Stage +logger = logging.getLogger("dvclive") + def _dvc_dir(dirname: StrPath) -> str: return os.path.join(dirname, ".dvc") @@ -125,12 +128,23 @@ def _update_entries(old, new, key): del orig["artifacts"] -def get_random_exp_name(scm, baseline_rev) -> str: - from dvc.repo.experiments.utils import ( - get_random_exp_name as dvc_get_random_exp_name, - ) - - return dvc_get_random_exp_name(scm, baseline_rev) +def get_exp_name(name, scm, baseline_rev) -> str: + from dvc.exceptions import InvalidArgumentError + from dvc.repo.experiments.refs import ExpRefInfo + from dvc.repo.experiments.utils import check_ref_format, get_random_exp_name + + if name: + ref = ExpRefInfo(baseline_sha=baseline_rev, name=name) + if scm.get_ref(str(ref)): + logger.warning(f"Experiment conflicts with existing experiment '{name}'.") + else: + try: + check_ref_format(scm, ref) + except InvalidArgumentError as e: + logger.warning(e) + else: + return name + return get_random_exp_name(scm, baseline_rev) def find_overlapping_stage(dvc_repo: "Repo", path: StrPath) -> Optional["Stage"]: diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 5d59831c..5491bc91 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -16,7 +16,7 @@ ensure_dir_is_tracked, find_overlapping_stage, get_dvc_repo, - get_random_exp_name, + get_exp_name, make_dvcyaml, ) from .error import ( @@ -179,10 +179,10 @@ def _init_dvc(self): self._baseline_rev = self._dvc_repo.scm.get_rev() if self._save_dvc_exp: - if not self._exp_name: - self._exp_name = get_random_exp_name( - self._dvc_repo.scm, self._baseline_rev - ) + self._exp_name = get_exp_name( + self._exp_name, self._dvc_repo.scm, self._baseline_rev + ) + logger.info(f"Logging to experiment '{self._exp_name}'") mark_dvclive_only_started(self._exp_name) self._include_untracked.append(self.dir) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index b82cea4d..9f0445cc 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -209,3 +209,38 @@ def test_dvc_repro(tmp_dir, monkeypatch, mocker): mocker.patch("dvclive.live.get_dvc_repo", return_value=None) live = Live(save_dvc_exp=True) assert not live._save_dvc_exp + + +def test_get_exp_name_valid(tmp_dir, mocked_dvc_repo): + live = Live(exp_name="name") + assert live._exp_name == "name" + + +def test_get_exp_name_random(tmp_dir, mocked_dvc_repo, mocker): + mocker.patch( + "dvc.repo.experiments.utils.get_random_exp_name", return_value="random" + ) + live = Live() + assert live._exp_name == "random" + + +def test_get_exp_name_invalid(tmp_dir, mocked_dvc_repo, mocker, caplog): + mocker.patch( + "dvc.repo.experiments.utils.get_random_exp_name", return_value="random" + ) + with caplog.at_level("WARNING"): + live = Live(exp_name="invalid//name") + assert live._exp_name == "random" + assert caplog.text + + +def test_get_exp_name_duplicate(tmp_dir, mocked_dvc_repo, mocker, caplog): + mocker.patch( + "dvc.repo.experiments.utils.get_random_exp_name", return_value="random" + ) + mocked_dvc_repo.scm.get_ref.return_value = "duplicate" + with caplog.at_level("WARNING"): + live = Live(exp_name="duplicate") + assert live._exp_name == "random" + msg = "Experiment conflicts with existing experiment 'duplicate'." + assert msg in caplog.text