Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exp save: initial implementation #8599

Merged
merged 4 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dvc/commands/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
queue_worker,
remove,
run,
save,
show,
)

Expand All @@ -34,6 +35,7 @@
queue_worker,
remove,
run,
save,
show,
]

Expand Down
79 changes: 79 additions & 0 deletions dvc/commands/experiments/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import logging

from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import DvcException
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdExperimentsSave(CmdBase):
def run(self):

try:
ref = self.repo.experiments.save(
name=self.args.name,
force=self.args.force,
include_untracked=self.args.include_untracked,
)
except DvcException:
logger.exception("failed to save experiment")
return 1

if self.args.json:
ui.write_json({"ref": ref})
else:
name = self.repo.experiments.get_exact_name([ref])[ref]
ui.write(f"Experiment has been saved as: {name}")
ui.write(
"\nTo promote an experiment to a Git branch run:\n\n"
"\tdvc exp branch <exp> <branch>\n"
)

return 0


def add_parser(experiments_subparsers, parent_parser):
EXPERIMENTS_SAVE_HELP = "Save current workspace as a dvc experiment."
save_parser = experiments_subparsers.add_parser(
"save",
parents=[parent_parser],
description=append_doc_link(EXPERIMENTS_SAVE_HELP, "exp/save"),
help=EXPERIMENTS_SAVE_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
save_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Save even if hash value for dependencies/outputs changed.",
)
save_parser.add_argument(
"--json",
"--show-json",
action="store_true",
default=False,
help="Show output in JSON format.",
)
save_parser.add_argument(
"-n",
"--name",
default=None,
help=(
"Human-readable experiment name. If not specified, a name will "
"be auto-generated."
),
metavar="<name>",
)
save_parser.add_argument(
"-I",
"--include-untracked",
action="append",
default=[],
help="List of untracked paths to include in the experiment.",
metavar="<path>",
)
Comment on lines +71 to +78
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dberenbaum

$ dvc stage add -q -n exp_save -M metrics.yaml 'echo "foo: 1" > metrics.yaml'
$ echo misc > misc    
$ dvc repro -q
$ dvc exp save -I "misc" -I "metrics.yaml" -I "dvc.lock" -I "dvc.yaml" 
$ git stash -u
$ dvc exp apply exp-48145  
$ git status
On branch master
Untracked files:
  (use "git add <file>..." to include in what will be committed)
	dvc.lock
	dvc.yaml
	metrics.yaml
	misc
$ dvc exp show
dvc exp show
 ────────────────────────────────────────── 
  Experiment                Created    foo  
 ────────────────────────────────────────── 
  workspace                 -            1  
  master                    11:57 AM     -  
  └── 71d7388 [exp-48145]   11:58 AM     1  
 ────────────────────────────────────────── 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In DVCLive I am hardcoding the option to include_untracked=self.dir

save_parser.set_defaults(func=CmdExperimentsSave)
5 changes: 5 additions & 0 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,11 @@ def run(self, *args, **kwargs):

return run(self.repo, *args, **kwargs)

def save(self, *args, **kwargs):
from dvc.repo.experiments.save import save

return save(self.repo, *args, **kwargs)

def gc(self, *args, **kwargs):
from dvc.repo.experiments.gc import gc

Expand Down
62 changes: 59 additions & 3 deletions dvc/repo/experiments/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os
from contextlib import ExitStack
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from funcy import cached_property, retry
from scmrepo.exceptions import SCMError as _SCMError
from shortuuid import uuid

from dvc.exceptions import DvcException
from dvc.lock import LockError
from dvc.scm import SCM, GitMergeError
from dvc.utils.fs import makedirs, remove
Expand All @@ -21,17 +22,18 @@
EXEC_MERGE,
EXEC_NAMESPACE,
EXPS_TEMP,
ExpRefInfo,
)
from ..utils import EXEC_TMP_DIR, get_exp_rwlock
from .base import BaseExecutor, TaskStatus
from .base import BaseExecutor, ExecutorResult, TaskStatus

if TYPE_CHECKING:
from scmrepo.git import Git

from dvc.repo import Repo

from ..refs import ExpRefInfo
from ..stash import ExpStashEntry
from .base import ExecutorInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -247,3 +249,57 @@ def cleanup(self, infofile: str):
checkpoint = self.scm.get_ref(EXEC_CHECKPOINT)
if checkpoint and checkpoint != self._orig_checkpoint:
self.scm.set_ref(EXEC_APPLY, checkpoint)

@classmethod
def save(
cls,
info: "ExecutorInfo",
force: bool = False,
include_untracked: Optional[List[str]] = None,
) -> ExecutorResult:
from dvc.repo import Repo

exp_hash: Optional[str] = None
exp_ref: Optional[ExpRefInfo] = None

dvc = Repo(os.path.join(info.root_dir, info.dvc_dir))
old_cwd = os.getcwd()
if info.wdir:
os.chdir(os.path.join(dvc.scm.root_dir, info.wdir))
else:
os.chdir(dvc.root_dir)

try:
stages = dvc.commit([], force=force)
exp_hash = cls.hash_exp(stages)
if include_untracked:
dvc.scm.add(include_untracked)
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
force=force,
)
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref = ExpRefInfo.from_ref(ref) if ref else None
untracked = dvc.scm.untracked_files()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: this check might slow down exp save considerably when large untracked directories are present, but I guess it's fine for now

if untracked:
logger.warning(
"The following untracked files were present in "
"the workspace before saving but "
"will not be included in the experiment commit:\n"
"\t%s",
", ".join(untracked),
)
info.result_hash = exp_hash
info.result_ref = ref
info.result_force = False
info.status = TaskStatus.SUCCESS
except DvcException:
info.status = TaskStatus.FAILED
raise
finally:
dvc.close()
os.chdir(old_cwd)

return ExecutorResult(ref, exp_ref, info.result_force)
44 changes: 44 additions & 0 deletions dvc/repo/experiments/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
import os
from typing import TYPE_CHECKING, List, Optional

from funcy import first

if TYPE_CHECKING:
from dvc.repo import Repo


logger = logging.getLogger(__name__)


def save(
repo: "Repo",
name: Optional[str] = None,
force: bool = False,
include_untracked: Optional[List[str]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2. Neither dvc exp run or dvc exp save saves arbitrary new files, even if they are added to git. This is related to the idea from the previous PR that it should include all untracked files. It might not be a blocker for now if solving the first issue is enough for dvclive, but it seems like there needs to be some mechanism to save new files to an experiment (either all untracked files or at least those that have been added to git).

@dberenbaum To cover the DVCLive scenario I included this option (which is not exposed in the CLI but I can expose it).

This allows to pass pattern(s) of (potentially) untracked files to run git add internally.

) -> Optional[str]:
"""Save the current workspace status as an experiment.

Returns the saved experiment's SHAs.
"""
queue = repo.experiments.workspace_queue
logger.debug("Saving workspace in %s", os.getcwd())

staged, _, _ = repo.scm.status(untracked_files="no")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're checking for untracked files above, how about:

Suggested change
staged, _, _ = repo.scm.status(untracked_files="no")
staged, unstaged,untracked = repo.scm.status(untracked_files="no")
if untracked:
logger.warning(
"The following untracked files were present in "
"the workspace before saving but "
"will not be included in the experiment commit:\n"
"\t%s",
", ".join(untracked),
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this change and remove the other call, but we need to handle include_untracked

if staged:
logger.warning(
"Your workspace contains staged Git changes which will be "
"unstaged before saving this experiment."
)
repo.scm.reset()

entry = repo.experiments.new(queue=queue, name=name, force=force)
executor = queue.init_executor(repo.experiments, entry)

save_result = executor.save(
executor.info, force=force, include_untracked=include_untracked
)
result = queue.collect_executor(repo.experiments, executor, save_result)

exp_rev = first(result)
return exp_rev
3 changes: 1 addition & 2 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from dvc.dvcfile import PIPELINE_FILE
from dvc.exceptions import ReproductionError
from dvc.repo.experiments.exceptions import ExperimentExistsError
from dvc.repo.experiments.queue.base import BaseStashQueue
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.scm import resolve_rev
Expand Down Expand Up @@ -44,8 +45,6 @@ def test_new_simple(tmp_dir, scm, dvc, exp_stage, mocker, name, workspace):


def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace):
from dvc.repo.experiments.exceptions import ExperimentExistsError

dvc.experiments.run(
exp_stage.addressing,
name="foo",
Expand Down
85 changes: 85 additions & 0 deletions tests/func/experiments/test_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from contextlib import nullcontext

import pytest
from funcy import first

from dvc.repo.experiments.exceptions import ExperimentExistsError
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.scm import resolve_rev
from dvc.stage.exceptions import StageCommitError


@pytest.mark.parametrize("name", (None, "test"))
def test_exp_save(tmp_dir, dvc, scm, exp_stage, name):
baseline = scm.get_rev()

exp = dvc.experiments.save(name=name)
ref_info = first(exp_refs_by_rev(scm, exp))
assert ref_info and ref_info.baseline_sha == baseline

exp_name = name if name else ref_info.name
assert dvc.experiments.get_exact_name([exp])[exp] == exp_name
assert resolve_rev(scm, exp_name) == exp


@pytest.mark.parametrize(
("force", "expected_raises"),
(
(False, pytest.raises(StageCommitError)),
(True, nullcontext()),
),
)
def test_exp_save_force(tmp_dir, dvc, scm, exp_stage, force, expected_raises):
with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh:
fh.write("\n# dummy change")

with expected_raises:
dvc.experiments.save(force=force)


def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm, exp_stage):
dvc.experiments.save(name="dummy")

with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh:
fh.write("\n# dummy change")

with pytest.raises(ExperimentExistsError):
dvc.experiments.save(name="dummy")

dvc.experiments.save(name="dummy", force=True)


def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage):
baseline = scm.get_rev()
dvc.experiments.save(name="exp-1")

tmp_dir.scm_gen({"new_file": "new_file"}, commit="new baseline")
dvc.experiments.save(name="exp-2")

all_exps = dvc.experiments.ls(all_commits=True)
assert all_exps[baseline[:7]] == ["exp-1"]
assert all_exps["master"] == ["exp-2"]


def test_exp_save_with_staged_changes(tmp_dir, dvc, scm):
tmp_dir.gen({"new_file": "new_file"})
scm.add("new_file")

dvc.experiments.save(name="exp")

_, _, unstaged = scm.status()
assert "new_file" in unstaged


def test_exp_save_include_untracked(tmp_dir, dvc, scm, exp_stage):
new_file = tmp_dir / "new_file"
for i in range(2):
new_file.write_text(f"exp-{i}")
dvc.experiments.save(name=f"exp-{i}", include_untracked=["new_file"])

_, _, unstaged = scm.status()
assert "new_file" in unstaged
assert new_file.read_text() == f"exp-{i}"

dvc.experiments.apply("exp-0")
assert new_file.read_text() == "exp-0"
15 changes: 15 additions & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dvc.commands.experiments.push import CmdExperimentsPush
from dvc.commands.experiments.remove import CmdExperimentsRemove
from dvc.commands.experiments.run import CmdExperimentsRun
from dvc.commands.experiments.save import CmdExperimentsSave
from dvc.commands.experiments.show import CmdExperimentsShow, show_experiments
from dvc.exceptions import InvalidArgumentError
from dvc.repo import Repo
Expand Down Expand Up @@ -934,3 +935,17 @@ def test_show_experiments_pcp(tmp_dir, mocker):

assert kwargs["output_path"] == str(tmp_dir / "dvc_plots" / "index.html")
assert kwargs["color_by"] == "Experiment"


def test_experiments_save(dvc, scm, mocker):
cli_args = parse_args(["exp", "save", "--name", "exp-name", "--force"])
assert cli_args.func == CmdExperimentsSave

cmd = cli_args.func(cli_args)
m = mocker.patch("dvc.repo.experiments.save.save", return_value="acabb")

assert cmd.run() == 0

m.assert_called_once_with(
cmd.repo, name="exp-name", force=True, include_untracked=[]
)