Skip to content

Commit

Permalink
exp save: Reintroduce implementation based on executor/queue.
Browse files Browse the repository at this point in the history
Closes #9058 .

Ensures behavior of `_stash_exp` in `exp run` is matched in `exp save` without duplicating logic.

The simpler implementation ended up being worse for maintenance when wanting to make `exp run` and `exp save` behavior in sync.
  • Loading branch information
daavoo committed Mar 13, 2023
1 parent 16ecd53 commit 8286bf0
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 72 deletions.
54 changes: 54 additions & 0 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,60 @@ def _from_stash_entry(
**kwargs,
)

@classmethod
def save(
cls,
info: "ExecutorInfo",
force: bool = False,
include_untracked: Optional[List[str]] = None,
) -> ExecutorResult:
from dvc.repo import Repo

exp_hash: Optional[str] = None
exp_ref: Optional[ExpRefInfo] = None

dvc = Repo(os.path.join(info.root_dir, info.dvc_dir))
old_cwd = os.getcwd()
if info.wdir:
os.chdir(os.path.join(dvc.scm.root_dir, info.wdir))
else:
os.chdir(dvc.root_dir)

try:
stages = dvc.commit([], force=True, relink=False)
exp_hash = cls.hash_exp(stages)
if include_untracked:
dvc.scm.add(include_untracked)
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
force=force,
)
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref = ExpRefInfo.from_ref(ref) if ref else None
untracked = dvc.scm.untracked_files()
if untracked:
logger.warning(
"The following untracked files were present in "
"the workspace before saving but "
"will not be included in the experiment commit:\n"
"\t%s",
", ".join(untracked),
)
info.result_hash = exp_hash
info.result_ref = ref
info.result_force = False
info.status = TaskStatus.SUCCESS
except DvcException:
info.status = TaskStatus.FAILED
raise
finally:
dvc.close()
os.chdir(old_cwd)

return ExecutorResult(ref, exp_ref, info.result_force)

@staticmethod
def hash_exp(stages: Iterable["PipelineStage"]) -> str:
from dvc.stage import PipelineStage
Expand Down
72 changes: 9 additions & 63 deletions dvc/repo/experiments/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
import os
from typing import TYPE_CHECKING, List, Optional

from pathspec import PathSpec

from dvc.scm import Git

from .exceptions import ExperimentExistsError
from .refs import ExpRefInfo
from .utils import check_ref_format, get_random_exp_name
from funcy import first

if TYPE_CHECKING:
from dvc.repo import Repo
Expand All @@ -17,34 +11,6 @@
logger = logging.getLogger(__name__)


def _save_experiment(
repo: "Repo",
baseline_rev: str,
force: bool,
name: Optional[str],
include_untracked: Optional[List[str]],
) -> str:
repo.commit([], force=True, relink=False)

name = name or get_random_exp_name(repo.scm, baseline_rev)
ref_info = ExpRefInfo(baseline_rev, name)
check_ref_format(repo.scm.dulwich, ref_info)
ref = str(ref_info)
if repo.scm.get_ref(ref) and not force:
raise ExperimentExistsError(ref_info.name, command="save")

assert isinstance(repo.scm, Git)

repo.scm.add([], update=True)
if include_untracked:
repo.scm.add(include_untracked)
repo.scm.commit(f"dvc: commit experiment {name}", no_verify=True)
exp_rev = repo.scm.get_rev()
repo.scm.set_ref(ref, exp_rev, old_ref=None)

return exp_rev


def save(
repo: "Repo",
name: Optional[str] = None,
Expand All @@ -57,33 +23,13 @@ def save(
"""
logger.debug("Saving workspace in %s", os.getcwd())

assert isinstance(repo.scm, Git)

_, _, untracked = repo.scm.status()
if include_untracked:
spec = PathSpec.from_lines("gitwildmatch", include_untracked)
untracked = [file for file in untracked if not spec.match_file(file)]
if untracked:
logger.warning(
(
"The following untracked files were present in "
"the workspace before saving but "
"will not be included in the experiment commit:\n"
"\t%s"
),
", ".join(untracked),
)

with repo.scm.detach_head(client="dvc") as orig_head:
with repo.scm.stash_workspace() as workspace:
try:
if workspace is not None:
repo.scm.stash.apply(workspace)
queue = repo.experiments.workspace_queue
entry = repo.experiments.new(queue=queue, name=name, force=force)
executor = queue.init_executor(repo.experiments, entry)

exp_rev = _save_experiment(
repo, orig_head, force, name, include_untracked
)
finally:
repo.scm.reset(hard=True)
save_result = executor.save(
executor.info, force=force, include_untracked=include_untracked
)
result = queue.collect_executor(repo.experiments, executor, save_result)

return exp_rev
return first(result)
29 changes: 20 additions & 9 deletions tests/func/experiments/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def test_exp_save(tmp_dir, dvc, scm, exp_stage, name, modified_exp_stage):
assert resolve_rev(scm, exp_name) == exp


def test_exp_save_overwrite_experiment(
tmp_dir, dvc, scm, exp_stage, modified_exp_stage
):
def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm):
tmp_dir.gen("params.yaml", "foo: 1")
dvc.run(name="echo-foo", outs=["bar"], cmd="echo ${foo} > bar")
scm.add(["dvc.yaml", "dvc.lock", ".gitignore", "params.yaml"])
scm.commit("init")
name = "dummy"
dvc.experiments.save(name=name)

tmp_dir.gen("params.yaml", "foo: 2")
with pytest.raises(ExperimentExistsError):
dvc.experiments.save(name=name)

Expand Down Expand Up @@ -72,13 +75,21 @@ def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage):


def test_exp_save_with_staged_changes(tmp_dir, dvc, scm):
tmp_dir.gen({"new_file": "new_file"})
scm.add("new_file")

dvc.experiments.save(name="exp")
tmp_dir.gen({"deleted": "deleted", "modified": "modified"})
scm.add_commit(["deleted", "modified"], "init")

_, _, unstaged = scm.status()
assert "new_file" in unstaged
(tmp_dir / "deleted").unlink()
tmp_dir.gen({"new_file": "new_file"})
(tmp_dir / "modified").write_text("foo")
scm.add(["deleted", "new_file", "modified"])

# prev_status = scm.status()
exp_rev = dvc.experiments.save(name="exp")
# assert scm.status() == prev_status
scm.checkout(exp_rev, force=True)
assert not (tmp_dir / "deleted").exists()
assert (tmp_dir / "new_file").exists()
assert (tmp_dir / "modified").read_text() == "foo"


def test_exp_save_include_untracked(tmp_dir, dvc, scm, exp_stage):
Expand Down

0 comments on commit 8286bf0

Please sign in to comment.