diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index ba2468f75c..a86a0b3675 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -417,22 +417,44 @@ def _resume_checkpoint( ) last_applied = self.scm.get_ref(EXEC_APPLY) + try: + if last_applied: + self.check_baseline(last_applied) + self.check_baseline(resume_rev) + except BaselineMismatchError: + # If HEAD has moved since the the last applied checkpoint, + # the applied checkpoint is no longer valid + self.scm.remove_ref(EXEC_APPLY) + last_applied = None + checkpoint_resume = None if resume_rev != last_applied: - if last_applied is None: - msg = "Current workspace does not contain an experiment. " + if checkpoint_resume == self.LAST_CHECKPOINT: + display_rev: Optional[str] = resume_rev[:7] else: + display_rev = checkpoint_resume + + if display_rev: + if last_applied is None: + msg = ( + f"Checkpoint '{display_rev}' cannot be resumed until " + "it is applied to your workspace." + ) + else: + msg = ( + f"Checkpoint '{display_rev}' does not match the " + "most recently applied experiment in your workspace " + f"('{last_applied[:7]}')." + ) msg = ( - f"Checkpoint '{checkpoint_resume[:7]}' does not match the " - "most recently applied experiment in your workspace " - f"('{last_applied[:7]}')." + f"{msg}\n" + "To resume this experiment run:\n\n" + f"\tdvc exp apply {display_rev}\n\n" + "And then retry this 'dvc exp res' command." ) + else: + msg = "No existing checkpoint to resume in your workspace." - raise DvcException( - f"{msg}\n" - "To resume this experiment run:\n\n" - f"\tdvc exp apply {checkpoint_resume[:7]}\n\n" - "And then retry this 'dvc exp res' command." - ) + raise DvcException(msg) baseline_rev = self._get_baseline(branch) logger.debug( @@ -449,7 +471,7 @@ def _resume_checkpoint( **kwargs, ) - def _get_last_checkpoint(self): + def _get_last_checkpoint(self) -> str: rev = self.scm.get_ref(EXEC_CHECKPOINT) if rev: return rev @@ -674,6 +696,7 @@ def _workspace_repro(self) -> Mapping[str, str]: elif self.scm.get_ref(EXEC_BRANCH): self.scm.remove_ref(EXEC_BRANCH) try: + orig_checkpoint = self.scm.get_ref(EXEC_CHECKPOINT) exec_result = BaseExecutor.reproduce( None, rev, @@ -701,6 +724,9 @@ def _workspace_repro(self) -> Mapping[str, str]: self.scm.remove_ref(EXEC_BASELINE) if entry.branch: self.scm.remove_ref(EXEC_BRANCH) + checkpoint = self.scm.get_ref(EXEC_CHECKPOINT) + if checkpoint and checkpoint != orig_checkpoint: + self.scm.set_ref(EXEC_APPLY, checkpoint) def check_baseline(self, exp_rev): baseline_sha = self.repo.scm.get_rev() diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 8e74e58b7c..d7861d5804 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -2,7 +2,7 @@ import typing from functools import partial -from dvc.exceptions import ReproductionError +from dvc.exceptions import DvcException, ReproductionError from dvc.repo.scm_context import scm_context from . import locked @@ -24,9 +24,9 @@ def _run_callback(repro_callback): if checkpoint_func: kwargs["checkpoint_func"] = partial(_run_callback, checkpoint_func) else: - logger.warning( - "Checkpoint stages are not fully supported in 'dvc repro'. " - "Checkpoint stages should be reproduced with 'dvc exp run' " + raise DvcException( + "Checkpoint stages are not supported in 'dvc repro'. " + "Checkpoint stages must be reproduced with 'dvc exp run' " "or 'dvc exp resume'." )