diff --git a/dvc/commands/experiments/save.py b/dvc/commands/experiments/save.py index 2c7f3763fee..e99e4dfff85 100644 --- a/dvc/commands/experiments/save.py +++ b/dvc/commands/experiments/save.py @@ -14,7 +14,9 @@ def run(self): try: ref = self.repo.experiments.save( - name=self.args.name, force=self.args.force + name=self.args.name, + force=self.args.force, + include_untracked=self.args.include_untracked, ) except DvcException: logger.exception("failed to save experiment") @@ -66,4 +68,12 @@ def add_parser(experiments_subparsers, parent_parser): ), metavar="", ) + save_parser.add_argument( + "-I", + "--include-untracked", + action="append", + default=[], + help="List of untracked files to include in the experiment.", + metavar="", + ) save_parser.set_defaults(func=CmdExperimentsSave) diff --git a/dvc/repo/experiments/executor/local.py b/dvc/repo/experiments/executor/local.py index 45f090f2c7f..ba548eb8399 100644 --- a/dvc/repo/experiments/executor/local.py +++ b/dvc/repo/experiments/executor/local.py @@ -2,7 +2,7 @@ import os from contextlib import ExitStack from tempfile import mkdtemp -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional from funcy import cached_property from scmrepo.exceptions import SCMError as _SCMError @@ -221,6 +221,7 @@ def save( cls, info: "ExecutorInfo", force: bool = False, + include_untracked: Optional[List[str]] = None, ) -> ExecutorResult: from dvc.repo import Repo @@ -237,6 +238,8 @@ def save( try: stages = dvc.commit([], force=force) exp_hash = cls.hash_exp(stages) + if include_untracked: + dvc.scm.add(include_untracked) cls.commit( dvc.scm, exp_hash, diff --git a/dvc/repo/experiments/save.py b/dvc/repo/experiments/save.py index 692a4b6c7c6..260521aa8ea 100644 --- a/dvc/repo/experiments/save.py +++ b/dvc/repo/experiments/save.py @@ -15,6 +15,7 @@ def save( repo: "Repo", name: Optional[str] = None, force: bool = False, + include_untracked: Optional[List[str]] = None, ) -> Optional[str]: """Save the current workspace status as an experiment. @@ -33,7 +34,10 @@ def save( entry = repo.experiments.new(queue=queue, name=name, force=force) executor = queue.init_executor(repo.experiments, entry) - save_result = executor.save(executor.info, force=force) + + save_result = executor.save( + executor.info, force=force, include_untracked=include_untracked + ) result = queue.collect_executor(repo.experiments, executor, save_result) exp_rev = first(result) diff --git a/tests/func/experiments/test_save.py b/tests/func/experiments/test_save.py index 2427a00dda3..429ec194454 100644 --- a/tests/func/experiments/test_save.py +++ b/tests/func/experiments/test_save.py @@ -49,22 +49,6 @@ def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm, exp_stage): dvc.experiments.save(name="dummy", force=True) -def test_exp_save_multiple(tmp_dir, dvc, scm): - baseline = scm.get_rev() - for i in range(2): - name = f"exp-{i}" - tmp_dir.gen({name: f"{name} content"}) - dvc.experiments.save(name=name) - - assert dvc.experiments.ls()[baseline] == ["exp-0", "exp-1"] - - for i in range(2): - scm.reset(hard=True) - name = f"exp-{i}" - dvc.experiments.apply(name) - assert (tmp_dir / name).read_text() == f"{name} content" - - def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage): baseline = scm.get_rev() dvc.experiments.save(name="exp-1") @@ -85,4 +69,18 @@ def test_exp_save_with_staged_changes(tmp_dir, dvc, scm): dvc.experiments.save(name="exp") _, _, unstaged = scm.status() - assert "new_file" in unstaged \ No newline at end of file + assert "new_file" in unstaged + + +def test_exp_save_include_untracked(tmp_dir, dvc, scm, exp_stage): + new_file = tmp_dir / "new_file" + for i in range(2): + new_file.write_text(f"exp-{i}") + dvc.experiments.save(name=f"exp-{i}", include_untracked=["new_file"]) + + _, _, unstaged = scm.status() + assert "new_file" in unstaged + assert new_file.read_text() == f"exp-{i}" + + dvc.experiments.apply("exp-0") + assert new_file.read_text() == "exp-0" diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index e906411f124..36ff5c352ea 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -946,4 +946,6 @@ def test_experiments_save(dvc, scm, mocker): assert cmd.run() == 0 - m.assert_called_once_with(cmd.repo, name="exp-name", force=True) + m.assert_called_once_with( + cmd.repo, name="exp-name", force=True, include_untracked=[] + )