From 24e10e0278c1bbfeb37d4414b48c84fbded1165f Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Mon, 27 Mar 2023 22:02:05 +0200 Subject: [PATCH] exp save: Track top level files --- dvc/repo/experiments/executor/base.py | 9 +++++++++ tests/func/experiments/test_save.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index 74bf160d0a..2d1e670438 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -34,6 +34,8 @@ ExpRefInfo, ) from dvc.repo.experiments.utils import to_studio_params +from dvc.repo.metrics.show import _collect_top_level_metrics +from dvc.repo.params.show import _collect_top_level_params from dvc.stage.serialize import to_lockfile from dvc.ui import ui from dvc.utils import dict_sha256, env2bool, relpath @@ -275,6 +277,13 @@ def save( else: os.chdir(dvc.root_dir) + include_untracked = include_untracked or [] + include_untracked.extend(_collect_top_level_metrics(dvc)) + include_untracked.extend(_collect_top_level_params(dvc)) + include_untracked.extend( + dvc.index._plot_sources # pylint: disable=protected-access + ) + try: stages = dvc.commit([], force=True, relink=False) exp_hash = cls.hash_exp(stages) diff --git a/tests/func/experiments/test_save.py b/tests/func/experiments/test_save.py index c756d5b371..c77c1bc0d7 100644 --- a/tests/func/experiments/test_save.py +++ b/tests/func/experiments/test_save.py @@ -116,3 +116,23 @@ def test_exp_save_include_untracked_warning(tmp_dir, dvc, scm, mocker): dvc.experiments.save(name="exp", include_untracked=["new_dir"]) assert not logger.warning.called + + +def test_untracked_top_level_files_are_included_in_exp(tmp_dir, scm, dvc): + (tmp_dir / "dvc.yaml").dump( + { + "metrics": ["metrics.json"], + "params": ["params.yaml"], + "plots": ["plots.csv"], + } + ) + stage = dvc.stage.add( + cmd="touch metrics.json && touch params.yaml && touch plots.csv", + name="top-level", + ) + scm.add_commit(["dvc.yaml"], message="add dvc.yaml") + dvc.reproduce(stage.addressing) + exp = dvc.experiments.save() + fs = scm.get_fs(exp) + for file in ["metrics.json", "params.yaml", "plots.csv"]: + assert fs.exists(file)