From 27831e7ff1ba98d1d81fcb9d09453642b65ce936 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 2 Dec 2020 20:01:49 +0900 Subject: [PATCH] experiments: add `exp branch` for promoting experiment to full branch --- dvc/command/experiments.py | 28 ++++++++++++++++ dvc/repo/experiments/__init__.py | 21 +++++++++--- tests/func/experiments/test_experiments.py | 39 ++++++++++++++++++++++ tests/unit/command/test_experiments.py | 13 ++++++++ 4 files changed, 97 insertions(+), 4 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 0475b90730..c3d2b5747f 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -534,6 +534,16 @@ def run(self): return 0 +class CmdExperimentsBranch(CmdBase): + def run(self): + if not self.repo.experiments: + return 0 + + self.repo.experiments.branch(self.args.experiment, self.args.branch) + + return 0 + + def add_parser(subparsers, parent_parser): EXPERIMENTS_HELP = "Commands to display and compare experiments." @@ -833,6 +843,24 @@ def add_parser(subparsers, parent_parser): ) experiments_gc_parser.set_defaults(func=CmdExperimentsGC) + EXPERIMENTS_BRANCH_HELP = "Promote an experiment to a Git branch." + experiments_branch_parser = experiments_subparsers.add_parser( + "branch", + parents=[parent_parser], + description=append_doc_link( + EXPERIMENTS_BRANCH_HELP, "experiments/branch" + ), + help=EXPERIMENTS_BRANCH_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + experiments_branch_parser.add_argument( + "experiment", help="Experiment to be promoted.", + ) + experiments_branch_parser.add_argument( + "branch", help="Git branch name to use.", + ) + experiments_branch_parser.set_defaults(func=CmdExperimentsBranch) + def _add_run_common(parser): """Add common args for 'exp run' and 'exp resume'.""" diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 3a47b7b113..e32357834f 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -293,14 +293,14 @@ def _log_reproduced(self, revs: Iterable[str]): for rev in revs: name = self.get_exact_name(rev) names.append(name if name else rev[:7]) - msg = ( - "\nReproduced experiment(s): {}\n" + fmt = ( + "\nReproduced experiment(s): %s\n" "To promote an experiment to a Git branch run:\n\n" "\tdvc exp branch \n\n" "To apply the results of an experiment to your workspace run:\n\n" "\tdvc exp apply " - ).format(", ".join(names)) - logger.info(msg) + ) + logger.info(fmt, ", ".join(names)) @scm_locked def new( @@ -621,11 +621,24 @@ def get_exact_name(self, rev: str): return ExpRefInfo.from_ref(ref).name return None + def iter_ref_infos_by_name(self, name: str): + for ref in self.scm.iter_refs(base=EXPS_NAMESPACE): + if ref.startswith(EXEC_NAMESPACE) or ref == EXPS_STASH: + continue + ref_info = ExpRefInfo.from_ref(ref) + if ref_info.name == name: + yield ref_info + def apply(self, *args, **kwargs): from dvc.repo.experiments.apply import apply return apply(self.repo, *args, **kwargs) + def branch(self, *args, **kwargs): + from dvc.repo.experiments.branch import branch + + return branch(self.repo, *args, **kwargs) + def diff(self, *args, **kwargs): from dvc.repo.experiments.diff import diff diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index b63564dfe9..032e7e82ee 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -266,3 +266,42 @@ def test_detached_parent(tmp_dir, scm, dvc, exp_stage, mocker): dvc.experiments.apply(exp_rev) assert (tmp_dir / "params.yaml").read_text().strip() == "foo: 3" + + +def test_branch(tmp_dir, scm, dvc, exp_stage): + from dvc.exceptions import InvalidArgumentError + + with pytest.raises(InvalidArgumentError): + dvc.experiments.branch("foo", "branch") + + scm.branch("branch-exists") + + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=2"], name="foo" + ) + exp_a = first(results) + ref_a = dvc.experiments.get_branch_containing(exp_a) + + with pytest.raises(InvalidArgumentError): + dvc.experiments.branch("foo", "branch-exists") + dvc.experiments.branch("foo", "branch-name") + dvc.experiments.branch(exp_a, "branch-rev") + dvc.experiments.branch(ref_a, "branch-ref") + + for name in ["branch-name", "branch-rev", "branch-ref"]: + assert name in scm.list_branches() + assert scm.resolve_rev(name) == exp_a + + tmp_dir.scm_gen({"new_file": "new_file"}, commit="new baseline") + results = dvc.experiments.run( + exp_stage.addressing, params=["foo=2"], name="foo" + ) + exp_b = first(results) + ref_b = dvc.experiments.get_branch_containing(exp_b) + + with pytest.raises(InvalidArgumentError): + dvc.experiments.branch("foo", "branch-name") + dvc.experiments.branch(ref_b, "branch-ref-b") + + assert "branch-ref-b" in scm.list_branches() + assert scm.resolve_rev("branch-ref-b") == exp_b diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index f5c1bd47cc..0f34088303 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -3,6 +3,7 @@ from dvc.cli import parse_args from dvc.command.experiments import ( CmdExperimentsApply, + CmdExperimentsBranch, CmdExperimentsDiff, CmdExperimentsGC, CmdExperimentsRun, @@ -137,3 +138,15 @@ def test_experiments_gc(dvc, mocker): cmd = cli_args.func(cli_args) with pytest.raises(InvalidArgumentError): cmd.run() + + +def test_experiments_branch(dvc, mocker): + cli_args = parse_args(["experiments", "branch", "expname", "branchname"]) + assert cli_args.func == CmdExperimentsBranch + + cmd = cli_args.func(cli_args) + m = mocker.patch("dvc.repo.experiments.branch.branch", return_value={}) + + assert cmd.run() == 0 + + m.assert_called_once_with(cmd.repo, "expname", "branchname")