From 09e51e655d9299f229834094d571086ed31e8b76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Rowlands=20=28=EB=B3=80=EA=B8=B0=ED=98=B8=29?= Date: Wed, 14 Oct 2020 19:12:57 +0900 Subject: [PATCH] checkpoints: cleanup --continue usage (#4716) * warn if rev provided with --continue is ambiguous * always provide tip of branch when suggesting to use --continue --- dvc/command/experiments.py | 2 +- dvc/repo/experiments/__init__.py | 24 ++++++++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 3c98e9dd02..bdb9226f5a 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -670,7 +670,7 @@ def add_parser(subparsers, parent_parser): ) experiments_run_parser.add_argument( "--continue", - nargs=1, + type=str, default=None, dest="checkpoint_continue", help=( diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 361757b429..c1e488b24d 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -66,6 +66,15 @@ def __init__(self, rev, expected): self.expected_rev = expected +class MultipleBranchError(DvcException): + def __init__(self, rev): + super().__init__( + f"Ambiguous commit '{rev[:7]}' belongs to multiple experiment " + "branches." + ) + self.rev = rev + + class Experiments: """Class that manages experiments in a DVC repo. @@ -322,15 +331,16 @@ def _commit( if ( check_exists or checkpoint ) and exp_name in self.scm.list_branches(): + branch_tip = self.scm.resolve_rev(exp_name) if checkpoint: raise DvcException( f"Checkpoint experiment containing '{rev[:7]}' " "already exists. To resume the experiment, " "run:\n\n" - f"\tdvc exp run --continue {rev[:7]}" + f"\tdvc exp run --continue {branch_tip[:7]}" ) logger.debug("Using existing experiment branch '%s'", exp_name) - return self.scm.resolve_rev(exp_name) + return branch_tip self.scm.checkout(exp_name, create_new=True) logger.debug("Commit new experiment branch '%s'", exp_name) else: @@ -383,7 +393,8 @@ def new( workspace. """ if checkpoint_continue: - branch = self._get_branch_containing(checkpoint_continue) + rev = self.scm.resolve_rev(checkpoint_continue) + branch = self._get_branch_containing(rev) if not branch: raise DvcException( "Could not find checkpoint experiment " @@ -772,7 +783,12 @@ def _get_branch_containing(self, rev): if self.scm.repo.head.is_detached: self._checkout_default_branch() try: - name = self.scm.repo.git.branch(contains=rev) + names = self.scm.repo.git.branch(contains=rev).strip().splitlines() + if not names: + return None + if len(names) > 1: + raise MultipleBranchError(rev) + name = names[0] if name.startswith("*"): name = name[1:] return name.rsplit("/")[-1].strip()