diff --git a/dvc/api.py b/dvc/api.py index bf79b682e6..0d0a45ba21 100644 --- a/dvc/api.py +++ b/dvc/api.py @@ -109,3 +109,34 @@ def _make_repo(repo_url=None, rev=None): pass # fallthrough to external_repo with external_repo(url=repo_url, rev=rev) as repo: yield repo + + +def make_checkpoint(): + """ + Signal DVC to create a checkpoint experiment. + + If the current process is being run from DVC, this function will block + until DVC has finished creating the checkpoint. Otherwise, this function + will return immediately. + """ + import builtins + from time import sleep + + from dvc.stage.run import CHECKPOINT_SIGNAL_FILE + + if os.getenv("DVC_CHECKPOINT") is None: + return + + root_dir = Repo.find_root() + signal_file = os.path.join( + root_dir, Repo.DVC_DIR, "tmp", CHECKPOINT_SIGNAL_FILE + ) + + with builtins.open(signal_file, "w") as fobj: + # NOTE: force flushing/writing empty file to disk, otherwise when + # run in certain contexts (pytest) file may not actually be written + fobj.write("") + fobj.flush() + os.fsync(fobj.fileno()) + while os.path.exists(signal_file): + sleep(1) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 8b54ed01bf..43927ee074 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1,13 +1,17 @@ import argparse import io import logging +import os from collections import OrderedDict from datetime import date from itertools import groupby from typing import Iterable, Optional +from dvc.command import completion from dvc.command.base import CmdBase, append_doc_link, fix_subparsers -from dvc.command.metrics import DEFAULT_PRECISION +from dvc.command.metrics import DEFAULT_PRECISION, _show_metrics +from dvc.command.status import CmdDataStatus +from dvc.dvcfile import PIPELINE_FILE from dvc.exceptions import DvcException, InvalidArgumentError from dvc.utils.flatten import flatten @@ -109,19 +113,30 @@ def _collect_rows( reverse = sort_order == "desc" experiments = _sort_exp(experiments, sort_by, sort_type, reverse) + last_tip = None for i, (rev, exp) in enumerate(experiments.items()): row = [] style = None queued = "*" if exp.get("queued", False) else "" + tip = exp.get("checkpoint_tip") if rev == "baseline": name = exp.get("name", base_rev) row.append(f"{name}") style = "bold" - elif i < len(experiments) - 1: - row.append(f"├── {queued}{rev[:7]}") else: - row.append(f"└── {queued}{rev[:7]}") + if tip and tip == last_tip: + tree = "│ ╟" + else: + if i < len(experiments) - 1: + if tip: + tree = "├─╥" + else: + tree = "├──" + else: + tree = "└──" + row.append(f"{tree} {queued}{rev[:7]}") + last_tip = tip if not no_timestamp: row.append(_format_time(exp.get("timestamp"))) @@ -373,6 +388,64 @@ def run(self): return 0 +class CmdExperimentsRun(CmdBase): + def run(self): + if not self.repo.experiments: + return 0 + + saved_dir = os.path.realpath(os.curdir) + os.chdir(self.args.cwd) + + # Dirty hack so the for loop below can at least enter once + if self.args.all_pipelines: + self.args.targets = [None] + elif not self.args.targets: + self.args.targets = self.default_targets + + ret = 0 + for target in self.args.targets: + try: + stages = self.repo.reproduce( + target, + single_item=self.args.single_item, + force=self.args.force, + dry=self.args.dry, + interactive=self.args.interactive, + pipeline=self.args.pipeline, + all_pipelines=self.args.all_pipelines, + run_cache=not self.args.no_run_cache, + no_commit=self.args.no_commit, + downstream=self.args.downstream, + recursive=self.args.recursive, + force_downstream=self.args.force_downstream, + experiment=True, + queue=self.args.queue, + run_all=self.args.run_all, + jobs=self.args.jobs, + params=self.args.params, + checkpoint=( + self.args.checkpoint + or self.args.checkpoint_continue is not None + ), + checkpoint_continue=self.args.checkpoint_continue, + ) + + if len(stages) == 0: + logger.info(CmdDataStatus.UP_TO_DATE_MSG) + + if self.args.metrics: + metrics = self.repo.metrics.show() + logger.info(_show_metrics(metrics)) + + except DvcException: + logger.exception("") + ret = 1 + break + + os.chdir(saved_dir) + return ret + + def add_parser(subparsers, parent_parser): EXPERIMENTS_HELP = "Commands to display and compare experiments." @@ -552,3 +625,156 @@ def add_parser(subparsers, parent_parser): metavar="", ) experiments_diff_parser.set_defaults(func=CmdExperimentsDiff) + + EXPERIMENTS_RUN_HELP = ( + "Reproduce complete or partial experiment pipelines." + ) + experiments_run_parser = experiments_subparsers.add_parser( + "run", + parents=[parent_parser], + description=append_doc_link(EXPERIMENTS_RUN_HELP, "experiments/run"), + help=EXPERIMENTS_RUN_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + experiments_run_parser.add_argument( + "targets", + nargs="*", + help=f"Stages to reproduce. '{PIPELINE_FILE}' by default.", + ).complete = completion.DVC_FILE + experiments_run_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Reproduce even if dependencies were not changed.", + ) + experiments_run_parser.add_argument( + "-s", + "--single-item", + action="store_true", + default=False, + help="Reproduce only single data item without recursive dependencies " + "check.", + ) + experiments_run_parser.add_argument( + "-c", + "--cwd", + default=os.path.curdir, + help="Directory within your repo to reproduce from. Note: deprecated " + "by `dvc --cd `.", + metavar="", + ) + experiments_run_parser.add_argument( + "-m", + "--metrics", + action="store_true", + default=False, + help="Show metrics after reproduction.", + ) + experiments_run_parser.add_argument( + "--dry", + action="store_true", + default=False, + help="Only print the commands that would be executed without " + "actually executing.", + ) + experiments_run_parser.add_argument( + "-i", + "--interactive", + action="store_true", + default=False, + help="Ask for confirmation before reproducing each stage.", + ) + experiments_run_parser.add_argument( + "-p", + "--pipeline", + action="store_true", + default=False, + help="Reproduce the whole pipeline that the specified stage file " + "belongs to.", + ) + experiments_run_parser.add_argument( + "-P", + "--all-pipelines", + action="store_true", + default=False, + help="Reproduce all pipelines in the repo.", + ) + experiments_run_parser.add_argument( + "-R", + "--recursive", + action="store_true", + default=False, + help="Reproduce all stages in the specified directory.", + ) + experiments_run_parser.add_argument( + "--no-run-cache", + action="store_true", + default=False, + help=( + "Execute stage commands even if they have already been run with " + "the same command/dependencies/outputs/etc before." + ), + ) + experiments_run_parser.add_argument( + "--force-downstream", + action="store_true", + default=False, + help="Reproduce all descendants of a changed stage even if their " + "direct dependencies didn't change.", + ) + experiments_run_parser.add_argument( + "--no-commit", + action="store_true", + default=False, + help="Don't put files/directories into cache.", + ) + experiments_run_parser.add_argument( + "--downstream", + action="store_true", + default=False, + help="Start from the specified stages when reproducing pipelines.", + ) + experiments_run_parser.add_argument( + "--params", + action="append", + default=[], + help="Use the specified param values when reproducing pipelines.", + metavar="[:]", + ) + experiments_run_parser.add_argument( + "--queue", + action="store_true", + default=False, + help="Stage this experiment in the run queue for future execution.", + ) + experiments_run_parser.add_argument( + "--run-all", + action="store_true", + default=False, + help="Execute all experiments in the run queue.", + ) + experiments_run_parser.add_argument( + "-j", + "--jobs", + type=int, + help="Run the specified number of experiments at a time in parallel.", + metavar="", + ) + experiments_run_parser.add_argument( + "--checkpoint", + action="store_true", + default=False, + help="Reproduce pipelines as a checkpoint experiment.", + ) + experiments_run_parser.add_argument( + "--continue", + nargs=1, + default=None, + dest="checkpoint_continue", + help=( + "Continue from the specified checkpoint experiment" + "(implies --checkpoint)." + ), + ) + experiments_run_parser.set_defaults(func=CmdExperimentsRun) diff --git a/dvc/output/base.py b/dvc/output/base.py index f575685b92..d0e4b14686 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -8,6 +8,7 @@ import dvc.prompt as prompt from dvc.cache import NamedCache from dvc.exceptions import ( + CheckoutError, CollectCacheError, DvcException, MergeError, @@ -325,6 +326,7 @@ def checkout( progress_callback=None, relink=False, filter_info=None, + allow_persist_missing=False, ): if not self.use_cache: if progress_callback: @@ -333,14 +335,19 @@ def checkout( ) return None - return self.cache.checkout( - self.path_info, - self.hash_info, - force=force, - progress_callback=progress_callback, - relink=relink, - filter_info=filter_info, - ) + try: + return self.cache.checkout( + self.path_info, + self.hash_info, + force=force, + progress_callback=progress_callback, + relink=relink, + filter_info=filter_info, + ) + except CheckoutError: + if self.persist and allow_persist_missing: + return None + raise def remove(self, ignore_remove=False): self.tree.remove(self.path_info) diff --git a/dvc/repo/checkout.py b/dvc/repo/checkout.py index ce7238472e..f59814473f 100644 --- a/dvc/repo/checkout.py +++ b/dvc/repo/checkout.py @@ -46,6 +46,7 @@ def checkout( force=False, relink=False, recursive=False, + allow_persist_missing=False, ): from dvc.stage.exceptions import ( StageFileBadNameError, @@ -96,6 +97,7 @@ def checkout( progress_callback=pbar.update_msg, relink=relink, filter_info=filter_info, + allow_persist_missing=allow_persist_missing, ) for key, items in result.items(): stats[key].extend(_fspath_dir(path) for path in items) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 4822ca6457..73498c309e 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -79,7 +79,8 @@ class Experiments: r"dvc-exp:(?P[0-9a-f]+)(:(?P.+))?$" ) BRANCH_RE = re.compile( - r"^(?P[a-f0-9]{7})-(?P[a-f0-9]+)$" + r"^(?P[a-f0-9]{7})-(?P[a-f0-9]+)" + r"(?P-checkpoint)?$" ) StashEntry = namedtuple("StashEntry", ["index", "baseline_rev", "branch"]) @@ -302,13 +303,16 @@ def _update(dict_, other): with modify_data(path, tree=self.exp_dvc.tree) as data: _update(data, params[params_fname]) - def _commit(self, exp_hash, check_exists=True, create_branch=True): + def _commit( + self, exp_hash, check_exists=True, create_branch=True, checkpoint=False + ): """Commit stages as an experiment and return the commit SHA.""" if not self.scm.is_dirty(): raise UnchangedExperimentError(self.scm.get_rev()) rev = self.scm.get_rev() - exp_name = f"{rev[:7]}-{exp_hash}" + checkpoint = "-checkpoint" if checkpoint else "" + exp_name = f"{rev[:7]}-{exp_hash}{checkpoint}" if create_branch: if check_exists and exp_name in self.scm.list_branches(): logger.debug("Using existing experiment branch '%s'", exp_name) @@ -323,13 +327,16 @@ def _commit(self, exp_hash, check_exists=True, create_branch=True): def reproduce_one(self, queue=False, **kwargs): """Reproduce and checkout a single experiment.""" + checkpoint = kwargs.get("checkpoint", False) stash_rev = self.new(**kwargs) if queue: logger.info( "Queued experiment '%s' for future execution.", stash_rev[:7] ) return [stash_rev] - results = self.reproduce([stash_rev], keep_stash=False) + results = self.reproduce( + [stash_rev], keep_stash=False, checkpoint=checkpoint + ) exp_rev = first(results) if exp_rev is not None: self.checkout_exp(exp_rev) @@ -348,13 +355,31 @@ def reproduce_queued(self, **kwargs): return results @scm_locked - def new(self, *args, **kwargs): + def new( + self, + *args, + checkpoint: Optional[bool] = False, + checkpoint_continue: Optional[str] = None, + branch: Optional[str] = None, + **kwargs, + ): """Create a new experiment. Experiment will be reproduced and checked out into the user's workspace. """ - branch = kwargs.get("branch") + if checkpoint_continue: + branch = self._get_branch_containing(checkpoint_continue) + if not branch: + raise DvcException( + "Could not find checkpoint experiment " + f"'{checkpoint_continue}'" + ) + logger.debug( + "Continuing checkpoint experiment '%s'", checkpoint_continue + ) + kwargs["apply_workspace"] = False + if branch: rev = self.scm.resolve_rev(branch) logger.debug( @@ -363,8 +388,11 @@ def new(self, *args, **kwargs): else: rev = self.repo.scm.get_rev() self._scm_checkout(rev) + try: - stash_rev = self._stash_exp(*args, **kwargs) + stash_rev = self._stash_exp( + *args, branch=branch, allow_unchanged=checkpoint, **kwargs + ) except UnchangedExperimentError as exc: logger.info("Reproducing existing experiment '%s'.", rev[:7]) raise exc @@ -378,6 +406,7 @@ def reproduce( self, revs: Optional[Iterable] = None, keep_stash: Optional[bool] = True, + checkpoint: Optional[bool] = False, **kwargs, ): """Reproduce the specified experiments. @@ -428,7 +457,10 @@ def reproduce( self._collect_input(executor) executors[rev] = executor - exec_results = self._reproduce(executors, **kwargs) + if checkpoint: + exec_results = self._reproduce_checkpoint(executors) + else: + exec_results = self._reproduce(executors, **kwargs) if keep_stash: # only drop successfully run stashed experiments @@ -480,30 +512,10 @@ def _reproduce(self, executors: dict, jobs: Optional[int] = 1) -> dict: self._scm_checkout(executor.branch) else: self._scm_checkout(executor.baseline_rev) - try: - self._collect_output(executor) - except DownloadError: - logger.error( - "Failed to collect output for experiment '%s'", - rev, - ) - continue - finally: - if os.path.exists(self.args_file): - remove(self.args_file) - - try: - branch = not executor.branch - exp_rev = self._commit(exp_hash, create_branch=branch) - except UnchangedExperimentError: - logger.debug( - "Experiment '%s' identical to baseline '%s'", - rev, - executor.baseline_rev, - ) - exp_rev = executor.baseline_rev - logger.info("Reproduced experiment '%s'.", exp_rev[:7]) - result[rev] = {exp_rev: exp_hash} + exp_rev = self._collect_and_commit(rev, executor, exp_hash) + if exp_rev: + logger.info("Reproduced experiment '%s'.", exp_rev[:7]) + result[rev] = {exp_rev: exp_hash} else: logger.exception( "Failed to reproduce experiment '%s'", rev @@ -512,6 +524,76 @@ def _reproduce(self, executors: dict, jobs: Optional[int] = 1) -> dict: return result + def _reproduce_checkpoint(self, executors): + result = {} + for rev, executor in executors.items(): + logger.debug("Reproducing checkpoint experiment '%s'", rev[:7]) + + if executor.branch: + self._scm_checkout(executor.branch) + else: + self._scm_checkout(executor.baseline_rev) + + def _checkpoint_callback(rev, executor, unchanged, stages): + exp_hash = hash_exp(stages + unchanged) + exp_rev = self._collect_and_commit( + rev, executor, exp_hash, checkpoint=True + ) + if exp_rev: + if not executor.branch: + branch = self._get_branch_containing(exp_rev) + executor.branch = branch + logger.info( + "Checkpoint experiment iteration '%s'.", exp_rev[:7] + ) + result[rev] = {exp_rev: exp_hash} + + checkpoint_func = partial(_checkpoint_callback, rev, executor) + executor.reproduce( + executor.dvc_dir, + cwd=executor.dvc.root_dir, + checkpoint=True, + checkpoint_func=checkpoint_func, + **executor.repro_kwargs, + ) + + # TODO: determine whether or not we should create a final + # checkpoint commit after repro is killed, or if we should only do + # it on explicit make_checkpoint() signals. + + # NOTE: our cached GitPython Repo instance cannot be re-used if the + # checkpoint run was interrupted via SIGINT, so we need this hack + # to create a new git repo instance after checkpoint runs. + del self.scm + + return result + + def _collect_and_commit(self, rev, executor, exp_hash, **kwargs): + try: + self._collect_output(executor) + except DownloadError: + logger.error( + "Failed to collect output for experiment '%s'", rev, + ) + return None + finally: + if os.path.exists(self.args_file): + remove(self.args_file) + + try: + create_branch = not executor.branch + exp_rev = self._commit( + exp_hash, create_branch=create_branch, **kwargs + ) + except UnchangedExperimentError: + logger.debug( + "Experiment '%s' identical to baseline '%s'", + rev, + executor.baseline_rev, + ) + exp_rev = executor.baseline_rev + return exp_rev + def _collect_input(self, executor: ExperimentExecutor): """Copy (upload) input from the experiments workspace to the executor tree. @@ -674,6 +756,8 @@ def _get_branch_containing(self, rev): self._checkout_default_branch() try: name = self.scm.repo.git.branch(contains=rev) + if name.startswith("*"): + name = name[1:] return name.rsplit("/")[-1].strip() except GitCommandError: pass diff --git a/dvc/repo/experiments/executor.py b/dvc/repo/experiments/executor.py index 789087ed16..3c1122f51f 100644 --- a/dvc/repo/experiments/executor.py +++ b/dvc/repo/experiments/executor.py @@ -134,9 +134,10 @@ def reproduce(dvc_dir, cwd=None, **kwargs): unchanged = [] - def filter_pipeline(stage): - if isinstance(stage, PipelineStage): - unchanged.append(stage) + def filter_pipeline(stages): + unchanged.extend( + [stage for stage in stages if isinstance(stage, PipelineStage)] + ) if cwd: old_cwd = os.getcwd() @@ -148,7 +149,20 @@ def filter_pipeline(stage): try: logger.debug("Running repro in '%s'", cwd) dvc = Repo(dvc_dir) - dvc.checkout() + + # NOTE: for checkpoint experiments we handle persist outs slightly + # differently than normal: + # + # - checkpoint out may not yet exist if this is the first time this + # experiment has been run, this is not an error condition for + # experiments + # - at the start of a repro run, we need to remove the persist out + # and restore it to its last known (committed) state (which may + # be removed/does not yet exist) so that our executor workspace + # is not polluted with the (persistent) out from an unrelated + # experiment run + checkpoint = kwargs.pop("checkpoint", False) + dvc.checkout(allow_persist_missing=checkpoint, force=checkpoint) stages = dvc.reproduce(on_unchanged=filter_pipeline, **kwargs) finally: if old_cwd is not None: diff --git a/dvc/repo/experiments/show.py b/dvc/repo/experiments/show.py index 54f448e8d2..af63d35d00 100644 --- a/dvc/repo/experiments/show.py +++ b/dvc/repo/experiments/show.py @@ -1,6 +1,9 @@ import logging from collections import OrderedDict, defaultdict from datetime import datetime +from typing import Optional + +from funcy import first from dvc.repo import locked from dvc.repo.metrics.show import _collect_metrics, _read_metrics @@ -9,11 +12,11 @@ logger = logging.getLogger(__name__) -def _collect_experiment(repo, branch, stash=False, sha_only=True): +def _collect_experiment(repo, rev, stash=False, sha_only=True): from git.exc import GitCommandError res = defaultdict(dict) - for rev in repo.brancher(revs=[branch]): + for rev in repo.brancher(revs=[rev]): if rev == "workspace": res["timestamp"] = None else: @@ -53,6 +56,29 @@ def _resolve_commit(repo, rev): return commit +def _collect_checkpoint_experiment(repo, branch, baseline, **kwargs): + res = OrderedDict() + exp_rev = repo.scm.resolve_rev(branch) + for rev in _branch_revs(repo, exp_rev, baseline): + res[rev] = _collect_experiment(repo, rev, **kwargs) + res[rev]["checkpoint_tip"] = exp_rev + return res + + +def _branch_revs(repo, branch_tip, baseline: Optional[str] = None): + """Iterate over revisions in a given branch (from newest to oldest). + + If baseline is set, iterator will stop when the specified revision is + reached. + """ + commit = _resolve_commit(repo, branch_tip) + while commit is not None: + yield commit.hexsha + commit = first(commit.parents) + if commit and commit.hexsha == baseline: + return + + @locked def show( repo, @@ -89,12 +115,18 @@ def show( if m: rev = repo.scm.resolve_rev(m.group("baseline_rev")) if rev in revs: - exp_rev = repo.experiments.scm.resolve_rev(exp_branch) with repo.experiments.chdir(): - experiment = _collect_experiment( - repo.experiments.exp_dvc, exp_branch - ) - res[rev][exp_rev] = experiment + if m.group("checkpoint"): + checkpoint_exps = _collect_checkpoint_experiment( + repo.experiments.exp_dvc, exp_branch, rev + ) + res[rev].update(checkpoint_exps) + else: + exp_rev = repo.experiments.scm.resolve_rev(exp_branch) + experiment = _collect_experiment( + repo.experiments.exp_dvc, exp_branch + ) + res[rev][exp_rev] = experiment # collect queued (not yet reproduced) experiments for stash_rev, entry in repo.experiments.stash_revs.items(): diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 8fb665f877..529261987b 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -1,4 +1,5 @@ import logging +from functools import partial from dvc.exceptions import InvalidArgumentError, ReproductionError from dvc.repo.experiments import UnchangedExperimentError @@ -11,6 +12,14 @@ def _reproduce_stage(stage, **kwargs): + def _run_callback(repro_callback): + _dump_stage(stage) + repro_callback([stage]) + + checkpoint_func = kwargs.pop("checkpoint_func", None) + if checkpoint_func: + kwargs["checkpoint_func"] = partial(_run_callback, checkpoint_func) + if stage.frozen and not stage.is_import: logger.warning( "{} is frozen. Its dependencies are" @@ -22,14 +31,18 @@ def _reproduce_stage(stage, **kwargs): return [] if not kwargs.get("dry", False): - from ..dvcfile import Dvcfile - - dvcfile = Dvcfile(stage.repo, stage.path) - dvcfile.dump(stage, update_pipeline=False) + _dump_stage(stage) return [stage] +def _dump_stage(stage): + from ..dvcfile import Dvcfile + + dvcfile = Dvcfile(stage.repo, stage.path) + dvcfile.dump(stage, update_pipeline=False) + + def _get_active_graph(G): import networkx as nx @@ -75,6 +88,8 @@ def reproduce( queue = kwargs.pop("queue", False) run_all = kwargs.pop("run_all", False) jobs = kwargs.pop("jobs", 1) + checkpoint = kwargs.pop("checkpoint", False) + checkpoint_continue = kwargs.pop("checkpoint_continue", None) if (experiment or run_all) and self.experiments: try: return _reproduce_experiments( @@ -86,6 +101,8 @@ def reproduce( queue=queue, run_all=run_all, jobs=jobs, + checkpoint=checkpoint, + checkpoint_continue=checkpoint_continue, **kwargs, ) except UnchangedExperimentError: @@ -219,16 +236,25 @@ def _reproduce_stages( force_downstream = kwargs.pop("force_downstream", False) result = [] + unchanged = [] # `ret` is used to add a cosmetic newline. ret = [] for stage in pipeline: if ret: logger.info("") + checkpoint_func = kwargs.pop("checkpoint_func", None) + if checkpoint_func: + kwargs["checkpoint_func"] = partial( + _repro_callback, checkpoint_func, unchanged + ) + try: ret = _reproduce_stage(stage, **kwargs) - if len(ret) != 0 and force_downstream: + if len(ret) == 0: + unchanged.extend([stage]) + elif force_downstream: # NOTE: we are walking our pipeline from the top to the # bottom. If one stage is changed, it will be reproduced, # which tells us that we should force reproducing all of @@ -238,9 +264,13 @@ def _reproduce_stages( if ret: result.extend(ret) - elif on_unchanged is not None: - on_unchanged(stage) except Exception as exc: raise ReproductionError(stage.relpath) from exc + if on_unchanged is not None: + on_unchanged(unchanged) return result + + +def _repro_callback(experiments_callback, unchanged, stages): + experiments_callback(unchanged, stages) diff --git a/dvc/stage/decorators.py b/dvc/stage/decorators.py index df51f67f8f..7b5f985888 100644 --- a/dvc/stage/decorators.py +++ b/dvc/stage/decorators.py @@ -50,3 +50,19 @@ def wrapper(stage, *args, **kwargs): return ret return wrapper + + +def relock_repo(f): + @wraps(f) + def wrapper(stage, *args, **kwargs): + stage.repo.lock.lock() + stage.repo.state.load() + try: + ret = f(stage, *args, **kwargs) + finally: + stage.repo.state.dump() + stage.repo.lock.unlock() + stage.repo._reset() # pylint: disable=protected-access + return ret + + return wrapper diff --git a/dvc/stage/run.py b/dvc/stage/run.py index 616b5306d6..01c064ccbb 100644 --- a/dvc/stage/run.py +++ b/dvc/stage/run.py @@ -3,15 +3,19 @@ import signal import subprocess import threading +from contextlib import contextmanager from dvc.utils import fix_env -from .decorators import unlocked_repo +from .decorators import relock_repo, unlocked_repo from .exceptions import StageCmdFailedError logger = logging.getLogger(__name__) +CHECKPOINT_SIGNAL_FILE = "DVC_CHECKPOINT" + + def _nix_cmd(executable, cmd): opts = {"zsh": ["--no-rcs"], "bash": ["--noprofile", "--norc"]} name = os.path.basename(executable).lower() @@ -35,8 +39,11 @@ def warn_if_fish(executable): @unlocked_repo -def cmd_run(stage, *args, **kwargs): +def cmd_run(stage, *args, checkpoint=False, **kwargs): kwargs = {"cwd": stage.wdir, "env": fix_env(None), "close_fds": True} + if checkpoint: + # indicate that checkpoint cmd is being run inside DVC + kwargs["env"].update({"DVC_CHECKPOINT": "1"}) if os.name == "nt": kwargs["shell"] = True @@ -81,8 +88,8 @@ def cmd_run(stage, *args, **kwargs): raise StageCmdFailedError(stage.cmd, retcode) -def run_stage(stage, dry=False, force=False, **kwargs): - if not (dry or force): +def run_stage(stage, dry=False, force=False, checkpoint_func=None, **kwargs): + if not (dry or force or checkpoint_func): from .cache import RunCacheNotFoundError try: @@ -99,4 +106,59 @@ def run_stage(stage, dry=False, force=False, **kwargs): ) logger.info("\t%s", stage.cmd) if not dry: - cmd_run(stage) + with checkpoint_monitor(stage, checkpoint_func) as monitor: + cmd_run(stage, checkpoint=monitor is not None) + + +class CheckpointCond: + def __init__(self): + self.done = False + self.cond = threading.Condition() + + def notify(self): + with self.cond: + self.done = True + self.cond.notify() + + def wait(self, timeout=None): + with self.cond: + return self.cond.wait(timeout) or self.done + + +@contextmanager +def checkpoint_monitor(stage, callback_func): + if not callback_func: + yield None + return + + done_cond = CheckpointCond() + monitor_thread = threading.Thread( + target=_checkpoint_run, args=(stage, callback_func, done_cond), + ) + + try: + monitor_thread.start() + yield monitor_thread + finally: + done_cond.notify() + monitor_thread.join() + + +def _checkpoint_run(stage, callback_func, done_cond): + """Run callback_func whenever checkpoint signal file is present.""" + signal_path = os.path.join(stage.repo.tmp_dir, CHECKPOINT_SIGNAL_FILE) + while True: + if os.path.exists(signal_path): + _run_callback(stage, callback_func) + logger.debug("Remove checkpoint signal file") + os.remove(signal_path) + if done_cond.wait(1): + return + + +@relock_repo +def _run_callback(stage, callback_func): + stage.save() + # TODO: do we need commit() (and check for --no-commit) here + logger.debug("Running checkpoint callback for stage '%s'", stage) + callback_func() diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 88a24086a8..6688c6d642 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -1,8 +1,47 @@ +from textwrap import dedent + import pytest +from funcy import first from dvc.utils.serialize import PythonFileCorruptedError from tests.func.test_repro_multistage import COPY_SCRIPT +CHECKPOINT_SCRIPT_FORMAT = dedent( + """\ + import os + import sys + import shutil + from time import sleep + + from dvc.api import make_checkpoint + + checkpoint_file = {} + checkpoint_iterations = int({}) + if os.path.exists(checkpoint_file): + with open(checkpoint_file) as fobj: + try: + value = int(fobj.read()) + except ValueError: + value = 0 + else: + with open(checkpoint_file, "w"): + pass + value = 0 + + shutil.copyfile({}, {}) + + if os.getenv("DVC_CHECKPOINT"): + for _ in range(checkpoint_iterations): + value += 1 + with open(checkpoint_file, "w") as fobj: + fobj.write(str(value)) + make_checkpoint() +""" +) +CHECKPOINT_SCRIPT = CHECKPOINT_SCRIPT_FORMAT.format( + "sys.argv[1]", "sys.argv[2]", "sys.argv[3]", "sys.argv[4]" +) + def test_new_simple(tmp_dir, scm, dvc, mocker): tmp_dir.gen("copy.py", COPY_SCRIPT) @@ -236,3 +275,77 @@ def test_detached_parent(tmp_dir, scm, dvc, mocker): assert dvc.experiments.get_baseline(exp_rev) == detached_rev assert (tmp_dir / "params.yaml").read_text().strip() == "foo: 3" assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 3" + + +def test_new_checkpoint(tmp_dir, scm, dvc, mocker): + tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT) + tmp_dir.gen("params.yaml", "foo: 1") + stage = dvc.run( + cmd="python checkpoint.py foo 5 params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + outs_persist=["foo"], + always_changed=True, + name="checkpoint-file", + ) + scm.add( + [ + "dvc.yaml", + "dvc.lock", + "checkpoint.py", + "params.yaml", + "metrics.yaml", + ] + ) + scm.commit("init") + + new_mock = mocker.spy(dvc.experiments, "new") + dvc.reproduce( + stage.addressing, experiment=True, checkpoint=True, params=["foo=2"] + ) + + new_mock.assert_called_once() + assert (tmp_dir / "foo").read_text() == "5" + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text().strip() == "foo: 2" + + +def test_continue_checkpoint(tmp_dir, scm, dvc, mocker): + tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT) + tmp_dir.gen("params.yaml", "foo: 1") + stage = dvc.run( + cmd="python checkpoint.py foo 5 params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + outs_persist=["foo"], + always_changed=True, + name="checkpoint-file", + ) + scm.add( + [ + "dvc.yaml", + "dvc.lock", + "checkpoint.py", + "params.yaml", + "metrics.yaml", + ] + ) + scm.commit("init") + + results = dvc.reproduce( + stage.addressing, experiment=True, checkpoint=True, params=["foo=2"] + ) + exp_rev = first(results) + + dvc.reproduce( + stage.addressing, + experiment=True, + checkpoint=True, + checkpoint_continue=exp_rev, + ) + + assert (tmp_dir / "foo").read_text() == "10" + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text().strip() == "foo: 2" diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index ed65309ad0..7b2213e3da 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -1302,4 +1302,4 @@ def test_repro_when_cmd_changes(tmp_dir, dvc, run_copy, mocker): stage.addressing: ["changed checksum"] } assert dvc.reproduce(stage.addressing)[0] == stage - m.assert_called_once_with(stage) + m.assert_called_once_with(stage, checkpoint=False) diff --git a/tests/func/test_repro_multistage.py b/tests/func/test_repro_multistage.py index 4916d66c3f..d1bb897fbe 100644 --- a/tests/func/test_repro_multistage.py +++ b/tests/func/test_repro_multistage.py @@ -286,7 +286,7 @@ def test_repro_when_cmd_changes(tmp_dir, dvc, run_copy, mocker): assert dvc.status([target]) == {target: ["changed command"]} assert dvc.reproduce(target)[0] == stage - m.assert_called_once_with(stage) + m.assert_called_once_with(stage, checkpoint=False) def test_repro_when_new_deps_is_added_in_dvcfile(tmp_dir, dvc, run_copy): diff --git a/tests/func/test_stage.py b/tests/func/test_stage.py index e3914aa6db..413fcf24c7 100644 --- a/tests/func/test_stage.py +++ b/tests/func/test_stage.py @@ -9,6 +9,7 @@ from dvc.repo import Repo from dvc.stage import PipelineStage, Stage from dvc.stage.exceptions import StageFileFormatError +from dvc.stage.run import run_stage from dvc.tree.local import LocalTree from dvc.utils.serialize import dump_yaml, load_yaml from tests.basic_env import TestDvc @@ -274,3 +275,17 @@ def test_stage_remove_pointer_stage(tmp_dir, dvc, run_copy): with dvc.lock: stage.remove() assert not (tmp_dir / stage.relpath).exists() + + +@pytest.mark.parametrize("checkpoint", [True, False]) +def test_stage_run_checkpoint(tmp_dir, dvc, mocker, checkpoint): + stage = Stage(dvc, "stage.dvc", cmd="mycmd arg1 arg2") + mocker.patch.object(stage, "save") + + mock_cmd_run = mocker.patch("dvc.stage.run.cmd_run") + if checkpoint: + callback = mocker.Mock() + else: + callback = None + run_stage(stage, checkpoint_func=callback) + mock_cmd_run.assert_called_with(stage, checkpoint=checkpoint) diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 0d751f00c0..dd4f88fca3 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -1,5 +1,10 @@ from dvc.cli import parse_args -from dvc.command.experiments import CmdExperimentsDiff, CmdExperimentsShow +from dvc.command.experiments import ( + CmdExperimentsDiff, + CmdExperimentsRun, + CmdExperimentsShow, +) +from dvc.dvcfile import PIPELINE_FILE def test_experiments_diff(dvc, mocker): @@ -54,3 +59,34 @@ def test_experiments_show(dvc, mocker): all_commits=True, sha_only=True, ) + + +default_run_arguments = { + "all_pipelines": False, + "downstream": False, + "dry": False, + "force": False, + "run_cache": True, + "interactive": False, + "no_commit": False, + "pipeline": False, + "single_item": False, + "recursive": False, + "force_downstream": False, + "params": [], + "queue": False, + "run_all": False, + "jobs": None, + "checkpoint": False, + "checkpoint_continue": None, + "experiment": True, +} + + +def test_experiments_run(dvc, mocker): + cmd = CmdExperimentsRun(parse_args(["exp", "run"])) + mocker.patch.object(cmd.repo, "reproduce") + cmd.run() + cmd.repo.reproduce.assert_called_with( + PIPELINE_FILE, **default_run_arguments + )