Skip to content

Commit

Permalink
experiments: add exp branch for promoting experiment to full branch
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla committed Dec 2, 2020
1 parent 1692c0c commit 27831e7
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 4 deletions.
28 changes: 28 additions & 0 deletions dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,16 @@ def run(self):
return 0


class CmdExperimentsBranch(CmdBase):
def run(self):
if not self.repo.experiments:
return 0

self.repo.experiments.branch(self.args.experiment, self.args.branch)

return 0


def add_parser(subparsers, parent_parser):
EXPERIMENTS_HELP = "Commands to display and compare experiments."

Expand Down Expand Up @@ -833,6 +843,24 @@ def add_parser(subparsers, parent_parser):
)
experiments_gc_parser.set_defaults(func=CmdExperimentsGC)

EXPERIMENTS_BRANCH_HELP = "Promote an experiment to a Git branch."
experiments_branch_parser = experiments_subparsers.add_parser(
"branch",
parents=[parent_parser],
description=append_doc_link(
EXPERIMENTS_BRANCH_HELP, "experiments/branch"
),
help=EXPERIMENTS_BRANCH_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
experiments_branch_parser.add_argument(
"experiment", help="Experiment to be promoted.",
)
experiments_branch_parser.add_argument(
"branch", help="Git branch name to use.",
)
experiments_branch_parser.set_defaults(func=CmdExperimentsBranch)


def _add_run_common(parser):
"""Add common args for 'exp run' and 'exp resume'."""
Expand Down
21 changes: 17 additions & 4 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,14 @@ def _log_reproduced(self, revs: Iterable[str]):
for rev in revs:
name = self.get_exact_name(rev)
names.append(name if name else rev[:7])
msg = (
"\nReproduced experiment(s): {}\n"
fmt = (
"\nReproduced experiment(s): %s\n"
"To promote an experiment to a Git branch run:\n\n"
"\tdvc exp branch <exp>\n\n"
"To apply the results of an experiment to your workspace run:\n\n"
"\tdvc exp apply <exp>"
).format(", ".join(names))
logger.info(msg)
)
logger.info(fmt, ", ".join(names))

@scm_locked
def new(
Expand Down Expand Up @@ -621,11 +621,24 @@ def get_exact_name(self, rev: str):
return ExpRefInfo.from_ref(ref).name
return None

def iter_ref_infos_by_name(self, name: str):
for ref in self.scm.iter_refs(base=EXPS_NAMESPACE):
if ref.startswith(EXEC_NAMESPACE) or ref == EXPS_STASH:
continue
ref_info = ExpRefInfo.from_ref(ref)
if ref_info.name == name:
yield ref_info

def apply(self, *args, **kwargs):
from dvc.repo.experiments.apply import apply

return apply(self.repo, *args, **kwargs)

def branch(self, *args, **kwargs):
from dvc.repo.experiments.branch import branch

return branch(self.repo, *args, **kwargs)

def diff(self, *args, **kwargs):
from dvc.repo.experiments.diff import diff

Expand Down
39 changes: 39 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,42 @@ def test_detached_parent(tmp_dir, scm, dvc, exp_stage, mocker):

dvc.experiments.apply(exp_rev)
assert (tmp_dir / "params.yaml").read_text().strip() == "foo: 3"


def test_branch(tmp_dir, scm, dvc, exp_stage):
from dvc.exceptions import InvalidArgumentError

with pytest.raises(InvalidArgumentError):
dvc.experiments.branch("foo", "branch")

scm.branch("branch-exists")

results = dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], name="foo"
)
exp_a = first(results)
ref_a = dvc.experiments.get_branch_containing(exp_a)

with pytest.raises(InvalidArgumentError):
dvc.experiments.branch("foo", "branch-exists")
dvc.experiments.branch("foo", "branch-name")
dvc.experiments.branch(exp_a, "branch-rev")
dvc.experiments.branch(ref_a, "branch-ref")

for name in ["branch-name", "branch-rev", "branch-ref"]:
assert name in scm.list_branches()
assert scm.resolve_rev(name) == exp_a

tmp_dir.scm_gen({"new_file": "new_file"}, commit="new baseline")
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], name="foo"
)
exp_b = first(results)
ref_b = dvc.experiments.get_branch_containing(exp_b)

with pytest.raises(InvalidArgumentError):
dvc.experiments.branch("foo", "branch-name")
dvc.experiments.branch(ref_b, "branch-ref-b")

assert "branch-ref-b" in scm.list_branches()
assert scm.resolve_rev("branch-ref-b") == exp_b
13 changes: 13 additions & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dvc.cli import parse_args
from dvc.command.experiments import (
CmdExperimentsApply,
CmdExperimentsBranch,
CmdExperimentsDiff,
CmdExperimentsGC,
CmdExperimentsRun,
Expand Down Expand Up @@ -137,3 +138,15 @@ def test_experiments_gc(dvc, mocker):
cmd = cli_args.func(cli_args)
with pytest.raises(InvalidArgumentError):
cmd.run()


def test_experiments_branch(dvc, mocker):
cli_args = parse_args(["experiments", "branch", "expname", "branchname"])
assert cli_args.func == CmdExperimentsBranch

cmd = cli_args.func(cli_args)
m = mocker.patch("dvc.repo.experiments.branch.branch", return_value={})

assert cmd.run() == 0

m.assert_called_once_with(cmd.repo, "expname", "branchname")

0 comments on commit 27831e7

Please sign in to comment.