Skip to content

Commit

Permalink
exp save: Reintroduce implementation based on executor/queue.
Browse files Browse the repository at this point in the history
Closes #9058 .

Ensures behavior of `_stash_exp` in `exp run` is matched in `exp save` without duplicating logic.

The simpler implementation ended up being worse for maintenance when wanting to make `exp run` and `exp save` behavior in sync.
  • Loading branch information
daavoo committed Mar 14, 2023
1 parent 16ecd53 commit dabe361
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 92 deletions.
54 changes: 54 additions & 0 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,60 @@ def _from_stash_entry(
**kwargs,
)

@classmethod
def save(
cls,
info: "ExecutorInfo",
force: bool = False,
include_untracked: Optional[List[str]] = None,
) -> ExecutorResult:
from dvc.repo import Repo

exp_hash: Optional[str] = None
exp_ref: Optional[ExpRefInfo] = None

dvc = Repo(os.path.join(info.root_dir, info.dvc_dir))
old_cwd = os.getcwd()
if info.wdir:
os.chdir(os.path.join(dvc.scm.root_dir, info.wdir))
else:
os.chdir(dvc.root_dir)

try:
stages = dvc.commit([], force=True, relink=False)
exp_hash = cls.hash_exp(stages)
if include_untracked:
dvc.scm.add(include_untracked)
cls.commit(
dvc.scm, # type: ignore[arg-type]
exp_hash,
exp_name=info.name,
force=force,
)
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref = ExpRefInfo.from_ref(ref) if ref else None
untracked = dvc.scm.untracked_files()
if untracked:
logger.warning(
"The following untracked files were present in "
"the workspace before saving but "
"will not be included in the experiment commit:\n"
"\t%s",
", ".join(untracked),
)
info.result_hash = exp_hash
info.result_ref = ref
info.result_force = False
info.status = TaskStatus.SUCCESS
except DvcException:
info.status = TaskStatus.FAILED
raise
finally:
dvc.close()
os.chdir(old_cwd)

return ExecutorResult(ref, exp_ref, info.result_force)

@staticmethod
def hash_exp(stages: Iterable["PipelineStage"]) -> str:
from dvc.stage import PipelineStage
Expand Down
72 changes: 9 additions & 63 deletions dvc/repo/experiments/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
import os
from typing import TYPE_CHECKING, List, Optional

from pathspec import PathSpec

from dvc.scm import Git

from .exceptions import ExperimentExistsError
from .refs import ExpRefInfo
from .utils import check_ref_format, get_random_exp_name
from funcy import first

if TYPE_CHECKING:
from dvc.repo import Repo
Expand All @@ -17,34 +11,6 @@
logger = logging.getLogger(__name__)


def _save_experiment(
repo: "Repo",
baseline_rev: str,
force: bool,
name: Optional[str],
include_untracked: Optional[List[str]],
) -> str:
repo.commit([], force=True, relink=False)

name = name or get_random_exp_name(repo.scm, baseline_rev)
ref_info = ExpRefInfo(baseline_rev, name)
check_ref_format(repo.scm.dulwich, ref_info)
ref = str(ref_info)
if repo.scm.get_ref(ref) and not force:
raise ExperimentExistsError(ref_info.name, command="save")

assert isinstance(repo.scm, Git)

repo.scm.add([], update=True)
if include_untracked:
repo.scm.add(include_untracked)
repo.scm.commit(f"dvc: commit experiment {name}", no_verify=True)
exp_rev = repo.scm.get_rev()
repo.scm.set_ref(ref, exp_rev, old_ref=None)

return exp_rev


def save(
repo: "Repo",
name: Optional[str] = None,
Expand All @@ -57,33 +23,13 @@ def save(
"""
logger.debug("Saving workspace in %s", os.getcwd())

assert isinstance(repo.scm, Git)

_, _, untracked = repo.scm.status()
if include_untracked:
spec = PathSpec.from_lines("gitwildmatch", include_untracked)
untracked = [file for file in untracked if not spec.match_file(file)]
if untracked:
logger.warning(
(
"The following untracked files were present in "
"the workspace before saving but "
"will not be included in the experiment commit:\n"
"\t%s"
),
", ".join(untracked),
)

with repo.scm.detach_head(client="dvc") as orig_head:
with repo.scm.stash_workspace() as workspace:
try:
if workspace is not None:
repo.scm.stash.apply(workspace)
queue = repo.experiments.workspace_queue
entry = repo.experiments.new(queue=queue, name=name, force=force)
executor = queue.init_executor(repo.experiments, entry)

exp_rev = _save_experiment(
repo, orig_head, force, name, include_untracked
)
finally:
repo.scm.reset(hard=True)
save_result = executor.save(
executor.info, force=force, include_untracked=include_untracked
)
result = queue.collect_executor(repo.experiments, executor, save_result)

return exp_rev
return first(result)
69 changes: 40 additions & 29 deletions tests/func/experiments/test_save.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

import pytest
from funcy import first

Expand All @@ -8,18 +6,21 @@
from dvc.scm import resolve_rev


@pytest.fixture
def modified_exp_stage(exp_stage, tmp_dir):
with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh:
fh.write("\n# dummy change")
def setup_stage(tmp_dir, dvc, scm):
tmp_dir.gen("params.yaml", "foo: 1")
dvc.run(name="echo-foo", outs=["bar"], cmd="echo ${foo} > bar")
scm.add(["dvc.yaml", "dvc.lock", ".gitignore", "params.yaml"])
scm.commit("init")


def test_exp_save_unchanged(tmp_dir, dvc, scm, exp_stage):
def test_exp_save_unchanged(tmp_dir, dvc, scm):
setup_stage(tmp_dir, dvc, scm)
dvc.experiments.save()


@pytest.mark.parametrize("name", (None, "test"))
def test_exp_save(tmp_dir, dvc, scm, exp_stage, name, modified_exp_stage):
def test_exp_save(tmp_dir, dvc, scm, name):
setup_stage(tmp_dir, dvc, scm)
baseline = scm.get_rev()

exp = dvc.experiments.save(name=name)
Expand All @@ -32,12 +33,12 @@ def test_exp_save(tmp_dir, dvc, scm, exp_stage, name, modified_exp_stage):
assert resolve_rev(scm, exp_name) == exp


def test_exp_save_overwrite_experiment(
tmp_dir, dvc, scm, exp_stage, modified_exp_stage
):
def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm):
setup_stage(tmp_dir, dvc, scm)
name = "dummy"
dvc.experiments.save(name=name)

tmp_dir.gen("params.yaml", "foo: 2")
with pytest.raises(ExperimentExistsError):
dvc.experiments.save(name=name)

Expand All @@ -54,12 +55,14 @@ def test_exp_save_overwrite_experiment(
"invalidname.",
),
)
def test_exp_save_invalid_name(tmp_dir, dvc, scm, exp_stage, name):
def test_exp_save_invalid_name(tmp_dir, dvc, scm, name):
setup_stage(tmp_dir, dvc, scm)
with pytest.raises(InvalidArgumentError):
dvc.experiments.save(name=name, force=True)


def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage):
def test_exp_save_after_commit(tmp_dir, dvc, scm):
setup_stage(tmp_dir, dvc, scm)
baseline = scm.get_rev()
dvc.experiments.save(name="exp-1", force=True)

Expand All @@ -72,36 +75,44 @@ def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage):


def test_exp_save_with_staged_changes(tmp_dir, dvc, scm):
setup_stage(tmp_dir, dvc, scm)
tmp_dir.gen({"deleted": "deleted", "modified": "modified"})
scm.add_commit(["deleted", "modified"], "init")

(tmp_dir / "deleted").unlink()
tmp_dir.gen({"new_file": "new_file"})
scm.add("new_file")
(tmp_dir / "modified").write_text("foo")
scm.add(["deleted", "new_file", "modified"])

dvc.experiments.save(name="exp")
exp_rev = dvc.experiments.save(name="exp")
scm.checkout(exp_rev, force=True)
assert not (tmp_dir / "deleted").exists()
assert (tmp_dir / "new_file").exists()
assert (tmp_dir / "modified").read_text() == "foo"

_, _, unstaged = scm.status()
assert "new_file" in unstaged

def test_exp_save_include_untracked(tmp_dir, dvc, scm):
setup_stage(tmp_dir, dvc, scm)

def test_exp_save_include_untracked(tmp_dir, dvc, scm, exp_stage):
new_file = tmp_dir / "new_file"
for i in range(2):
new_file.write_text(f"exp-{i}")
dvc.experiments.save(name=f"exp-{i}", include_untracked=["new_file"])
new_file.write_text("new_file")
dvc.experiments.save(name="exp", include_untracked=["new_file"])

_, _, unstaged = scm.status()
assert "new_file" in unstaged
assert new_file.read_text() == f"exp-{i}"
assert new_file.read_text() == "new_file"

dvc.experiments.apply("exp-0")
assert new_file.read_text() == "exp-0"


def test_exp_save_include_untracked_warning(tmp_dir, dvc, scm, caplog, exp_stage):
def test_exp_save_include_untracked_warning(tmp_dir, dvc, scm, mocker):
"""Regression test for https://github.com/iterative/dvc/issues/9061"""
setup_stage(tmp_dir, dvc, scm)

new_dir = tmp_dir / "new_dir"
new_dir.mkdir()
(new_dir / "foo").write_text("foo")
(new_dir / "bar").write_text("bar")

with caplog.at_level(logging.WARNING, logger="dvc.repo.experiments.save"):
dvc.experiments.save(name="exp", include_untracked=["new_dir"])
assert not caplog.records
logger = mocker.patch("dvc.repo.experiments.executor.base.logger")

dvc.experiments.save(name="exp", include_untracked=["new_dir"])
assert not logger.warning.called

0 comments on commit dabe361

Please sign in to comment.