-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
exp save: initial implementation #8599
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import argparse | ||
import logging | ||
|
||
from dvc.cli.command import CmdBase | ||
from dvc.cli.utils import append_doc_link | ||
from dvc.exceptions import DvcException | ||
from dvc.ui import ui | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CmdExperimentsSave(CmdBase): | ||
def run(self): | ||
|
||
try: | ||
ref = self.repo.experiments.save( | ||
name=self.args.name, | ||
force=self.args.force, | ||
include_untracked=self.args.include_untracked, | ||
) | ||
except DvcException: | ||
logger.exception("failed to save experiment") | ||
return 1 | ||
|
||
if self.args.json: | ||
ui.write_json({"ref": ref}) | ||
else: | ||
name = self.repo.experiments.get_exact_name([ref])[ref] | ||
ui.write(f"Experiment has been saved as: {name}") | ||
ui.write( | ||
"\nTo promote an experiment to a Git branch run:\n\n" | ||
"\tdvc exp branch <exp> <branch>\n" | ||
) | ||
|
||
return 0 | ||
|
||
|
||
def add_parser(experiments_subparsers, parent_parser): | ||
EXPERIMENTS_SAVE_HELP = "Save current workspace as a dvc experiment." | ||
save_parser = experiments_subparsers.add_parser( | ||
"save", | ||
parents=[parent_parser], | ||
description=append_doc_link(EXPERIMENTS_SAVE_HELP, "exp/save"), | ||
help=EXPERIMENTS_SAVE_HELP, | ||
formatter_class=argparse.RawDescriptionHelpFormatter, | ||
) | ||
save_parser.add_argument( | ||
"-f", | ||
"--force", | ||
action="store_true", | ||
default=False, | ||
help="Save even if hash value for dependencies/outputs changed.", | ||
) | ||
save_parser.add_argument( | ||
"--json", | ||
"--show-json", | ||
action="store_true", | ||
default=False, | ||
help="Show output in JSON format.", | ||
) | ||
save_parser.add_argument( | ||
"-n", | ||
"--name", | ||
default=None, | ||
help=( | ||
"Human-readable experiment name. If not specified, a name will " | ||
"be auto-generated." | ||
), | ||
metavar="<name>", | ||
) | ||
save_parser.add_argument( | ||
"-I", | ||
"--include-untracked", | ||
action="append", | ||
default=[], | ||
help="List of untracked paths to include in the experiment.", | ||
metavar="<path>", | ||
) | ||
save_parser.set_defaults(func=CmdExperimentsSave) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,13 @@ | |
import os | ||
from contextlib import ExitStack | ||
from tempfile import mkdtemp | ||
from typing import TYPE_CHECKING, Optional | ||
from typing import TYPE_CHECKING, List, Optional | ||
|
||
from funcy import cached_property, retry | ||
from scmrepo.exceptions import SCMError as _SCMError | ||
from shortuuid import uuid | ||
|
||
from dvc.exceptions import DvcException | ||
from dvc.lock import LockError | ||
from dvc.scm import SCM, GitMergeError | ||
from dvc.utils.fs import makedirs, remove | ||
|
@@ -21,17 +22,18 @@ | |
EXEC_MERGE, | ||
EXEC_NAMESPACE, | ||
EXPS_TEMP, | ||
ExpRefInfo, | ||
) | ||
from ..utils import EXEC_TMP_DIR, get_exp_rwlock | ||
from .base import BaseExecutor, TaskStatus | ||
from .base import BaseExecutor, ExecutorResult, TaskStatus | ||
|
||
if TYPE_CHECKING: | ||
from scmrepo.git import Git | ||
|
||
from dvc.repo import Repo | ||
|
||
from ..refs import ExpRefInfo | ||
from ..stash import ExpStashEntry | ||
from .base import ExecutorInfo | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -247,3 +249,57 @@ def cleanup(self, infofile: str): | |
checkpoint = self.scm.get_ref(EXEC_CHECKPOINT) | ||
if checkpoint and checkpoint != self._orig_checkpoint: | ||
self.scm.set_ref(EXEC_APPLY, checkpoint) | ||
|
||
@classmethod | ||
def save( | ||
cls, | ||
info: "ExecutorInfo", | ||
force: bool = False, | ||
include_untracked: Optional[List[str]] = None, | ||
) -> ExecutorResult: | ||
from dvc.repo import Repo | ||
|
||
exp_hash: Optional[str] = None | ||
exp_ref: Optional[ExpRefInfo] = None | ||
|
||
dvc = Repo(os.path.join(info.root_dir, info.dvc_dir)) | ||
old_cwd = os.getcwd() | ||
if info.wdir: | ||
os.chdir(os.path.join(dvc.scm.root_dir, info.wdir)) | ||
else: | ||
os.chdir(dvc.root_dir) | ||
|
||
try: | ||
stages = dvc.commit([], force=force) | ||
exp_hash = cls.hash_exp(stages) | ||
if include_untracked: | ||
dvc.scm.add(include_untracked) | ||
cls.commit( | ||
dvc.scm, | ||
exp_hash, | ||
exp_name=info.name, | ||
force=force, | ||
) | ||
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False) | ||
exp_ref = ExpRefInfo.from_ref(ref) if ref else None | ||
untracked = dvc.scm.untracked_files() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a note: this check might slow down |
||
if untracked: | ||
logger.warning( | ||
"The following untracked files were present in " | ||
"the workspace before saving but " | ||
"will not be included in the experiment commit:\n" | ||
"\t%s", | ||
", ".join(untracked), | ||
) | ||
info.result_hash = exp_hash | ||
info.result_ref = ref | ||
info.result_force = False | ||
info.status = TaskStatus.SUCCESS | ||
except DvcException: | ||
info.status = TaskStatus.FAILED | ||
raise | ||
finally: | ||
dvc.close() | ||
os.chdir(old_cwd) | ||
|
||
return ExecutorResult(ref, exp_ref, info.result_force) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,44 @@ | ||||||||||||||||||||||
import logging | ||||||||||||||||||||||
import os | ||||||||||||||||||||||
from typing import TYPE_CHECKING, List, Optional | ||||||||||||||||||||||
|
||||||||||||||||||||||
from funcy import first | ||||||||||||||||||||||
|
||||||||||||||||||||||
if TYPE_CHECKING: | ||||||||||||||||||||||
from dvc.repo import Repo | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def save( | ||||||||||||||||||||||
repo: "Repo", | ||||||||||||||||||||||
name: Optional[str] = None, | ||||||||||||||||||||||
force: bool = False, | ||||||||||||||||||||||
include_untracked: Optional[List[str]] = None, | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@dberenbaum To cover the DVCLive scenario I included this option (which is not exposed in the CLI but I can expose it). This allows to pass pattern(s) of (potentially) untracked files to run |
||||||||||||||||||||||
) -> Optional[str]: | ||||||||||||||||||||||
"""Save the current workspace status as an experiment. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Returns the saved experiment's SHAs. | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
queue = repo.experiments.workspace_queue | ||||||||||||||||||||||
logger.debug("Saving workspace in %s", os.getcwd()) | ||||||||||||||||||||||
|
||||||||||||||||||||||
staged, _, _ = repo.scm.status(untracked_files="no") | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we're checking for untracked files above, how about:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make this change and remove the other call, but we need to handle |
||||||||||||||||||||||
if staged: | ||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||
"Your workspace contains staged Git changes which will be " | ||||||||||||||||||||||
"unstaged before saving this experiment." | ||||||||||||||||||||||
) | ||||||||||||||||||||||
repo.scm.reset() | ||||||||||||||||||||||
|
||||||||||||||||||||||
entry = repo.experiments.new(queue=queue, name=name, force=force) | ||||||||||||||||||||||
executor = queue.init_executor(repo.experiments, entry) | ||||||||||||||||||||||
|
||||||||||||||||||||||
save_result = executor.save( | ||||||||||||||||||||||
executor.info, force=force, include_untracked=include_untracked | ||||||||||||||||||||||
) | ||||||||||||||||||||||
result = queue.collect_executor(repo.experiments, executor, save_result) | ||||||||||||||||||||||
|
||||||||||||||||||||||
exp_rev = first(result) | ||||||||||||||||||||||
return exp_rev |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from contextlib import nullcontext | ||
|
||
import pytest | ||
from funcy import first | ||
|
||
from dvc.repo.experiments.exceptions import ExperimentExistsError | ||
from dvc.repo.experiments.utils import exp_refs_by_rev | ||
from dvc.scm import resolve_rev | ||
from dvc.stage.exceptions import StageCommitError | ||
|
||
|
||
@pytest.mark.parametrize("name", (None, "test")) | ||
def test_exp_save(tmp_dir, dvc, scm, exp_stage, name): | ||
baseline = scm.get_rev() | ||
|
||
exp = dvc.experiments.save(name=name) | ||
ref_info = first(exp_refs_by_rev(scm, exp)) | ||
assert ref_info and ref_info.baseline_sha == baseline | ||
|
||
exp_name = name if name else ref_info.name | ||
assert dvc.experiments.get_exact_name([exp])[exp] == exp_name | ||
assert resolve_rev(scm, exp_name) == exp | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("force", "expected_raises"), | ||
( | ||
(False, pytest.raises(StageCommitError)), | ||
(True, nullcontext()), | ||
), | ||
) | ||
def test_exp_save_force(tmp_dir, dvc, scm, exp_stage, force, expected_raises): | ||
with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh: | ||
fh.write("\n# dummy change") | ||
|
||
with expected_raises: | ||
dvc.experiments.save(force=force) | ||
|
||
|
||
def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm, exp_stage): | ||
dvc.experiments.save(name="dummy") | ||
|
||
with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh: | ||
fh.write("\n# dummy change") | ||
|
||
with pytest.raises(ExperimentExistsError): | ||
dvc.experiments.save(name="dummy") | ||
|
||
dvc.experiments.save(name="dummy", force=True) | ||
|
||
|
||
def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage): | ||
baseline = scm.get_rev() | ||
dvc.experiments.save(name="exp-1") | ||
|
||
tmp_dir.scm_gen({"new_file": "new_file"}, commit="new baseline") | ||
dvc.experiments.save(name="exp-2") | ||
|
||
all_exps = dvc.experiments.ls(all_commits=True) | ||
assert all_exps[baseline[:7]] == ["exp-1"] | ||
assert all_exps["master"] == ["exp-2"] | ||
|
||
|
||
def test_exp_save_with_staged_changes(tmp_dir, dvc, scm): | ||
tmp_dir.gen({"new_file": "new_file"}) | ||
scm.add("new_file") | ||
|
||
dvc.experiments.save(name="exp") | ||
|
||
_, _, unstaged = scm.status() | ||
assert "new_file" in unstaged | ||
|
||
|
||
def test_exp_save_include_untracked(tmp_dir, dvc, scm, exp_stage): | ||
new_file = tmp_dir / "new_file" | ||
for i in range(2): | ||
new_file.write_text(f"exp-{i}") | ||
dvc.experiments.save(name=f"exp-{i}", include_untracked=["new_file"]) | ||
|
||
_, _, unstaged = scm.status() | ||
assert "new_file" in unstaged | ||
assert new_file.read_text() == f"exp-{i}" | ||
|
||
dvc.experiments.apply("exp-0") | ||
assert new_file.read_text() == "exp-0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dberenbaum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In DVCLive I am hardcoding the option to
include_untracked=self.dir