diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index 32b82de42a2..ee19d034d4b 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -256,6 +256,60 @@ def _from_stash_entry( **kwargs, ) + @classmethod + def save( + cls, + info: "ExecutorInfo", + force: bool = False, + include_untracked: Optional[List[str]] = None, + ) -> ExecutorResult: + from dvc.repo import Repo + + exp_hash: Optional[str] = None + exp_ref: Optional[ExpRefInfo] = None + + dvc = Repo(os.path.join(info.root_dir, info.dvc_dir)) + old_cwd = os.getcwd() + if info.wdir: + os.chdir(os.path.join(dvc.scm.root_dir, info.wdir)) + else: + os.chdir(dvc.root_dir) + + try: + stages = dvc.commit([], force=True, relink=False) + exp_hash = cls.hash_exp(stages) + if include_untracked: + dvc.scm.add(include_untracked) + cls.commit( + dvc.scm, # type: ignore[arg-type] + exp_hash, + exp_name=info.name, + force=force, + ) + ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False) + exp_ref = ExpRefInfo.from_ref(ref) if ref else None + untracked = dvc.scm.untracked_files() + if untracked: + logger.warning( + "The following untracked files were present in " + "the workspace before saving but " + "will not be included in the experiment commit:\n" + "\t%s", + ", ".join(untracked), + ) + info.result_hash = exp_hash + info.result_ref = ref + info.result_force = False + info.status = TaskStatus.SUCCESS + except DvcException: + info.status = TaskStatus.FAILED + raise + finally: + dvc.close() + os.chdir(old_cwd) + + return ExecutorResult(ref, exp_ref, info.result_force) + @staticmethod def hash_exp(stages: Iterable["PipelineStage"]) -> str: from dvc.stage import PipelineStage diff --git a/dvc/repo/experiments/save.py b/dvc/repo/experiments/save.py index 8851c134167..5853924e736 100644 --- a/dvc/repo/experiments/save.py +++ b/dvc/repo/experiments/save.py @@ -2,13 +2,7 @@ import os from typing import TYPE_CHECKING, List, Optional -from pathspec import PathSpec - -from dvc.scm import Git - -from .exceptions import ExperimentExistsError -from .refs import ExpRefInfo -from .utils import check_ref_format, get_random_exp_name +from funcy import first if TYPE_CHECKING: from dvc.repo import Repo @@ -17,34 +11,6 @@ logger = logging.getLogger(__name__) -def _save_experiment( - repo: "Repo", - baseline_rev: str, - force: bool, - name: Optional[str], - include_untracked: Optional[List[str]], -) -> str: - repo.commit([], force=True, relink=False) - - name = name or get_random_exp_name(repo.scm, baseline_rev) - ref_info = ExpRefInfo(baseline_rev, name) - check_ref_format(repo.scm.dulwich, ref_info) - ref = str(ref_info) - if repo.scm.get_ref(ref) and not force: - raise ExperimentExistsError(ref_info.name, command="save") - - assert isinstance(repo.scm, Git) - - repo.scm.add([], update=True) - if include_untracked: - repo.scm.add(include_untracked) - repo.scm.commit(f"dvc: commit experiment {name}", no_verify=True) - exp_rev = repo.scm.get_rev() - repo.scm.set_ref(ref, exp_rev, old_ref=None) - - return exp_rev - - def save( repo: "Repo", name: Optional[str] = None, @@ -57,33 +23,16 @@ def save( """ logger.debug("Saving workspace in %s", os.getcwd()) - assert isinstance(repo.scm, Git) + queue = repo.experiments.workspace_queue + entry = repo.experiments.new(queue=queue, name=name, force=force) + executor = queue.init_executor(repo.experiments, entry) - _, _, untracked = repo.scm.status() - if include_untracked: - spec = PathSpec.from_lines("gitwildmatch", include_untracked) - untracked = [file for file in untracked if not spec.match_file(file)] - if untracked: - logger.warning( - ( - "The following untracked files were present in " - "the workspace before saving but " - "will not be included in the experiment commit:\n" - "\t%s" - ), - ", ".join(untracked), + try: + save_result = executor.save( + executor.info, force=force, include_untracked=include_untracked ) + result = queue.collect_executor(repo.experiments, executor, save_result) + finally: + executor.cleanup(None) - with repo.scm.detach_head(client="dvc") as orig_head: - with repo.scm.stash_workspace() as workspace: - try: - if workspace is not None: - repo.scm.stash.apply(workspace) - - exp_rev = _save_experiment( - repo, orig_head, force, name, include_untracked - ) - finally: - repo.scm.reset(hard=True) - - return exp_rev + return first(result) diff --git a/tests/func/experiments/test_save.py b/tests/func/experiments/test_save.py index ac6690f630c..c756d5b371b 100644 --- a/tests/func/experiments/test_save.py +++ b/tests/func/experiments/test_save.py @@ -1,5 +1,3 @@ -import logging - import pytest from funcy import first @@ -8,18 +6,21 @@ from dvc.scm import resolve_rev -@pytest.fixture -def modified_exp_stage(exp_stage, tmp_dir): - with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh: - fh.write("\n# dummy change") +def setup_stage(tmp_dir, dvc, scm): + tmp_dir.gen("params.yaml", "foo: 1") + dvc.run(name="echo-foo", outs=["bar"], cmd="echo ${foo} > bar") + scm.add(["dvc.yaml", "dvc.lock", ".gitignore", "params.yaml"]) + scm.commit("init") -def test_exp_save_unchanged(tmp_dir, dvc, scm, exp_stage): +def test_exp_save_unchanged(tmp_dir, dvc, scm): + setup_stage(tmp_dir, dvc, scm) dvc.experiments.save() @pytest.mark.parametrize("name", (None, "test")) -def test_exp_save(tmp_dir, dvc, scm, exp_stage, name, modified_exp_stage): +def test_exp_save(tmp_dir, dvc, scm, name): + setup_stage(tmp_dir, dvc, scm) baseline = scm.get_rev() exp = dvc.experiments.save(name=name) @@ -32,12 +33,12 @@ def test_exp_save(tmp_dir, dvc, scm, exp_stage, name, modified_exp_stage): assert resolve_rev(scm, exp_name) == exp -def test_exp_save_overwrite_experiment( - tmp_dir, dvc, scm, exp_stage, modified_exp_stage -): +def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm): + setup_stage(tmp_dir, dvc, scm) name = "dummy" dvc.experiments.save(name=name) + tmp_dir.gen("params.yaml", "foo: 2") with pytest.raises(ExperimentExistsError): dvc.experiments.save(name=name) @@ -54,12 +55,14 @@ def test_exp_save_overwrite_experiment( "invalidname.", ), ) -def test_exp_save_invalid_name(tmp_dir, dvc, scm, exp_stage, name): +def test_exp_save_invalid_name(tmp_dir, dvc, scm, name): + setup_stage(tmp_dir, dvc, scm) with pytest.raises(InvalidArgumentError): dvc.experiments.save(name=name, force=True) -def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage): +def test_exp_save_after_commit(tmp_dir, dvc, scm): + setup_stage(tmp_dir, dvc, scm) baseline = scm.get_rev() dvc.experiments.save(name="exp-1", force=True) @@ -72,36 +75,44 @@ def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage): def test_exp_save_with_staged_changes(tmp_dir, dvc, scm): + setup_stage(tmp_dir, dvc, scm) + tmp_dir.gen({"deleted": "deleted", "modified": "modified"}) + scm.add_commit(["deleted", "modified"], "init") + + (tmp_dir / "deleted").unlink() tmp_dir.gen({"new_file": "new_file"}) - scm.add("new_file") + (tmp_dir / "modified").write_text("foo") + scm.add(["deleted", "new_file", "modified"]) - dvc.experiments.save(name="exp") + exp_rev = dvc.experiments.save(name="exp") + scm.checkout(exp_rev, force=True) + assert not (tmp_dir / "deleted").exists() + assert (tmp_dir / "new_file").exists() + assert (tmp_dir / "modified").read_text() == "foo" - _, _, unstaged = scm.status() - assert "new_file" in unstaged +def test_exp_save_include_untracked(tmp_dir, dvc, scm): + setup_stage(tmp_dir, dvc, scm) -def test_exp_save_include_untracked(tmp_dir, dvc, scm, exp_stage): new_file = tmp_dir / "new_file" - for i in range(2): - new_file.write_text(f"exp-{i}") - dvc.experiments.save(name=f"exp-{i}", include_untracked=["new_file"]) + new_file.write_text("new_file") + dvc.experiments.save(name="exp", include_untracked=["new_file"]) _, _, unstaged = scm.status() assert "new_file" in unstaged - assert new_file.read_text() == f"exp-{i}" + assert new_file.read_text() == "new_file" - dvc.experiments.apply("exp-0") - assert new_file.read_text() == "exp-0" - -def test_exp_save_include_untracked_warning(tmp_dir, dvc, scm, caplog, exp_stage): +def test_exp_save_include_untracked_warning(tmp_dir, dvc, scm, mocker): """Regression test for https://github.com/iterative/dvc/issues/9061""" + setup_stage(tmp_dir, dvc, scm) + new_dir = tmp_dir / "new_dir" new_dir.mkdir() (new_dir / "foo").write_text("foo") (new_dir / "bar").write_text("bar") - with caplog.at_level(logging.WARNING, logger="dvc.repo.experiments.save"): - dvc.experiments.save(name="exp", include_untracked=["new_dir"]) - assert not caplog.records + logger = mocker.patch("dvc.repo.experiments.executor.base.logger") + + dvc.experiments.save(name="exp", include_untracked=["new_dir"]) + assert not logger.warning.called