From ce8434a25db8856e1df43ec54abcb8bdc23a0342 Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Tue, 4 Apr 2023 18:20:28 +0200 Subject: [PATCH] exp run: Add `--copy-paths` arg. List of paths to copy inside the temp directory. Only used if `--temp` or `--queue` is specified. Closes #5800 --- dvc/commands/experiments/exec_run.py | 10 +++++ dvc/commands/experiments/run.py | 11 ++++++ dvc/repo/experiments/__init__.py | 11 ++++-- dvc/repo/experiments/executor/base.py | 16 ++++++++ dvc/repo/experiments/executor/ssh.py | 3 +- dvc/repo/experiments/queue/base.py | 4 +- dvc/repo/experiments/queue/celery.py | 10 +++-- dvc/repo/experiments/queue/tasks.py | 7 +++- dvc/repo/experiments/queue/tempdir.py | 9 ++++- dvc/repo/experiments/queue/workspace.py | 12 ++++-- dvc/repo/experiments/run.py | 8 +++- tests/func/experiments/test_experiments.py | 45 ++++++++++++++++++++++ tests/func/experiments/test_set_params.py | 1 + tests/unit/command/test_experiments.py | 1 + 14 files changed, 131 insertions(+), 17 deletions(-) diff --git a/dvc/commands/experiments/exec_run.py b/dvc/commands/experiments/exec_run.py index a327f1fbe4..6716d66bee 100644 --- a/dvc/commands/experiments/exec_run.py +++ b/dvc/commands/experiments/exec_run.py @@ -19,6 +19,7 @@ def run(self): queue=None, log_level=logger.getEffectiveLevel(), infofile=self.args.infofile, + copy_paths=self.args.copy_paths, ) return 0 @@ -36,4 +37,13 @@ def add_parser(experiments_subparsers, parent_parser): help="Path to executor info file", default=None, ) + exec_run_parser.add_argument( + "--copy-paths", + action="append", + default=[], + help=( + "List of paths to copy inside the temp directory." + " Only used if `--temp` or `--run-all` is specified." + ), + ) exec_run_parser.set_defaults(func=CmdExecutorRun) diff --git a/dvc/commands/experiments/run.py b/dvc/commands/experiments/run.py index bbc5954872..690ff22e19 100644 --- a/dvc/commands/experiments/run.py +++ b/dvc/commands/experiments/run.py @@ -36,6 +36,7 @@ def run(self): reset=self.args.reset, tmp_dir=self.args.tmp_dir, machine=self.args.machine, + copy_paths=self.args.copy_paths, **self._common_kwargs, ) @@ -136,3 +137,13 @@ def _add_run_common(parser): # ) # metavar="", ) + parser.add_argument( + "-C", + "--copy-paths", + action="append", + default=[], + help=( + "List of paths to copy inside the temp directory." + " Only used if `--temp` or `--queue` is specified." + ), + ) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 7057b71212..14b25c8d97 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -1,7 +1,7 @@ import logging import os import re -from typing import TYPE_CHECKING, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from funcy import chain, first @@ -118,6 +118,7 @@ def stash_revs(self) -> Dict[str, "ExpStashEntry"]: def reproduce_one( self, tmp_dir: bool = False, + copy_paths: Optional[List[str]] = None, **kwargs, ): """Reproduce and checkout a single (standalone) experiment.""" @@ -125,7 +126,7 @@ def reproduce_one( self.tempdir_queue if tmp_dir else self.workspace_queue ) self.queue_one(exp_queue, **kwargs) - results = self._reproduce_queue(exp_queue) + results = self._reproduce_queue(exp_queue, copy_paths=copy_paths) exp_rev = first(results) if exp_rev is not None: self._log_reproduced(results, tmp_dir=tmp_dir) @@ -347,7 +348,9 @@ def reset_checkpoints(self): self.scm.remove_ref(EXEC_APPLY) @unlocked_repo - def _reproduce_queue(self, queue: "BaseStashQueue", **kwargs) -> Dict[str, str]: + def _reproduce_queue( + self, queue: "BaseStashQueue", copy_paths: Optional[List[str]] = None, **kwargs + ) -> Dict[str, str]: """Reproduce queued experiments. Arguments: @@ -357,7 +360,7 @@ def _reproduce_queue(self, queue: "BaseStashQueue", **kwargs) -> Dict[str, str]: dict mapping successfully reproduced experiment revs to their results. """ - exec_results = queue.reproduce() + exec_results = queue.reproduce(copy_paths=copy_paths) results: Dict[str, str] = {} for _, exp_result in exec_results.items(): diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index 64cb12085f..e0affa4d86 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -1,6 +1,7 @@ import logging import os import pickle # nosec B403 +import shutil from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import asdict, dataclass @@ -451,6 +452,7 @@ def reproduce( infofile: Optional[str] = None, log_errors: bool = True, log_level: Optional[int] = None, + copy_paths: Optional[List[str]] = None, **kwargs, ) -> "ExecutorResult": """Run dvc repro and return the result. @@ -487,6 +489,7 @@ def filter_pipeline(stages): info, infofile, log_errors=log_errors, + copy_paths=copy_paths, **kwargs, ) as dvc: if auto_push: @@ -609,6 +612,7 @@ def _repro_dvc( # noqa: C901 info: "ExecutorInfo", infofile: Optional[str] = None, log_errors: bool = True, + copy_paths: Optional[List[str]] = None, **kwargs, ): from dvc_studio_client.post_live_metrics import post_live_metrics @@ -623,6 +627,18 @@ def _repro_dvc( # noqa: C901 if cls.QUIET: dvc.scm_context.quiet = cls.QUIET old_cwd = os.getcwd() + + if copy_paths: + for path in copy_paths: + if os.path.isfile(path): + shutil.copy( + os.path.realpath(path), os.path.join(dvc.root_dir, path) + ) + elif os.path.isdir(path): + shutil.copytree( + os.path.realpath(path), os.path.join(dvc.root_dir, path) + ) + if info.wdir: os.chdir(os.path.join(dvc.scm.root_dir, info.wdir)) else: diff --git a/dvc/repo/experiments/executor/ssh.py b/dvc/repo/experiments/executor/ssh.py index d71a7e606d..623c64a304 100644 --- a/dvc/repo/experiments/executor/ssh.py +++ b/dvc/repo/experiments/executor/ssh.py @@ -3,7 +3,7 @@ import posixpath import sys from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Iterable, Optional +from typing import TYPE_CHECKING, Callable, Iterable, List, Optional from dvc_ssh import SSHFileSystem from funcy import first @@ -242,6 +242,7 @@ def reproduce( infofile: Optional[str] = None, log_errors: bool = True, log_level: Optional[int] = None, + copy_paths: Optional[List[str]] = None, # noqa: ARG003 **kwargs, ) -> "ExecutorResult": """Reproduce an experiment on a remote machine over SSH. diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index 4f4fb9726b..39b65e9937 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -252,7 +252,9 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]: """Iterate over items which been failed.""" @abstractmethod - def reproduce(self) -> Mapping[str, Mapping[str, str]]: + def reproduce( + self, copy_paths: Optional[List[str]] = None + ) -> Mapping[str, Mapping[str, str]]: """Reproduce queued experiments sequentially.""" @abstractmethod diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index 20750e260a..cc9669e7b6 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -174,11 +174,13 @@ def start_workers(self, count: int) -> int: return started - def put(self, *args, **kwargs) -> QueueEntry: + def put( + self, *args, copy_paths: Optional[List[str]] = None, **kwargs + ) -> QueueEntry: """Stash an experiment and add it to the queue.""" with get_exp_rwlock(self.repo, writes=["workspace", CELERY_STASH]): entry = self._stash_exp(*args, **kwargs) - self.celery.signature(run_exp.s(entry.asdict())).delay() + self.celery.signature(run_exp.s(entry.asdict(), copy_paths=copy_paths)).delay() return entry # NOTE: Queue consumption should not be done directly. Celery worker(s) @@ -250,7 +252,9 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]: if exp_result is None: yield QueueDoneResult(queue_entry, exp_result) - def reproduce(self) -> Mapping[str, Mapping[str, str]]: + def reproduce( + self, copy_paths: Optional[List[str]] = None + ) -> Mapping[str, Mapping[str, str]]: raise NotImplementedError def _load_info(self, rev: str) -> ExecutorInfo: diff --git a/dvc/repo/experiments/queue/tasks.py b/dvc/repo/experiments/queue/tasks.py index e32325da3f..406e8bf45c 100644 --- a/dvc/repo/experiments/queue/tasks.py +++ b/dvc/repo/experiments/queue/tasks.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List, Optional from celery import shared_task from celery.utils.log import get_task_logger @@ -91,7 +91,7 @@ def cleanup_exp(executor: TempDirExecutor, infofile: str) -> None: @shared_task -def run_exp(entry_dict: Dict[str, Any]) -> None: +def run_exp(entry_dict: Dict[str, Any], copy_paths: Optional[List[str]] = None) -> None: """Run a full experiment. Experiment subtasks are executed inline as one atomic operation. @@ -108,6 +108,9 @@ def run_exp(entry_dict: Dict[str, Any]) -> None: executor = setup_exp.s(entry_dict)() try: cmd = ["dvc", "exp", "exec-run", "--infofile", infofile] + if copy_paths: + for path in copy_paths: + cmd.extend(["--copy-paths", path]) proc_dict = queue.proc.run_signature(cmd, name=entry.stash_rev)() collect_exp.s(proc_dict, entry_dict)() finally: diff --git a/dvc/repo/experiments/queue/tempdir.py b/dvc/repo/experiments/queue/tempdir.py index 27bcaab61f..944e835977 100644 --- a/dvc/repo/experiments/queue/tempdir.py +++ b/dvc/repo/experiments/queue/tempdir.py @@ -1,7 +1,7 @@ import logging import os from collections import defaultdict -from typing import TYPE_CHECKING, Dict, Generator, Optional +from typing import TYPE_CHECKING, Dict, Generator, List, Optional from funcy import first @@ -92,7 +92,11 @@ def iter_active(self) -> Generator[QueueEntry, None, None]: ) def _reproduce_entry( - self, entry: QueueEntry, executor: "BaseExecutor" + self, + entry: QueueEntry, + executor: "BaseExecutor", + copy_paths: Optional[List[str]] = None, + **kwargs, ) -> Dict[str, Dict[str, str]]: from dvc.stage.monitor import CheckpointKilledError @@ -107,6 +111,7 @@ def _reproduce_entry( infofile=infofile, log_level=logger.getEffectiveLevel(), log_errors=True, + copy_paths=copy_paths, ) if not exec_result.exp_hash: raise DvcException(f"Failed to reproduce experiment '{rev[:7]}'") diff --git a/dvc/repo/experiments/queue/workspace.py b/dvc/repo/experiments/queue/workspace.py index 063f1c54fb..5ba6b29db8 100644 --- a/dvc/repo/experiments/queue/workspace.py +++ b/dvc/repo/experiments/queue/workspace.py @@ -33,6 +33,7 @@ class WorkspaceQueue(BaseStashQueue): _EXEC_NAME: Optional[str] = "workspace" def put(self, *args, **kwargs) -> QueueEntry: + kwargs.pop("copy_paths", None) with get_exp_rwlock(self.repo, writes=["workspace", WORKSPACE_STASH]): return self._stash_exp(*args, **kwargs) @@ -81,19 +82,24 @@ def iter_failed(self) -> Generator["QueueDoneResult", None, None]: def iter_success(self) -> Generator["QueueDoneResult", None, None]: raise NotImplementedError - def reproduce(self) -> Dict[str, Dict[str, str]]: + def reproduce( + self, copy_paths: Optional[List[str]] = None + ) -> Dict[str, Dict[str, str]]: results: Dict[str, Dict[str, str]] = defaultdict(dict) try: while True: entry, executor = self.get() - results.update(self._reproduce_entry(entry, executor)) + results.update( + self._reproduce_entry(entry, executor, copy_paths=copy_paths) + ) except ExpQueueEmptyError: pass return results def _reproduce_entry( - self, entry: QueueEntry, executor: "BaseExecutor" + self, entry: QueueEntry, executor: "BaseExecutor", **kwargs ) -> Dict[str, Dict[str, str]]: + kwargs.pop("copy_paths", None) from dvc.stage.monitor import CheckpointKilledError results: Dict[str, Dict[str, str]] = defaultdict(dict) diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index 5975a375d1..42647ab437 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -19,6 +19,7 @@ def run( # noqa: C901 jobs: int = 1, tmp_dir: bool = False, queue: bool = False, + copy_paths: Optional[Iterable[str]] = None, **kwargs, ) -> Dict[str, str]: """Reproduce the specified targets as an experiment. @@ -57,7 +58,11 @@ def run( # noqa: C901 if not queue: return repo.experiments.reproduce_one( - targets=targets, params=path_overrides, tmp_dir=tmp_dir, **kwargs + targets=targets, + params=path_overrides, + tmp_dir=tmp_dir, + copy_paths=copy_paths, + **kwargs, ) if hydra_sweep: @@ -78,6 +83,7 @@ def run( # noqa: C901 repo.experiments.celery_queue, targets=targets, params=sweep_overrides, + copy_paths=copy_paths, **kwargs, ) if sweep_overrides: diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 794adf7bc7..ca6ae70985 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -704,3 +704,48 @@ def test_untracked_top_level_files_are_included_in_exp(tmp_dir, scm, dvc, tmp): fs = scm.get_fs(exp) for file in ["metrics.json", "params.yaml", "plots.csv"]: assert fs.exists(file) + + +@pytest.mark.parametrize("tmp", [True, False]) +def test_copy_paths(tmp_dir, scm, dvc, tmp): + stage = dvc.stage.add( + cmd="cat file && ls dir", + name="foo", + ) + scm.add_commit(["dvc.yaml"], message="add dvc.yaml") + + (tmp_dir / "dir").mkdir() + (tmp_dir / "dir" / "file").write_text("dir/file") + scm.ignore(tmp_dir / "dir") + (tmp_dir / "file").write_text("file") + scm.ignore(tmp_dir / "file") + + results = dvc.experiments.run( + stage.addressing, tmp_dir=tmp, copy_paths=["dir", "file"] + ) + exp = first(results) + fs = scm.get_fs(exp) + assert not fs.exists("dir") + assert not fs.exists("file") + + +def test_copy_paths_queue(tmp_dir, scm, dvc): + stage = dvc.stage.add( + cmd="cat file && ls dir", + name="foo", + ) + scm.add_commit(["dvc.yaml"], message="add dvc.yaml") + + (tmp_dir / "dir").mkdir() + (tmp_dir / "dir" / "file").write_text("dir/file") + scm.ignore(tmp_dir / "dir") + (tmp_dir / "file").write_text("file") + scm.ignore(tmp_dir / "file") + + dvc.experiments.run(stage.addressing, queue=True) + results = dvc.experiments.run(run_all=True) + + exp = first(results) + fs = scm.get_fs(exp) + assert not fs.exists("dir") + assert not fs.exists("file") diff --git a/tests/func/experiments/test_set_params.py b/tests/func/experiments/test_set_params.py index d708d6f268..2ffd7a8ab9 100644 --- a/tests/func/experiments/test_set_params.py +++ b/tests/func/experiments/test_set_params.py @@ -120,6 +120,7 @@ def test_hydra_sweep( params=e, reset=True, targets=None, + copy_paths=None ) diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 9ca3fd54f7..fac882e245 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -138,6 +138,7 @@ def test_experiments_run(dvc, scm, mocker): "checkpoint_resume": None, "reset": False, "machine": None, + "copy_paths": [], } default_arguments.update(repro_arguments)