diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 3c98e9dd02..bdb9226f5a 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -670,7 +670,7 @@ def add_parser(subparsers, parent_parser): ) experiments_run_parser.add_argument( "--continue", - nargs=1, + type=str, default=None, dest="checkpoint_continue", help=( diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 361757b429..c1e488b24d 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -66,6 +66,15 @@ def __init__(self, rev, expected): self.expected_rev = expected +class MultipleBranchError(DvcException): + def __init__(self, rev): + super().__init__( + f"Ambiguous commit '{rev[:7]}' belongs to multiple experiment " + "branches." + ) + self.rev = rev + + class Experiments: """Class that manages experiments in a DVC repo. @@ -322,15 +331,16 @@ def _commit( if ( check_exists or checkpoint ) and exp_name in self.scm.list_branches(): + branch_tip = self.scm.resolve_rev(exp_name) if checkpoint: raise DvcException( f"Checkpoint experiment containing '{rev[:7]}' " "already exists. To resume the experiment, " "run:\n\n" - f"\tdvc exp run --continue {rev[:7]}" + f"\tdvc exp run --continue {branch_tip[:7]}" ) logger.debug("Using existing experiment branch '%s'", exp_name) - return self.scm.resolve_rev(exp_name) + return branch_tip self.scm.checkout(exp_name, create_new=True) logger.debug("Commit new experiment branch '%s'", exp_name) else: @@ -383,7 +393,8 @@ def new( workspace. """ if checkpoint_continue: - branch = self._get_branch_containing(checkpoint_continue) + rev = self.scm.resolve_rev(checkpoint_continue) + branch = self._get_branch_containing(rev) if not branch: raise DvcException( "Could not find checkpoint experiment " @@ -772,7 +783,12 @@ def _get_branch_containing(self, rev): if self.scm.repo.head.is_detached: self._checkout_default_branch() try: - name = self.scm.repo.git.branch(contains=rev) + names = self.scm.repo.git.branch(contains=rev).strip().splitlines() + if not names: + return None + if len(names) > 1: + raise MultipleBranchError(rev) + name = names[0] if name.startswith("*"): name = name[1:] return name.rsplit("/")[-1].strip()