diff --git a/dvc/commands/experiments/__init__.py b/dvc/commands/experiments/__init__.py index addf85b7d7..adbe6b9232 100644 --- a/dvc/commands/experiments/__init__.py +++ b/dvc/commands/experiments/__init__.py @@ -18,6 +18,7 @@ queue_worker, remove, run, + save, show, ) @@ -34,6 +35,7 @@ queue_worker, remove, run, + save, show, ] diff --git a/dvc/commands/experiments/save.py b/dvc/commands/experiments/save.py new file mode 100644 index 0000000000..20a5ec29da --- /dev/null +++ b/dvc/commands/experiments/save.py @@ -0,0 +1,79 @@ +import argparse +import logging + +from dvc.cli.command import CmdBase +from dvc.cli.utils import append_doc_link +from dvc.exceptions import DvcException +from dvc.ui import ui + +logger = logging.getLogger(__name__) + + +class CmdExperimentsSave(CmdBase): + def run(self): + + try: + ref = self.repo.experiments.save( + name=self.args.name, + force=self.args.force, + include_untracked=self.args.include_untracked, + ) + except DvcException: + logger.exception("failed to save experiment") + return 1 + + if self.args.json: + ui.write_json({"ref": ref}) + else: + name = self.repo.experiments.get_exact_name([ref])[ref] + ui.write(f"Experiment has been saved as: {name}") + ui.write( + "\nTo promote an experiment to a Git branch run:\n\n" + "\tdvc exp branch \n" + ) + + return 0 + + +def add_parser(experiments_subparsers, parent_parser): + EXPERIMENTS_SAVE_HELP = "Save current workspace as a dvc experiment." + save_parser = experiments_subparsers.add_parser( + "save", + parents=[parent_parser], + description=append_doc_link(EXPERIMENTS_SAVE_HELP, "exp/save"), + help=EXPERIMENTS_SAVE_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + save_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Save even if hash value for dependencies/outputs changed.", + ) + save_parser.add_argument( + "--json", + "--show-json", + action="store_true", + default=False, + help="Show output in JSON format.", + ) + save_parser.add_argument( + "-n", + "--name", + default=None, + help=( + "Human-readable experiment name. If not specified, a name will " + "be auto-generated." + ), + metavar="", + ) + save_parser.add_argument( + "-I", + "--include-untracked", + action="append", + default=[], + help="List of untracked paths to include in the experiment.", + metavar="", + ) + save_parser.set_defaults(func=CmdExperimentsSave) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index b335c29335..27167b86c9 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -479,6 +479,11 @@ def run(self, *args, **kwargs): return run(self.repo, *args, **kwargs) + def save(self, *args, **kwargs): + from dvc.repo.experiments.save import save + + return save(self.repo, *args, **kwargs) + def gc(self, *args, **kwargs): from dvc.repo.experiments.gc import gc diff --git a/dvc/repo/experiments/executor/local.py b/dvc/repo/experiments/executor/local.py index e919a1f4f7..cbf725c4d2 100644 --- a/dvc/repo/experiments/executor/local.py +++ b/dvc/repo/experiments/executor/local.py @@ -2,12 +2,13 @@ import os from contextlib import ExitStack from tempfile import mkdtemp -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional from funcy import cached_property, retry from scmrepo.exceptions import SCMError as _SCMError from shortuuid import uuid +from dvc.exceptions import DvcException from dvc.lock import LockError from dvc.scm import SCM, GitMergeError from dvc.utils.fs import makedirs, remove @@ -21,17 +22,18 @@ EXEC_MERGE, EXEC_NAMESPACE, EXPS_TEMP, + ExpRefInfo, ) from ..utils import EXEC_TMP_DIR, get_exp_rwlock -from .base import BaseExecutor, TaskStatus +from .base import BaseExecutor, ExecutorResult, TaskStatus if TYPE_CHECKING: from scmrepo.git import Git from dvc.repo import Repo - from ..refs import ExpRefInfo from ..stash import ExpStashEntry + from .base import ExecutorInfo logger = logging.getLogger(__name__) @@ -247,3 +249,57 @@ def cleanup(self, infofile: str): checkpoint = self.scm.get_ref(EXEC_CHECKPOINT) if checkpoint and checkpoint != self._orig_checkpoint: self.scm.set_ref(EXEC_APPLY, checkpoint) + + @classmethod + def save( + cls, + info: "ExecutorInfo", + force: bool = False, + include_untracked: Optional[List[str]] = None, + ) -> ExecutorResult: + from dvc.repo import Repo + + exp_hash: Optional[str] = None + exp_ref: Optional[ExpRefInfo] = None + + dvc = Repo(os.path.join(info.root_dir, info.dvc_dir)) + old_cwd = os.getcwd() + if info.wdir: + os.chdir(os.path.join(dvc.scm.root_dir, info.wdir)) + else: + os.chdir(dvc.root_dir) + + try: + stages = dvc.commit([], force=force) + exp_hash = cls.hash_exp(stages) + if include_untracked: + dvc.scm.add(include_untracked) + cls.commit( + dvc.scm, + exp_hash, + exp_name=info.name, + force=force, + ) + ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False) + exp_ref = ExpRefInfo.from_ref(ref) if ref else None + untracked = dvc.scm.untracked_files() + if untracked: + logger.warning( + "The following untracked files were present in " + "the workspace before saving but " + "will not be included in the experiment commit:\n" + "\t%s", + ", ".join(untracked), + ) + info.result_hash = exp_hash + info.result_ref = ref + info.result_force = False + info.status = TaskStatus.SUCCESS + except DvcException: + info.status = TaskStatus.FAILED + raise + finally: + dvc.close() + os.chdir(old_cwd) + + return ExecutorResult(ref, exp_ref, info.result_force) diff --git a/dvc/repo/experiments/save.py b/dvc/repo/experiments/save.py new file mode 100644 index 0000000000..260521aa8e --- /dev/null +++ b/dvc/repo/experiments/save.py @@ -0,0 +1,44 @@ +import logging +import os +from typing import TYPE_CHECKING, List, Optional + +from funcy import first + +if TYPE_CHECKING: + from dvc.repo import Repo + + +logger = logging.getLogger(__name__) + + +def save( + repo: "Repo", + name: Optional[str] = None, + force: bool = False, + include_untracked: Optional[List[str]] = None, +) -> Optional[str]: + """Save the current workspace status as an experiment. + + Returns the saved experiment's SHAs. + """ + queue = repo.experiments.workspace_queue + logger.debug("Saving workspace in %s", os.getcwd()) + + staged, _, _ = repo.scm.status(untracked_files="no") + if staged: + logger.warning( + "Your workspace contains staged Git changes which will be " + "unstaged before saving this experiment." + ) + repo.scm.reset() + + entry = repo.experiments.new(queue=queue, name=name, force=force) + executor = queue.init_executor(repo.experiments, entry) + + save_result = executor.save( + executor.info, force=force, include_untracked=include_untracked + ) + result = queue.collect_executor(repo.experiments, executor, save_result) + + exp_rev = first(result) + return exp_rev diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index d8b5606df7..24811a6797 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -9,6 +9,7 @@ from dvc.dvcfile import PIPELINE_FILE from dvc.exceptions import ReproductionError +from dvc.repo.experiments.exceptions import ExperimentExistsError from dvc.repo.experiments.queue.base import BaseStashQueue from dvc.repo.experiments.utils import exp_refs_by_rev from dvc.scm import resolve_rev @@ -44,8 +45,6 @@ def test_new_simple(tmp_dir, scm, dvc, exp_stage, mocker, name, workspace): def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace): - from dvc.repo.experiments.exceptions import ExperimentExistsError - dvc.experiments.run( exp_stage.addressing, name="foo", diff --git a/tests/func/experiments/test_save.py b/tests/func/experiments/test_save.py new file mode 100644 index 0000000000..7ad61aa2e4 --- /dev/null +++ b/tests/func/experiments/test_save.py @@ -0,0 +1,85 @@ +from contextlib import nullcontext + +import pytest +from funcy import first + +from dvc.repo.experiments.exceptions import ExperimentExistsError +from dvc.repo.experiments.utils import exp_refs_by_rev +from dvc.scm import resolve_rev +from dvc.stage.exceptions import StageCommitError + + +@pytest.mark.parametrize("name", (None, "test")) +def test_exp_save(tmp_dir, dvc, scm, exp_stage, name): + baseline = scm.get_rev() + + exp = dvc.experiments.save(name=name) + ref_info = first(exp_refs_by_rev(scm, exp)) + assert ref_info and ref_info.baseline_sha == baseline + + exp_name = name if name else ref_info.name + assert dvc.experiments.get_exact_name([exp])[exp] == exp_name + assert resolve_rev(scm, exp_name) == exp + + +@pytest.mark.parametrize( + ("force", "expected_raises"), + ( + (False, pytest.raises(StageCommitError)), + (True, nullcontext()), + ), +) +def test_exp_save_force(tmp_dir, dvc, scm, exp_stage, force, expected_raises): + with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh: + fh.write("\n# dummy change") + + with expected_raises: + dvc.experiments.save(force=force) + + +def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm, exp_stage): + dvc.experiments.save(name="dummy") + + with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh: + fh.write("\n# dummy change") + + with pytest.raises(ExperimentExistsError): + dvc.experiments.save(name="dummy") + + dvc.experiments.save(name="dummy", force=True) + + +def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage): + baseline = scm.get_rev() + dvc.experiments.save(name="exp-1") + + tmp_dir.scm_gen({"new_file": "new_file"}, commit="new baseline") + dvc.experiments.save(name="exp-2") + + all_exps = dvc.experiments.ls(all_commits=True) + assert all_exps[baseline[:7]] == ["exp-1"] + assert all_exps["master"] == ["exp-2"] + + +def test_exp_save_with_staged_changes(tmp_dir, dvc, scm): + tmp_dir.gen({"new_file": "new_file"}) + scm.add("new_file") + + dvc.experiments.save(name="exp") + + _, _, unstaged = scm.status() + assert "new_file" in unstaged + + +def test_exp_save_include_untracked(tmp_dir, dvc, scm, exp_stage): + new_file = tmp_dir / "new_file" + for i in range(2): + new_file.write_text(f"exp-{i}") + dvc.experiments.save(name=f"exp-{i}", include_untracked=["new_file"]) + + _, _, unstaged = scm.status() + assert "new_file" in unstaged + assert new_file.read_text() == f"exp-{i}" + + dvc.experiments.apply("exp-0") + assert new_file.read_text() == "exp-0" diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index c20130db8b..36ff5c352e 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -16,6 +16,7 @@ from dvc.commands.experiments.push import CmdExperimentsPush from dvc.commands.experiments.remove import CmdExperimentsRemove from dvc.commands.experiments.run import CmdExperimentsRun +from dvc.commands.experiments.save import CmdExperimentsSave from dvc.commands.experiments.show import CmdExperimentsShow, show_experiments from dvc.exceptions import InvalidArgumentError from dvc.repo import Repo @@ -934,3 +935,17 @@ def test_show_experiments_pcp(tmp_dir, mocker): assert kwargs["output_path"] == str(tmp_dir / "dvc_plots" / "index.html") assert kwargs["color_by"] == "Experiment" + + +def test_experiments_save(dvc, scm, mocker): + cli_args = parse_args(["exp", "save", "--name", "exp-name", "--force"]) + assert cli_args.func == CmdExperimentsSave + + cmd = cli_args.func(cli_args) + m = mocker.patch("dvc.repo.experiments.save.save", return_value="acabb") + + assert cmd.run() == 0 + + m.assert_called_once_with( + cmd.repo, name="exp-name", force=True, include_untracked=[] + )