diff --git a/dvc/command/base.py b/dvc/command/base.py index 9391418f9e..78eb1a0136 100644 --- a/dvc/command/base.py +++ b/dvc/command/base.py @@ -57,3 +57,13 @@ def __init__(self, args): # pylint: disable=super-init-not-called def do_run(self): return self.run() + + +def fix_plumbing_subparsers(subparsers): + # metavar needs to be explicitly set in order to hide plumbing subcommands + # from the 'positional arguments' choices list + # see: https://bugs.python.org/issue22848 + cmds = [ + cmd for cmd, parser in subparsers.choices.items() if parser.add_help + ] + subparsers.metavar = "{{{}}}".format(",".join(cmds)) diff --git a/dvc/command/experiments/__init__.py b/dvc/command/experiments/__init__.py index 93b7c44aa4..d8ff7e825c 100644 --- a/dvc/command/experiments/__init__.py +++ b/dvc/command/experiments/__init__.py @@ -1,10 +1,15 @@ import argparse -from dvc.command.base import append_doc_link, fix_subparsers +from dvc.command.base import ( + append_doc_link, + fix_plumbing_subparsers, + fix_subparsers, +) from dvc.command.experiments import ( apply, branch, diff, + exec_run, gc, init, ls, @@ -19,6 +24,7 @@ apply, branch, diff, + exec_run, gc, init, ls, @@ -51,3 +57,4 @@ def add_parser(subparsers, parent_parser): fix_subparsers(experiments_subparsers) for cmd in SUB_COMMANDS: cmd.add_parser(experiments_subparsers, parent_parser) + fix_plumbing_subparsers(experiments_subparsers) diff --git a/dvc/command/experiments/exec_run.py b/dvc/command/experiments/exec_run.py new file mode 100644 index 0000000000..f49b64e510 --- /dev/null +++ b/dvc/command/experiments/exec_run.py @@ -0,0 +1,42 @@ +import logging + +from dvc.command.base import CmdBaseNoRepo + +logger = logging.getLogger(__name__) + + +class CmdExecutorRun(CmdBaseNoRepo): + """Run an experiment executor.""" + + def run(self): + from dvc.repo.experiments.executor.base import ( + BaseExecutor, + ExecutorInfo, + ) + from dvc.utils.serialize import load_json + + info = ExecutorInfo.from_dict(load_json(self.args.infofile)) + BaseExecutor.reproduce( + info=info, + rev="", + queue=None, + log_level=logger.getEffectiveLevel(), + infofile=self.args.infofile, + ) + return 0 + + +def add_parser(experiments_subparsers, parent_parser): + EXEC_RUN_HELP = "Run an experiment executor." + exec_run_parser = experiments_subparsers.add_parser( + "exec-run", + parents=[parent_parser], + description=EXEC_RUN_HELP, + add_help=False, + ) + exec_run_parser.add_argument( + "--infofile", + help="Path to executor info file", + default=None, + ) + exec_run_parser.set_defaults(func=CmdExecutorRun) diff --git a/dvc/command/experiments/run.py b/dvc/command/experiments/run.py index 3168ff2b20..d2381463b8 100644 --- a/dvc/command/experiments/run.py +++ b/dvc/command/experiments/run.py @@ -38,6 +38,7 @@ def run(self): checkpoint_resume=self.args.checkpoint_resume, reset=self.args.reset, tmp_dir=self.args.tmp_dir, + machine=self.args.machine, **self._repro_kwargs, ) @@ -130,3 +131,12 @@ def _add_run_common(parser): "your workspace." ), ) + parser.add_argument( + "--machine", + default=None, + help=argparse.SUPPRESS, + # help=( + # "Run this experiment on the specified 'dvc machine' instance." + # ) + # metavar="", + ) diff --git a/dvc/config.py b/dvc/config.py index ad8a98b32b..caa62e785f 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -266,7 +266,12 @@ def _map_dirs(conf, func): "key_path": func, } }, - "machine": {str: {"startup_script": func}}, + "machine": { + str: { + "startup_script": func, + "setup_script": func, + } + }, } return Schema(dirs_schema, extra=ALLOW_EXTRA)(conf) diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 098115f24a..aa495a890a 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -257,6 +257,7 @@ class RelPath(str): "instance_gpu": Lower, "ssh_private": str, "startup_script": str, + "setup_script": str, }, }, # section for experimental features diff --git a/dvc/machine/__init__.py b/dvc/machine/__init__.py index ccf767f461..d93c3e7b99 100644 --- a/dvc/machine/__init__.py +++ b/dvc/machine/__init__.py @@ -37,6 +37,8 @@ sudo apt-get update sudo apt-get install --yes dvc popd + +sudo echo "OK" > /var/log/dvc-machine-init.log """ @@ -188,6 +190,7 @@ def create(self, name: Optional[str]): else: startup_script = DEFAULT_STARTUP_SCRIPT config["startup_script"] = startup_script + config.pop("setup_script", None) return backend.create(**config) def destroy(self, name: Optional[str]): @@ -215,3 +218,7 @@ def rename(self, name: str, new: str): def get_executor_kwargs(self, name: Optional[str]): config, backend = self.get_config_and_backend(name) return backend.get_executor_kwargs(**config) + + def get_setup_script(self, name: Optional[str]) -> Optional[str]: + config, _backend = self.get_config_and_backend(name) + return config.get("setup_script") diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index ca32f4785a..811afb3186 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -31,11 +31,12 @@ BaseExecutor, ExecutorInfo, ) -from .executor.manager import ( - BaseExecutorManager, +from .executor.manager.base import BaseExecutorManager +from .executor.manager.local import ( TempDirExecutorManager, WorkspaceExecutorManager, ) +from .executor.manager.ssh import SSHExecutorManager from .utils import exp_refs_by_rev logger = logging.getLogger(__name__) @@ -386,6 +387,7 @@ def reproduce_one( tmp_dir: bool = False, checkpoint_resume: Optional[str] = None, reset: bool = False, + machine: Optional[str] = None, **kwargs, ): """Reproduce and checkout a single experiment.""" @@ -395,7 +397,7 @@ def reproduce_one( if reset: self.reset_checkpoints() - if not (queue or tmp_dir): + if not (queue or tmp_dir or machine): staged, _, _ = self.scm.status() if staged: logger.warning( @@ -429,12 +431,15 @@ def reproduce_one( return [stash_rev] if tmp_dir or queue: manager_cls: Type = TempDirExecutorManager + elif machine: + manager_cls = SSHExecutorManager else: manager_cls = WorkspaceExecutorManager results = self._reproduce_revs( revs=[stash_rev], keep_stash=False, manager_cls=manager_cls, + machine=machine, ) exp_rev = first(results) if exp_rev is not None: @@ -590,6 +595,7 @@ def _reproduce_revs( revs: Optional[Iterable] = None, keep_stash: Optional[bool] = True, manager_cls: Type = TempDirExecutorManager, + machine: Optional[str] = None, **kwargs, ) -> Mapping[str, str]: """Reproduce the specified experiments. @@ -631,6 +637,7 @@ def _reproduce_revs( os.path.join(self.repo.tmp_dir, EXEC_TMP_DIR), self.repo, to_run, + machine_name=machine, ) try: exec_results = {} @@ -665,7 +672,7 @@ def _executors_repro( dict mapping stash revs to the successfully executed experiments for each stash rev. """ - return manager.exec_queue(**kwargs) + return manager.exec_queue(self.repo, **kwargs) def check_baseline(self, exp_rev): baseline_sha = self.repo.scm.get_rev() @@ -740,6 +747,8 @@ def get_running_exps(self) -> Dict[str, int]: from dvc.scm import InvalidRemoteSCMRepo from dvc.utils.serialize import load_json + from .executor.local import TempDirExecutor + result = {} pid_dir = os.path.join( self.repo.tmp_dir, @@ -770,10 +779,10 @@ def get_running_exps(self) -> Dict[str, int]: def on_diverged(_ref: str, _checkpoint: bool): return False + executor = TempDirExecutor.from_info(info) try: - for ref in BaseExecutor.fetch_exps( + for ref in executor.fetch_exps( self.scm, - info.git_url, on_diverged=on_diverged, ): logger.debug( diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index 9943a9417b..34fdb2e878 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -13,6 +13,7 @@ Iterable, NamedTuple, Optional, + Tuple, Type, TypeVar, Union, @@ -90,6 +91,14 @@ def result(self) -> Optional["ExecutorResult"]: self.result_force, ) + def dump_json(self, filename: str): + from dvc.utils.fs import makedirs + from dvc.utils.serialize import modify_json + + makedirs(os.path.dirname(filename), exist_ok=True) + with modify_json(filename) as d: + d.update(self.asdict()) + _T = TypeVar("_T", bound="BaseExecutor") @@ -123,7 +132,7 @@ def __init__( result: Optional["ExecutorResult"] = None, **kwargs, ): - self._dvc_dir = dvc_dir + self.dvc_dir = dvc_dir self.root_dir = root_dir self.wdir = wdir self.name = name @@ -143,8 +152,14 @@ def git_url(self) -> str: pass @abstractmethod - def init_cache(self, dvc: "Repo", rev: str, run_cache: bool = True): - """Initialize DVC (cache).""" + def init_cache(self, repo: "Repo", rev: str, run_cache: bool = True): + """Initialize DVC cache.""" + + @abstractmethod + def collect_cache( + self, repo: "Repo", exp_ref: "ExpRefInfo", run_cache: bool = True + ): + """Collect DVC cache.""" @property def info(self) -> "ExecutorInfo": @@ -224,10 +239,6 @@ def _from_stash_entry( executor.init_cache(repo, stash_rev) return executor - @property - def dvc_dir(self) -> str: - return os.path.join(self.root_dir, self._dvc_dir) - @staticmethod def hash_exp(stages: Iterable["PipelineStage"]) -> str: from dvc.stage import PipelineStage @@ -266,28 +277,33 @@ def unpack_repro_args(path): data = pickle.load(fobj) return data["args"], data["kwargs"] - @classmethod def fetch_exps( - cls, + self, dest_scm: "Git", - url: str, force: bool = False, on_diverged: Callable[[str, bool], None] = None, + **kwargs, ) -> Iterable[str]: - """Fetch reproduced experiments into the specified SCM. + """Fetch reproduced experiment refs into the specified SCM. Args: dest_scm: Destination Git instance. - url: Git remote URL to fetch from. force: If True, diverged refs will be overwritten on_diverged: Callback in the form on_diverged(ref, is_checkpoint) to be called when an experiment ref has diverged. + + Extra kwargs will be passed into the remote git client. """ from ..utils import iter_remote_refs refs = [] has_checkpoint = False - for ref in iter_remote_refs(dest_scm, url, base=EXPS_NAMESPACE): + for ref in iter_remote_refs( + dest_scm, + self.git_url, + base=EXPS_NAMESPACE, + **kwargs, + ): if ref == EXEC_CHECKPOINT: has_checkpoint = True elif not ref.startswith(EXEC_NAMESPACE) and ref != EXPS_STASH: @@ -298,7 +314,7 @@ def on_diverged_ref(orig_ref: str, new_rev: str): logger.debug("Replacing existing experiment '%s'", orig_ref) return True - cls._raise_ref_conflict( + self._raise_ref_conflict( dest_scm, orig_ref, new_rev, has_checkpoint ) if on_diverged: @@ -308,17 +324,19 @@ def on_diverged_ref(orig_ref: str, new_rev: str): # fetch experiments dest_scm.fetch_refspecs( - url, + self.git_url, [f"{ref}:{ref}" for ref in refs], on_diverged=on_diverged_ref, force=force, + **kwargs, ) # update last run checkpoint (if it exists) if has_checkpoint: dest_scm.fetch_refspecs( - url, + self.git_url, [f"{EXEC_CHECKPOINT}:{EXEC_CHECKPOINT}"], force=True, + **kwargs, ) return refs @@ -382,10 +400,12 @@ def filter_pipeline(stages): exp_ref: Optional["ExpRefInfo"] = None repro_force: bool = False + if infofile is not None: + info.dump_json(infofile) + with cls._repro_dvc( info, log_errors=log_errors, - infofile=infofile, **kwargs, ) as dvc: if auto_push: @@ -446,61 +466,84 @@ def filter_pipeline(stages): exp_hash = cls.hash_exp(stages) if not repro_dry: - try: - is_checkpoint = any( - stage.is_checkpoint for stage in stages - ) - if is_checkpoint and checkpoint_reset: - # For reset checkpoint stages, we need to force - # overwriting existing checkpoint refs even though - # repro may not have actually been run with --force - repro_force = True - cls.commit( - dvc.scm, - exp_hash, - exp_name=info.name, - force=repro_force, - checkpoint=is_checkpoint, - ) - if auto_push: - cls._auto_push(dvc, dvc.scm, git_remote) - except UnchangedExperimentError: - pass - ref = dvc.scm.get_ref(EXEC_BRANCH, follow=False) - if ref: - exp_ref = ExpRefInfo.from_ref(ref) - if cls.WARN_UNTRACKED: - untracked = dvc.scm.untracked_files() - if untracked: - logger.warning( - "The following untracked files were present in " - "the experiment directory after reproduction but " - "will not be included in experiment commits:\n" - "\t%s", - ", ".join(untracked), - ) + ref, exp_ref, repro_force = cls._repro_commit( + dvc, + info, + stages, + exp_hash, + checkpoint_reset, + auto_push, + git_remote, + repro_force, + ) info.result_hash = exp_hash info.result_ref = ref info.result_force = repro_force + if infofile is not None: + info.dump_json(infofile) + # ideally we would return stages here like a normal repro() call, but # stages is not currently picklable and cannot be returned across # multiprocessing calls return ExecutorResult(exp_hash, exp_ref, repro_force) + @classmethod + def _repro_commit( + cls, + dvc, + info, + stages, + exp_hash, + checkpoint_reset, + auto_push, + git_remote, + repro_force, + ) -> Tuple[Optional[str], Optional["ExpRefInfo"], bool]: + try: + is_checkpoint = any(stage.is_checkpoint for stage in stages) + if is_checkpoint and checkpoint_reset: + # For reset checkpoint stages, we need to force + # overwriting existing checkpoint refs even though + # repro may not have actually been run with --force + repro_force = True + cls.commit( + dvc.scm, + exp_hash, + exp_name=info.name, + force=repro_force, + checkpoint=is_checkpoint, + ) + if auto_push: + cls._auto_push(dvc, dvc.scm, git_remote) + except UnchangedExperimentError: + pass + ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False) + exp_ref: Optional["ExpRefInfo"] = ( + ExpRefInfo.from_ref(ref) if ref else None + ) + if cls.WARN_UNTRACKED: + untracked = dvc.scm.untracked_files() + if untracked: + logger.warning( + "The following untracked files were present in " + "the experiment directory after reproduction but " + "will not be included in experiment commits:\n" + "\t%s", + ", ".join(untracked), + ) + return ref, exp_ref, repro_force + @classmethod @contextmanager def _repro_dvc( cls, info: "ExecutorInfo", log_errors: bool = True, - infofile: Optional[str] = None, **kwargs, ): from dvc.repo import Repo from dvc.stage.monitor import CheckpointKilledError - from dvc.utils.fs import makedirs - from dvc.utils.serialize import modify_json dvc = Repo(os.path.join(info.root_dir, info.dvc_dir)) if cls.QUIET: @@ -511,11 +554,6 @@ def _repro_dvc( else: os.chdir(dvc.root_dir) - if infofile is not None: - makedirs(os.path.dirname(infofile), exist_ok=True) - with modify_json(infofile) as d: - d.update(info.asdict()) - try: logger.debug("Running repro in '%s'", os.getcwd()) yield dvc @@ -530,9 +568,6 @@ def _repro_dvc( logger.exception("unexpected error") raise finally: - if infofile is not None: - with modify_json(infofile) as d: - d.update(info.asdict()) dvc.close() os.chdir(old_cwd) diff --git a/dvc/repo/experiments/executor/local.py b/dvc/repo/experiments/executor/local.py index d8a9f68675..fc74942af3 100644 --- a/dvc/repo/experiments/executor/local.py +++ b/dvc/repo/experiments/executor/local.py @@ -25,7 +25,7 @@ from dvc.repo import Repo - from ..base import ExpStashEntry + from ..base import ExpRefInfo, ExpStashEntry logger = logging.getLogger(__name__) @@ -49,6 +49,11 @@ def cleanup(self): self.scm.close() del self.scm + def collect_cache( + self, repo: "Repo", exp_ref: "ExpRefInfo", run_cache: bool = True + ): + """Collect DVC cache.""" + class TempDirExecutor(BaseLocalExecutor): """Temp directory experiment executor.""" @@ -86,14 +91,18 @@ def init_git(self, scm: "Git", branch: Optional[str] = None): self.scm.merge(merge_rev, squash=True, commit=False) def _config(self, cache_dir): - local_config = os.path.join(self.dvc_dir, "config.local") + local_config = os.path.join( + self.root_dir, + self.dvc_dir, + "config.local", + ) logger.debug("Writing experiments local config '%s'", local_config) with open(local_config, "w", encoding="utf-8") as fobj: fobj.write(f"[cache]\n dir = {cache_dir}") - def init_cache(self, dvc: "Repo", rev: str, run_cache: bool = True): - """Initialize DVC (cache).""" - self._config(dvc.odb.local.cache_dir) + def init_cache(self, repo: "Repo", rev: str, run_cache: bool = True): + """Initialize DVC cache.""" + self._config(repo.odb.local.cache_dir) def cleanup(self): super().cleanup() @@ -152,7 +161,7 @@ def init_git(self, scm: "Git", branch: Optional[str] = None): elif scm.get_ref(EXEC_BRANCH): self.scm.remove_ref(EXEC_BRANCH) - def init_cache(self, dvc: "Repo", rev: str, run_cache: bool = True): + def init_cache(self, repo: "Repo", rev: str, run_cache: bool = True): pass def cleanup(self): diff --git a/dvc/repo/experiments/executor/manager/__init__.py b/dvc/repo/experiments/executor/manager/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dvc/repo/experiments/executor/manager.py b/dvc/repo/experiments/executor/manager/base.py similarity index 70% rename from dvc/repo/experiments/executor/manager.py rename to dvc/repo/experiments/executor/manager/base.py index 02c182921a..2f53802fe6 100644 --- a/dvc/repo/experiments/executor/manager.py +++ b/dvc/repo/experiments/executor/manager/base.py @@ -7,9 +7,8 @@ from dvc.proc.manager import ProcessManager -from ..base import ( +from ...base import ( EXEC_BASELINE, - EXEC_BRANCH, EXEC_HEAD, EXEC_MERGE, CheckpointExistsError, @@ -17,8 +16,8 @@ ExpRefInfo, ExpStashEntry, ) -from .base import EXEC_PID_DIR, BaseExecutor -from .local import TempDirExecutor, WorkspaceExecutor +from ..base import EXEC_PID_DIR, BaseExecutor +from ..local import TempDirExecutor, WorkspaceExecutor if TYPE_CHECKING: from scmrepo.git import Git @@ -37,6 +36,7 @@ def __init__( self, scm: "Git", wdir: str, + **kwargs, ): from dvc.utils.fs import makedirs @@ -74,8 +74,8 @@ def _load_infos(self) -> Generator[Tuple[str, "BaseExecutor"], None, None]: import json from urllib.parse import urlparse - from .base import ExecutorInfo - from .ssh import SSHExecutor + from ..base import ExecutorInfo + from ..ssh import SSHExecutor def make_executor(info: "ExecutorInfo"): if info.git_url: @@ -116,6 +116,16 @@ def from_stash_entries( **kwargs, ): manager = cls(scm, wdir) + manager._enqueue_stash_entries(scm, repo, to_run, **kwargs) + return manager + + def _enqueue_stash_entries( + self, + scm: "Git", + repo: "Repo", + to_run: Dict[str, ExpStashEntry], + **kwargs, + ): try: for stash_rev, entry in to_run.items(): scm.set_ref(EXEC_HEAD, entry.head_rev) @@ -128,26 +138,27 @@ def from_stash_entries( # EXEC_MERGE - the unmerged changes (from our stash) # to be reproduced # EXEC_BASELINE - the baseline commit for this experiment - executor = cls.EXECUTOR_CLS.from_stash_entry( + executor = self.EXECUTOR_CLS.from_stash_entry( repo, stash_rev, entry, **kwargs, ) - manager.enqueue(stash_rev, executor) + self.enqueue(stash_rev, executor) finally: for ref in (EXEC_HEAD, EXEC_MERGE, EXEC_BASELINE): scm.remove_ref(ref) - return manager - def exec_queue(self, jobs: Optional[int] = 1, detach: bool = False): + def exec_queue( + self, repo: "Repo", jobs: Optional[int] = 1, detach: bool = False + ): """Run dvc repro for queued executors in parallel.""" if detach: raise NotImplementedError # TODO use ProcessManager.spawn() to support detached runs - return self._exec_attached(jobs=jobs) + return self._exec_attached(repo, jobs=jobs) - def _exec_attached(self, jobs: Optional[int] = 1): + def _exec_attached(self, repo: "Repo", jobs: Optional[int] = 1): import signal from concurrent.futures import ( CancelledError, @@ -205,7 +216,7 @@ def _exec_attached(self, jobs: Optional[int] = 1): if exc is None: exec_result = future.result() result[rev].update( - self._collect_executor(executor, exec_result) + self._collect_executor(repo, executor, exec_result) ) elif not isinstance(exc, CheckpointKilledError): logger.error( @@ -222,7 +233,7 @@ def _exec_attached(self, jobs: Optional[int] = 1): return result - def _collect_executor(self, executor, exec_result) -> Dict[str, str]: + def _collect_executor(self, repo, executor, exec_result) -> Dict[str, str]: # NOTE: GitPython Repo instances cannot be re-used # after process has received SIGINT or SIGTERM, so we # need this hack to re-instantiate git instances after @@ -240,7 +251,6 @@ def on_diverged(ref: str, checkpoint: bool): for ref in executor.fetch_exps( self.scm, - executor.git_url, force=exec_result.force, on_diverged=on_diverged, ): @@ -249,6 +259,9 @@ def on_diverged(ref: str, checkpoint: bool): logger.debug("Collected experiment '%s'.", exp_rev[:7]) results[exp_rev] = exec_result.exp_hash + if exec_result.ref_info is not None: + executor.collect_cache(repo, exec_result.ref_info) + return results def cleanup_executor(self, rev: str, executor: "BaseExecutor"): @@ -260,92 +273,3 @@ def cleanup_executor(self, rev: str, executor: "BaseExecutor"): except KeyError: pass remove(os.path.join(self.pid_dir, rev)) - - -class TempDirExecutorManager(BaseExecutorManager): - EXECUTOR_CLS = TempDirExecutor - - -class WorkspaceExecutorManager(BaseExecutorManager): - EXECUTOR_CLS = WorkspaceExecutor - - @classmethod - def from_stash_entries( - cls, - scm: "Git", - wdir: str, - repo: "Repo", - to_run: Dict[str, ExpStashEntry], - **kwargs, - ): - manager = cls(scm, wdir) - try: - assert len(to_run) == 1 - for stash_rev, entry in to_run.items(): - scm.set_ref(EXEC_HEAD, entry.head_rev) - scm.set_ref(EXEC_MERGE, stash_rev) - scm.set_ref(EXEC_BASELINE, entry.baseline_rev) - - executor = cls.EXECUTOR_CLS.from_stash_entry( - repo, - stash_rev, - entry, - **kwargs, - ) - manager.enqueue(stash_rev, executor) - finally: - for ref in (EXEC_MERGE,): - scm.remove_ref(ref) - return manager - - def _collect_executor(self, executor, exec_result) -> Dict[str, str]: - results = {} - exp_rev = self.scm.get_ref(EXEC_BRANCH) - if exp_rev: - logger.debug("Collected experiment '%s'.", exp_rev[:7]) - results[exp_rev] = exec_result.exp_hash - return results - - def exec_queue(self, jobs: Optional[int] = 1, detach: bool = False): - """Run a single WorkspaceExecutor. - - Workspace execution is done within the main DVC process - (rather than in multiprocessing context) - """ - from dvc.exceptions import DvcException - from dvc.stage.monitor import CheckpointKilledError - - assert len(self._queue) == 1 - assert not detach - result: Dict[str, Dict[str, str]] = defaultdict(dict) - rev, executor = self._queue.popleft() - - exec_name = "workspace" - infofile = self.get_infofile_path(exec_name) - try: - exec_result = executor.reproduce( - info=executor.info, - rev=rev, - infofile=infofile, - log_level=logger.getEffectiveLevel(), - ) - if not exec_result.exp_hash: - raise DvcException( - f"Failed to reproduce experiment '{rev[:7]}'" - ) - if exec_result.ref_info: - result[rev].update( - self._collect_executor(executor, exec_result) - ) - except CheckpointKilledError: - # Checkpoint errors have already been logged - return {} - except DvcException: - raise - except Exception as exc: - raise DvcException( - f"Failed to reproduce experiment '{rev[:7]}'" - ) from exc - finally: - self.cleanup_executor(exec_name, executor) - return result diff --git a/dvc/repo/experiments/executor/manager/local.py b/dvc/repo/experiments/executor/manager/local.py new file mode 100644 index 0000000000..da139f759c --- /dev/null +++ b/dvc/repo/experiments/executor/manager/local.py @@ -0,0 +1,111 @@ +import logging +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Optional + +from ...base import ( + EXEC_BASELINE, + EXEC_BRANCH, + EXEC_HEAD, + EXEC_MERGE, + ExpStashEntry, +) +from ..local import TempDirExecutor, WorkspaceExecutor +from .base import BaseExecutorManager + +if TYPE_CHECKING: + from scmrepo.git import Git + + from dvc.repo import Repo + +logger = logging.getLogger(__name__) + + +class TempDirExecutorManager(BaseExecutorManager): + EXECUTOR_CLS = TempDirExecutor + + +class WorkspaceExecutorManager(BaseExecutorManager): + EXECUTOR_CLS = WorkspaceExecutor + + @classmethod + def from_stash_entries( + cls, + scm: "Git", + wdir: str, + repo: "Repo", + to_run: Dict[str, ExpStashEntry], + **kwargs, + ): + manager = cls(scm, wdir) + try: + assert len(to_run) == 1 + for stash_rev, entry in to_run.items(): + scm.set_ref(EXEC_HEAD, entry.head_rev) + scm.set_ref(EXEC_MERGE, stash_rev) + scm.set_ref(EXEC_BASELINE, entry.baseline_rev) + + executor = cls.EXECUTOR_CLS.from_stash_entry( + repo, + stash_rev, + entry, + **kwargs, + ) + manager.enqueue(stash_rev, executor) + finally: + for ref in (EXEC_MERGE,): + scm.remove_ref(ref) + return manager + + def _collect_executor(self, repo, executor, exec_result) -> Dict[str, str]: + results = {} + exp_rev = self.scm.get_ref(EXEC_BRANCH) + if exp_rev: + logger.debug("Collected experiment '%s'.", exp_rev[:7]) + results[exp_rev] = exec_result.exp_hash + return results + + def exec_queue( + self, repo: "Repo", jobs: Optional[int] = 1, detach: bool = False + ): + """Run a single WorkspaceExecutor. + + Workspace execution is done within the main DVC process + (rather than in multiprocessing context) + """ + from dvc.exceptions import DvcException + from dvc.stage.monitor import CheckpointKilledError + + assert len(self._queue) == 1 + assert not detach + result: Dict[str, Dict[str, str]] = defaultdict(dict) + rev, executor = self._queue.popleft() + + exec_name = "workspace" + infofile = self.get_infofile_path(exec_name) + try: + exec_result = executor.reproduce( + info=executor.info, + rev=rev, + infofile=infofile, + log_level=logger.getEffectiveLevel(), + ) + if not exec_result.exp_hash: + raise DvcException( + f"Failed to reproduce experiment '{rev[:7]}'" + ) + if exec_result.ref_info: + result[rev].update( + self._collect_executor(repo, executor, exec_result) + ) + except CheckpointKilledError: + # Checkpoint errors have already been logged + return {} + except DvcException: + raise + except Exception as exc: + raise DvcException( + f"Failed to reproduce experiment '{rev[:7]}'" + ) from exc + finally: + self.cleanup_executor(exec_name, executor) + return result diff --git a/dvc/repo/experiments/executor/manager/ssh.py b/dvc/repo/experiments/executor/manager/ssh.py new file mode 100644 index 0000000000..a67c5c7e17 --- /dev/null +++ b/dvc/repo/experiments/executor/manager/ssh.py @@ -0,0 +1,109 @@ +import logging +import posixpath +from collections import defaultdict +from typing import TYPE_CHECKING, Callable, Dict, Generator, Optional, Tuple + +from ...base import ExpStashEntry +from ..base import BaseExecutor +from ..ssh import SSHExecutor, _sshfs +from .base import BaseExecutorManager + +if TYPE_CHECKING: + from scmrepo.git import Git + + from dvc.repo import Repo + +logger = logging.getLogger(__name__) + + +class SSHExecutorManager(BaseExecutorManager): + EXECUTOR_CLS = SSHExecutor + + def __init__( + self, + scm: "Git", + wdir: str, + host: Optional[str] = None, + port: Optional[int] = None, + username: Optional[str] = None, + fs_factory: Optional[Callable] = None, + **kwargs, + ): + assert host + super().__init__(scm, wdir, **kwargs) + self.host = host + self.port = port + self.username = username + self._fs_factory = fs_factory + + def _load_infos(self) -> Generator[Tuple[str, "BaseExecutor"], None, None]: + # TODO: load existing infos using sshfs + yield from [] + + @classmethod + def from_stash_entries( + cls, + scm: "Git", + wdir: str, + repo: "Repo", + to_run: Dict[str, ExpStashEntry], + **kwargs, + ): + machine_name: Optional[str] = kwargs.get("machine_name", None) + manager = cls( + scm, wdir, **repo.machine.get_executor_kwargs(machine_name) + ) + manager._enqueue_stash_entries(scm, repo, to_run, **kwargs) + return manager + + def sshfs(self): + return _sshfs(self._fs_factory, host=self.host, port=self.port) + + def get_infofile_path(self, name: str) -> str: + return f"{name}{BaseExecutor.INFOFILE_EXT}" + + def _exec_attached(self, repo: "Repo", jobs: Optional[int] = 1): + from dvc.exceptions import DvcException + from dvc.stage.monitor import CheckpointKilledError + + assert len(self._queue) == 1 + result: Dict[str, Dict[str, str]] = defaultdict(dict) + rev, executor = self._queue.popleft() + info = executor.info + infofile = posixpath.join( + info.root_dir, + info.dvc_dir, + "tmp", + self.get_infofile_path(rev), + ) + try: + exec_result = executor.reproduce( + info=executor.info, + rev=rev, + infofile=infofile, + log_level=logger.getEffectiveLevel(), + fs_factory=self._fs_factory, + ) + if not exec_result.exp_hash: + raise DvcException( + f"Failed to reproduce experiment '{rev[:7]}'" + ) + if exec_result.ref_info: + result[rev].update( + self._collect_executor(repo, executor, exec_result) + ) + except CheckpointKilledError: + # Checkpoint errors have already been logged + return {} + except DvcException: + raise + except Exception as exc: + raise DvcException( + f"Failed to reproduce experiment '{rev[:7]}'" + ) from exc + finally: + self.cleanup_executor(rev, executor) + return result + + def cleanup_executor(self, rev: str, executor: "BaseExecutor"): + executor.cleanup() diff --git a/dvc/repo/experiments/executor/ssh.py b/dvc/repo/experiments/executor/ssh.py index 3b1df95038..657e05b969 100644 --- a/dvc/repo/experiments/executor/ssh.py +++ b/dvc/repo/experiments/executor/ssh.py @@ -1,7 +1,9 @@ import logging +import os import posixpath +import sys from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Iterable, Optional from funcy import first @@ -14,15 +16,16 @@ EXEC_NAMESPACE, ) -from .base import BaseExecutor +from .base import BaseExecutor, ExecutorInfo, ExecutorResult if TYPE_CHECKING: + from multiprocessing import Queue + from scmrepo.git import Git - from dvc.machine import MachineManager from dvc.repo import Repo - from ..base import ExpStashEntry + from ..base import ExpRefInfo, ExpStashEntry logger = logging.getLogger(__name__) @@ -41,6 +44,7 @@ class SSHExecutor(BaseExecutor): WARN_UNTRACKED = True QUIET = True + SETUP_SCRIPT_FILENAME = "exec-setup.sh" def __init__( self, @@ -49,6 +53,7 @@ def __init__( port: Optional[int] = None, username: Optional[str] = None, fs_factory: Optional[Callable] = None, + setup_script: Optional[str] = None, **kwargs, ): assert host @@ -59,6 +64,7 @@ def __init__( self.username = username self._fs_factory = fs_factory self._repo_abspath = "" + self._setup_script = setup_script @classmethod def gen_dirname(cls, name: Optional[str] = None): @@ -74,14 +80,15 @@ def from_stash_entry( entry: "ExpStashEntry", **kwargs, ): - manager: "MachineManager" = kwargs.pop("manager") machine_name: Optional[str] = kwargs.pop("machine_name", None) executor = cls._from_stash_entry( repo, stash_rev, entry, cls.gen_dirname(entry.name), - **manager.get_executor_kwargs(machine_name), + location=machine_name, + **repo.machine.get_executor_kwargs(machine_name), + setup_script=repo.machine.get_setup_script(machine_name), ) logger.debug("Init SSH executor for host '%s'", executor.host) return executor @@ -151,21 +158,128 @@ def init_git(self, scm: "Git", branch: Optional[str] = None): merge_rev = scm.get_ref(EXEC_MERGE) self._ssh_cmd(fs, f"git merge --squash --no-commit {merge_rev}") + if self._setup_script: + self._init_setup_script(fs) + + @classmethod + def _setup_script_path(cls, dvc_dir: str): + return posixpath.join( + dvc_dir, + "tmp", + cls.SETUP_SCRIPT_FILENAME, + ) + + def _init_setup_script(self, fs: "SSHFileSystem"): + assert self._repo_abspath + script_path = self._setup_script_path( + posixpath.join(self._repo_abspath, self.dvc_dir) + ) + fs.upload(self._setup_script, script_path) + def _ssh_cmd(self, sshfs, cmd, chdir=None, **kwargs): working_dir = chdir or self.root_dir return sshfs.fs.execute(f"cd {working_dir};{cmd}", **kwargs) - def init_cache(self, dvc: "Repo", rev: str, run_cache: bool = True): - from dvc.data.db import ODBManager, get_odb - from dvc.repo import Repo + def init_cache(self, repo: "Repo", rev: str, run_cache: bool = True): from dvc.repo.push import push + with self.get_odb() as odb: + push( + repo, + revs=[rev], + run_cache=run_cache, + odb=odb, + include_imports=True, + ) + + def collect_cache( + self, repo: "Repo", exp_ref: "ExpRefInfo", run_cache: bool = True + ): + """Collect DVC cache.""" + from dvc.repo.experiments.pull import _pull_cache + + with self.get_odb() as odb: + _pull_cache(repo, exp_ref, run_cache=run_cache, odb=odb) + + @contextmanager + def get_odb(self): + from dvc.data.db import ODBManager, get_odb + cache_path = posixpath.join( self._repo_abspath, - Repo.DVC_DIR, + self.dvc_dir, ODBManager.CACHE_DIR, ) with self.sshfs() as fs: - odb = get_odb(fs, cache_path, **fs.config) - push(dvc, revs=[rev], run_cache=run_cache, odb=odb) + yield get_odb(fs, cache_path, **fs.config) + + def fetch_exps(self, *args, **kwargs) -> Iterable[str]: + with self.sshfs() as fs: + kwargs.update(self._git_client_args(fs)) + return super().fetch_exps(*args, **kwargs) + + @classmethod + def reproduce( + cls, + info: "ExecutorInfo", + rev: str, + queue: Optional["Queue"] = None, + infofile: Optional[str] = None, + log_errors: bool = True, + log_level: Optional[int] = None, + **kwargs, + ) -> "ExecutorResult": + """Reproduce an experiment on a remote machine over SSH. + + Internally uses 'dvc exp exec-run' over SSH. + """ + import json + import time + from tempfile import TemporaryFile + + from asyncssh import ProcessError + + fs_factory: Optional[Callable] = kwargs.pop("fs_factory", None) + if log_errors and log_level is not None: + cls._set_log_level(log_level) + + with _sshfs(fs_factory) as fs: + while not fs.exists("/var/log/dvc-machine-init.log"): + logger.info( + "Waiting for dvc-machine startup script to complete..." + ) + time.sleep(5) + logger.info( + "Reproducing experiment on '%s'", fs.fs_args.get("host") + ) + with TemporaryFile(mode="w+", encoding="utf-8") as fobj: + json.dump(info.asdict(), fobj) + fobj.seek(0) + fs.upload_fobj(fobj, infofile) + cmd = ["source ~/.profile"] + script_path = cls._setup_script_path(info.dvc_dir) + if fs.exists(posixpath.join(info.root_dir, script_path)): + cmd.extend( + [f"pushd {info.root_dir}", f"source {script_path}", "popd"] + ) + exec_cmd = f"dvc exp exec-run --infofile {infofile}" + if log_level is not None: + if log_level <= logging.TRACE: # type: ignore[attr-defined] + exec_cmd += " -vv" + elif log_level <= logging.DEBUG: + exec_cmd += " -v" + cmd.append(exec_cmd) + try: + sys.stdout.flush() + sys.stderr.flush() + stdout = os.dup(sys.stdout.fileno()) + stderr = os.dup(sys.stderr.fileno()) + fs.fs.execute("; ".join(cmd), stdout=stdout, stderr=stderr) + with fs.open(infofile) as fobj: + result_info = ExecutorInfo.from_dict(json.load(fobj)) + if result_info.result_hash: + return result_info.result + except ProcessError: + pass + return ExecutorResult(None, None, False) diff --git a/dvc/repo/experiments/pull.py b/dvc/repo/experiments/pull.py index d2440bf4a6..8760e523e4 100644 --- a/dvc/repo/experiments/pull.py +++ b/dvc/repo/experiments/pull.py @@ -48,7 +48,16 @@ def on_diverged(refname: str, rev: str) -> bool: _pull_cache(repo, exp_ref, **kwargs) -def _pull_cache(repo, exp_ref, dvc_remote=None, jobs=None, run_cache=False): +def _pull_cache( + repo, + exp_ref, + dvc_remote=None, + jobs=None, + run_cache=False, + odb=None, +): revs = list(exp_commits(repo.scm, [exp_ref])) logger.debug("dvc fetch experiment '%s'", exp_ref) - repo.fetch(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs) + repo.fetch( + jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, odb=odb + ) diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index c7af1179b2..80acdf2d75 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -1,10 +1,14 @@ import logging +from typing import TYPE_CHECKING, Optional from dvc.exceptions import DownloadError, FileTransferError from dvc.scheme import Schemes from . import locked +if TYPE_CHECKING: + from dvc.objects.db.base import ObjectDB + logger = logging.getLogger(__name__) @@ -21,6 +25,7 @@ def fetch( all_commits=False, run_cache=False, revs=None, + odb: Optional["ObjectDB"] = None, ): """Download data items from a cloud and imported repositories @@ -59,20 +64,34 @@ def fetch( except DownloadError as exc: failed += exc.amount - for odb, obj_ids in sorted( - used.items(), - key=lambda item: item[0] is not None - and item[0].fs.scheme == Schemes.MEMORY, - ): + if odb: + all_ids = set() + for _odb, obj_ids in used.items(): + all_ids.update(obj_ids) d, f = _fetch( self, - obj_ids, + all_ids, jobs=jobs, remote=remote, odb=odb, ) downloaded += d failed += f + else: + for src_odb, obj_ids in sorted( + used.items(), + key=lambda item: item[0] is not None + and item[0].fs.scheme == Schemes.MEMORY, + ): + d, f = _fetch( + self, + obj_ids, + jobs=jobs, + remote=remote, + odb=src_odb, + ) + downloaded += d + failed += f if failed: raise DownloadError(failed) diff --git a/dvc/repo/pull.py b/dvc/repo/pull.py index e0714da3f6..b4f472933f 100644 --- a/dvc/repo/pull.py +++ b/dvc/repo/pull.py @@ -1,8 +1,12 @@ import logging +from typing import TYPE_CHECKING, Optional from dvc.repo import locked from dvc.utils import glob_targets +if TYPE_CHECKING: + from dvc.objects.db.base import ObjectDB + logger = logging.getLogger(__name__) @@ -20,6 +24,7 @@ def pull( all_commits=False, run_cache=False, glob=False, + odb: Optional["ObjectDB"] = None, ): if isinstance(targets, str): targets = [targets] @@ -36,6 +41,7 @@ def pull( with_deps=with_deps, recursive=recursive, run_cache=run_cache, + odb=odb, ) stats = self.checkout( targets=expanded_targets, diff --git a/dvc/repo/push.py b/dvc/repo/push.py index 7f5320ab38..25a9b36922 100644 --- a/dvc/repo/push.py +++ b/dvc/repo/push.py @@ -24,6 +24,7 @@ def push( revs=None, glob=False, odb: Optional["ObjectDB"] = None, + include_imports=False, ): used_run_cache = ( self.stage_cache.push(remote, odb=odb) if run_cache else [] @@ -49,13 +50,24 @@ def push( ) pushed = len(used_run_cache) - for dest_odb, obj_ids in used.items(): - if dest_odb and dest_odb.read_only: - continue + if odb: + all_ids = set() + for dest_odb, obj_ids in used.items(): + if not include_imports and dest_odb and dest_odb.read_only: + continue + all_ids.update(obj_ids) try: - pushed += self.cloud.push( - obj_ids, jobs, remote=remote, odb=odb or dest_odb - ) + pushed += self.cloud.push(all_ids, jobs, remote=remote, odb=odb) except FileTransferError as exc: raise UploadError(exc.amount) + else: + for dest_odb, obj_ids in used.items(): + if dest_odb and dest_odb.read_only: + continue + try: + pushed += self.cloud.push( + obj_ids, jobs, remote=remote, odb=odb or dest_odb + ) + except FileTransferError as exc: + raise UploadError(exc.amount) return pushed diff --git a/tests/func/experiments/executor/test_ssh.py b/tests/func/experiments/executor/test_ssh.py index a2128de0fb..104e490660 100644 --- a/tests/func/experiments/executor/test_ssh.py +++ b/tests/func/experiments/executor/test_ssh.py @@ -1,12 +1,14 @@ import posixpath from contextlib import contextmanager from functools import partial +from urllib.parse import urlparse import pytest from dvc_ssh.tests.cloud import TEST_SSH_KEY_PATH, TEST_SSH_USER from dvc.fs.ssh import SSHFileSystem from dvc.repo.experiments.base import EXEC_HEAD, EXEC_MERGE +from dvc.repo.experiments.executor.base import ExecutorInfo, ExecutorResult from dvc.repo.experiments.executor.ssh import SSHExecutor from tests.func.machine.conftest import * # noqa, pylint: disable=wildcard-import @@ -26,10 +28,9 @@ def test_init_from_stash(tmp_dir, scm, dvc, machine_instance, mocker): mock_entry = mocker.Mock() mock_entry.name = "" SSHExecutor.from_stash_entry( - None, - "", + dvc, + "abc123", mock_entry, - manager=dvc.machine, machine_name="foo", ) _args, kwargs = mock.call_args @@ -40,8 +41,8 @@ def test_init_from_stash(tmp_dir, scm, dvc, machine_instance, mocker): @pytest.mark.parametrize("cloud", [pytest.lazy_fixture("git_ssh")]) def test_init_git(tmp_dir, scm, cloud): tmp_dir.scm_gen({"foo": "foo", "dir": {"bar": "bar"}}, commit="init") - rev = scm.get_rev() - scm.set_ref(EXEC_HEAD, rev) + baseline_rev = scm.get_rev() + scm.set_ref(EXEC_HEAD, baseline_rev) tmp_dir.gen("foo", "stashed") scm.gitpython.git.stash() rev = scm.resolve_rev("stash@{0}") @@ -50,14 +51,15 @@ def test_init_git(tmp_dir, scm, cloud): root_url = cloud / SSHExecutor.gen_dirname() executor = SSHExecutor( - scm, - ".", root_dir=root_url.path, + dvc_dir=".dvc", + baseline_rev=baseline_rev, host=root_url.host, port=root_url.port, username=TEST_SSH_USER, fs_factory=partial(_ssh_factory, cloud), ) + executor.init_git(scm) assert root_url.path == executor._repo_abspath fs = cloud._ssh @@ -76,9 +78,9 @@ def test_init_cache(tmp_dir, dvc, scm, cloud): root_url = cloud / SSHExecutor.gen_dirname() executor = SSHExecutor( - scm, - ".", root_dir=root_url.path, + dvc_dir=".dvc", + baseline_rev=rev, host=root_url.host, port=root_url.port, username=TEST_SSH_USER, @@ -93,3 +95,67 @@ def test_init_cache(tmp_dir, dvc, scm, cloud): executor._repo_abspath, ".dvc", "cache", foo_hash[:2], foo_hash[2:] ) ) + + +@pytest.mark.needs_internet +@pytest.mark.parametrize("cloud", [pytest.lazy_fixture("git_ssh")]) +def test_reproduce(tmp_dir, scm, dvc, cloud, exp_stage, mocker): + from sshfs import SSHFileSystem as _sshfs + + rev = scm.get_rev() + root_url = cloud / SSHExecutor.gen_dirname() + mocker.patch.object(SSHFileSystem, "exists", return_value=True) + mock_execute = mocker.patch.object(_sshfs, "execute") + info = ExecutorInfo( + str(root_url), + rev, + "machine-foo", + str(root_url.path), + ".dvc", + ) + infofile = str((root_url / "foo.run").path) + SSHExecutor.reproduce( + info, + rev, + infofile=infofile, + fs_factory=partial(_ssh_factory, cloud), + ) + assert mock_execute.called_once() + _name, args, _kwargs = mock_execute.mock_calls[0] + assert f"dvc exp exec-run --infofile {infofile}" in args[0] + + +@pytest.mark.needs_internet +@pytest.mark.parametrize("cloud", [pytest.lazy_fixture("git_ssh")]) +def test_run_machine(tmp_dir, scm, dvc, cloud, exp_stage, mocker): + baseline = scm.get_rev() + factory = partial(_ssh_factory, cloud) + mocker.patch.object( + dvc.machine, + "get_executor_kwargs", + return_value={ + "host": cloud.host, + "port": cloud.port, + "username": TEST_SSH_USER, + "fs_factory": factory, + }, + ) + mocker.patch.object(dvc.machine, "get_setup_script", return_value=None) + mock_repro = mocker.patch.object( + SSHExecutor, + "reproduce", + return_value=ExecutorResult("abc123", None, False), + ) + + tmp_dir.gen("params.yaml", "foo: 2") + dvc.experiments.run(exp_stage.addressing, machine="foo") + assert mock_repro.called_once() + _name, _args, kwargs = mock_repro.mock_calls[0] + info = kwargs["info"] + url = urlparse(info.git_url) + assert url.scheme == "ssh" + assert url.hostname == cloud.host + assert url.port == cloud.port + assert info.baseline_rev == baseline + assert kwargs["infofile"] is not None + assert kwargs["fs_factory"] is not None diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 7a222f548c..6717225db1 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -129,6 +129,7 @@ def test_experiments_run(dvc, scm, mocker): "tmp_dir": False, "checkpoint_resume": None, "reset": False, + "machine": None, } default_arguments.update(repro_arguments)