From c46fa7c156341f06eb5d9990451d4ac0696ac5ec Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Wed, 5 Apr 2023 19:07:42 +0200 Subject: [PATCH] exp run: Add `--copy-paths` arg. List of paths to copy inside the temp directory. Only used if `--temp` or `--queue` is specified. Closes #5800 --- dvc/commands/experiments/exec_run.py | 11 ++++++ dvc/commands/experiments/run.py | 11 ++++++ dvc/repo/experiments/__init__.py | 11 +++--- dvc/repo/experiments/executor/base.py | 22 ++++++++++++ dvc/repo/experiments/executor/ssh.py | 3 +- dvc/repo/experiments/queue/base.py | 4 ++- dvc/repo/experiments/queue/celery.py | 10 ++++-- dvc/repo/experiments/queue/tasks.py | 7 ++-- dvc/repo/experiments/queue/tempdir.py | 9 +++-- dvc/repo/experiments/queue/workspace.py | 12 +++++-- dvc/repo/experiments/run.py | 8 ++++- tests/func/experiments/test_experiments.py | 42 +++++++++++++++++++++- tests/func/experiments/test_queue.py | 22 ++++++++++++ tests/func/experiments/test_set_params.py | 5 +-- tests/unit/command/test_experiments.py | 1 + 15 files changed, 156 insertions(+), 22 deletions(-) diff --git a/dvc/commands/experiments/exec_run.py b/dvc/commands/experiments/exec_run.py index a327f1fbe4b..942ded6a65d 100644 --- a/dvc/commands/experiments/exec_run.py +++ b/dvc/commands/experiments/exec_run.py @@ -19,6 +19,7 @@ def run(self): queue=None, log_level=logger.getEffectiveLevel(), infofile=self.args.infofile, + copy_paths=self.args.copy_paths, ) return 0 @@ -36,4 +37,14 @@ def add_parser(experiments_subparsers, parent_parser): help="Path to executor info file", default=None, ) + exec_run_parser.add_argument( + "-C", + "--copy-paths", + action="append", + default=[], + help=( + "List of ignored or untracked paths to copy into the temp directory." + " Only used if `--temp` or `--queue` is specified." + ), + ) exec_run_parser.set_defaults(func=CmdExecutorRun) diff --git a/dvc/commands/experiments/run.py b/dvc/commands/experiments/run.py index bbc59548720..ee56636f7e7 100644 --- a/dvc/commands/experiments/run.py +++ b/dvc/commands/experiments/run.py @@ -36,6 +36,7 @@ def run(self): reset=self.args.reset, tmp_dir=self.args.tmp_dir, machine=self.args.machine, + copy_paths=self.args.copy_paths, **self._common_kwargs, ) @@ -136,3 +137,13 @@ def _add_run_common(parser): # ) # metavar="", ) + parser.add_argument( + "-C", + "--copy-paths", + action="append", + default=[], + help=( + "List of ignored or untracked paths to copy into the temp directory." + " Only used if `--temp` or `--queue` is specified." + ), + ) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 7057b712126..14b25c8d974 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -1,7 +1,7 @@ import logging import os import re -from typing import TYPE_CHECKING, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from funcy import chain, first @@ -118,6 +118,7 @@ def stash_revs(self) -> Dict[str, "ExpStashEntry"]: def reproduce_one( self, tmp_dir: bool = False, + copy_paths: Optional[List[str]] = None, **kwargs, ): """Reproduce and checkout a single (standalone) experiment.""" @@ -125,7 +126,7 @@ def reproduce_one( self.tempdir_queue if tmp_dir else self.workspace_queue ) self.queue_one(exp_queue, **kwargs) - results = self._reproduce_queue(exp_queue) + results = self._reproduce_queue(exp_queue, copy_paths=copy_paths) exp_rev = first(results) if exp_rev is not None: self._log_reproduced(results, tmp_dir=tmp_dir) @@ -347,7 +348,9 @@ def reset_checkpoints(self): self.scm.remove_ref(EXEC_APPLY) @unlocked_repo - def _reproduce_queue(self, queue: "BaseStashQueue", **kwargs) -> Dict[str, str]: + def _reproduce_queue( + self, queue: "BaseStashQueue", copy_paths: Optional[List[str]] = None, **kwargs + ) -> Dict[str, str]: """Reproduce queued experiments. Arguments: @@ -357,7 +360,7 @@ def _reproduce_queue(self, queue: "BaseStashQueue", **kwargs) -> Dict[str, str]: dict mapping successfully reproduced experiment revs to their results. """ - exec_results = queue.reproduce() + exec_results = queue.reproduce(copy_paths=copy_paths) results: Dict[str, str] = {} for _, exp_result in exec_results.items(): diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index 64cb12085f9..5e4f4fba415 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -1,6 +1,7 @@ import logging import os import pickle # nosec B403 +import shutil from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import asdict, dataclass @@ -451,6 +452,7 @@ def reproduce( infofile: Optional[str] = None, log_errors: bool = True, log_level: Optional[int] = None, + copy_paths: Optional[List[str]] = None, **kwargs, ) -> "ExecutorResult": """Run dvc repro and return the result. @@ -487,6 +489,7 @@ def filter_pipeline(stages): info, infofile, log_errors=log_errors, + copy_paths=copy_paths, **kwargs, ) as dvc: if auto_push: @@ -609,6 +612,7 @@ def _repro_dvc( # noqa: C901 info: "ExecutorInfo", infofile: Optional[str] = None, log_errors: bool = True, + copy_paths: Optional[List[str]] = None, **kwargs, ): from dvc_studio_client.post_live_metrics import post_live_metrics @@ -623,6 +627,10 @@ def _repro_dvc( # noqa: C901 if cls.QUIET: dvc.scm_context.quiet = cls.QUIET old_cwd = os.getcwd() + + for path in copy_paths or []: + cls._copy_path(os.path.realpath(path), os.path.join(dvc.root_dir, path)) + if info.wdir: os.chdir(os.path.join(dvc.scm.root_dir, info.wdir)) else: @@ -792,6 +800,20 @@ def _set_log_level(level): if level is not None: dvc_logger.setLevel(level) + @staticmethod + def _copy_path(src, dst): + try: + if os.path.isfile(src): + shutil.copy(src, dst) + elif os.path.isdir(src): + shutil.copytree(src, dst) + else: + raise DvcException( + f"Unable to copy '{src}'. It is not a file or directory." + ) + except OSError as exc: + raise DvcException(f"Unable to copy '{src}' to '{dst}'.") from exc + @contextmanager def set_temp_refs(self, scm: "Git", temp_dict: Dict[str, str]): try: diff --git a/dvc/repo/experiments/executor/ssh.py b/dvc/repo/experiments/executor/ssh.py index d71a7e606d7..623c64a3041 100644 --- a/dvc/repo/experiments/executor/ssh.py +++ b/dvc/repo/experiments/executor/ssh.py @@ -3,7 +3,7 @@ import posixpath import sys from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Iterable, Optional +from typing import TYPE_CHECKING, Callable, Iterable, List, Optional from dvc_ssh import SSHFileSystem from funcy import first @@ -242,6 +242,7 @@ def reproduce( infofile: Optional[str] = None, log_errors: bool = True, log_level: Optional[int] = None, + copy_paths: Optional[List[str]] = None, # noqa: ARG003 **kwargs, ) -> "ExecutorResult": """Reproduce an experiment on a remote machine over SSH. diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index 4f4fb9726bc..39b65e99372 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -252,7 +252,9 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]: """Iterate over items which been failed.""" @abstractmethod - def reproduce(self) -> Mapping[str, Mapping[str, str]]: + def reproduce( + self, copy_paths: Optional[List[str]] = None + ) -> Mapping[str, Mapping[str, str]]: """Reproduce queued experiments sequentially.""" @abstractmethod diff --git a/dvc/repo/experiments/queue/celery.py b/dvc/repo/experiments/queue/celery.py index 20750e260aa..cc9669e7b67 100644 --- a/dvc/repo/experiments/queue/celery.py +++ b/dvc/repo/experiments/queue/celery.py @@ -174,11 +174,13 @@ def start_workers(self, count: int) -> int: return started - def put(self, *args, **kwargs) -> QueueEntry: + def put( + self, *args, copy_paths: Optional[List[str]] = None, **kwargs + ) -> QueueEntry: """Stash an experiment and add it to the queue.""" with get_exp_rwlock(self.repo, writes=["workspace", CELERY_STASH]): entry = self._stash_exp(*args, **kwargs) - self.celery.signature(run_exp.s(entry.asdict())).delay() + self.celery.signature(run_exp.s(entry.asdict(), copy_paths=copy_paths)).delay() return entry # NOTE: Queue consumption should not be done directly. Celery worker(s) @@ -250,7 +252,9 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]: if exp_result is None: yield QueueDoneResult(queue_entry, exp_result) - def reproduce(self) -> Mapping[str, Mapping[str, str]]: + def reproduce( + self, copy_paths: Optional[List[str]] = None + ) -> Mapping[str, Mapping[str, str]]: raise NotImplementedError def _load_info(self, rev: str) -> ExecutorInfo: diff --git a/dvc/repo/experiments/queue/tasks.py b/dvc/repo/experiments/queue/tasks.py index e32325da3f1..406e8bf45c5 100644 --- a/dvc/repo/experiments/queue/tasks.py +++ b/dvc/repo/experiments/queue/tasks.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List, Optional from celery import shared_task from celery.utils.log import get_task_logger @@ -91,7 +91,7 @@ def cleanup_exp(executor: TempDirExecutor, infofile: str) -> None: @shared_task -def run_exp(entry_dict: Dict[str, Any]) -> None: +def run_exp(entry_dict: Dict[str, Any], copy_paths: Optional[List[str]] = None) -> None: """Run a full experiment. Experiment subtasks are executed inline as one atomic operation. @@ -108,6 +108,9 @@ def run_exp(entry_dict: Dict[str, Any]) -> None: executor = setup_exp.s(entry_dict)() try: cmd = ["dvc", "exp", "exec-run", "--infofile", infofile] + if copy_paths: + for path in copy_paths: + cmd.extend(["--copy-paths", path]) proc_dict = queue.proc.run_signature(cmd, name=entry.stash_rev)() collect_exp.s(proc_dict, entry_dict)() finally: diff --git a/dvc/repo/experiments/queue/tempdir.py b/dvc/repo/experiments/queue/tempdir.py index 27bcaab61f4..944e8359771 100644 --- a/dvc/repo/experiments/queue/tempdir.py +++ b/dvc/repo/experiments/queue/tempdir.py @@ -1,7 +1,7 @@ import logging import os from collections import defaultdict -from typing import TYPE_CHECKING, Dict, Generator, Optional +from typing import TYPE_CHECKING, Dict, Generator, List, Optional from funcy import first @@ -92,7 +92,11 @@ def iter_active(self) -> Generator[QueueEntry, None, None]: ) def _reproduce_entry( - self, entry: QueueEntry, executor: "BaseExecutor" + self, + entry: QueueEntry, + executor: "BaseExecutor", + copy_paths: Optional[List[str]] = None, + **kwargs, ) -> Dict[str, Dict[str, str]]: from dvc.stage.monitor import CheckpointKilledError @@ -107,6 +111,7 @@ def _reproduce_entry( infofile=infofile, log_level=logger.getEffectiveLevel(), log_errors=True, + copy_paths=copy_paths, ) if not exec_result.exp_hash: raise DvcException(f"Failed to reproduce experiment '{rev[:7]}'") diff --git a/dvc/repo/experiments/queue/workspace.py b/dvc/repo/experiments/queue/workspace.py index 063f1c54fb9..5ba6b29db8e 100644 --- a/dvc/repo/experiments/queue/workspace.py +++ b/dvc/repo/experiments/queue/workspace.py @@ -33,6 +33,7 @@ class WorkspaceQueue(BaseStashQueue): _EXEC_NAME: Optional[str] = "workspace" def put(self, *args, **kwargs) -> QueueEntry: + kwargs.pop("copy_paths", None) with get_exp_rwlock(self.repo, writes=["workspace", WORKSPACE_STASH]): return self._stash_exp(*args, **kwargs) @@ -81,19 +82,24 @@ def iter_failed(self) -> Generator["QueueDoneResult", None, None]: def iter_success(self) -> Generator["QueueDoneResult", None, None]: raise NotImplementedError - def reproduce(self) -> Dict[str, Dict[str, str]]: + def reproduce( + self, copy_paths: Optional[List[str]] = None + ) -> Dict[str, Dict[str, str]]: results: Dict[str, Dict[str, str]] = defaultdict(dict) try: while True: entry, executor = self.get() - results.update(self._reproduce_entry(entry, executor)) + results.update( + self._reproduce_entry(entry, executor, copy_paths=copy_paths) + ) except ExpQueueEmptyError: pass return results def _reproduce_entry( - self, entry: QueueEntry, executor: "BaseExecutor" + self, entry: QueueEntry, executor: "BaseExecutor", **kwargs ) -> Dict[str, Dict[str, str]]: + kwargs.pop("copy_paths", None) from dvc.stage.monitor import CheckpointKilledError results: Dict[str, Dict[str, str]] = defaultdict(dict) diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index 42e5ab91005..829ab61e1e8 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -19,6 +19,7 @@ def run( # noqa: C901, PLR0912 jobs: int = 1, tmp_dir: bool = False, queue: bool = False, + copy_paths: Optional[Iterable[str]] = None, **kwargs, ) -> Dict[str, str]: """Reproduce the specified targets as an experiment. @@ -69,7 +70,11 @@ def run( # noqa: C901, PLR0912 if not queue: return repo.experiments.reproduce_one( - targets=targets, params=path_overrides, tmp_dir=tmp_dir, **kwargs + targets=targets, + params=path_overrides, + tmp_dir=tmp_dir, + copy_paths=copy_paths, + **kwargs, ) if hydra_sweep: @@ -90,6 +95,7 @@ def run( # noqa: C901, PLR0912 repo.experiments.celery_queue, targets=targets, params=sweep_overrides, + copy_paths=copy_paths, **kwargs, ) if sweep_overrides: diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 794adf7bc7c..ae9140d56f3 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -9,7 +9,7 @@ from funcy import first from dvc.dvcfile import PROJECT_FILE -from dvc.exceptions import ReproductionError +from dvc.exceptions import DvcException, ReproductionError from dvc.repo.experiments.exceptions import ExperimentExistsError from dvc.repo.experiments.queue.base import BaseStashQueue from dvc.repo.experiments.refs import CELERY_STASH @@ -704,3 +704,43 @@ def test_untracked_top_level_files_are_included_in_exp(tmp_dir, scm, dvc, tmp): fs = scm.get_fs(exp) for file in ["metrics.json", "params.yaml", "plots.csv"]: assert fs.exists(file) + + +@pytest.mark.parametrize("tmp", [True, False]) +def test_copy_paths(tmp_dir, scm, dvc, tmp): + stage = dvc.stage.add( + cmd="cat file && ls dir", + name="foo", + ) + scm.add_commit(["dvc.yaml"], message="add dvc.yaml") + + (tmp_dir / "dir").mkdir() + (tmp_dir / "dir" / "file").write_text("dir/file") + scm.ignore(tmp_dir / "dir") + (tmp_dir / "file").write_text("file") + scm.ignore(tmp_dir / "file") + + results = dvc.experiments.run( + stage.addressing, tmp_dir=tmp, copy_paths=["dir", "file"] + ) + exp = first(results) + fs = scm.get_fs(exp) + assert not fs.exists("dir") + assert not fs.exists("file") + + +def test_copy_paths_errors(tmp_dir, scm, dvc, mocker): + stage = dvc.stage.add( + cmd="echo foo", + name="foo", + ) + scm.add_commit(["dvc.yaml"], message="add dvc.yaml") + + with pytest.raises(DvcException, match="Unable to copy"): + dvc.experiments.run(stage.addressing, tmp_dir=True, copy_paths=["foo"]) + + (tmp_dir / "foo").write_text("foo") + mocker.patch("shutil.copy", side_effect=OSError) + + with pytest.raises(DvcException, match="Unable to copy"): + dvc.experiments.run(stage.addressing, tmp_dir=True, copy_paths=["foo"]) diff --git a/tests/func/experiments/test_queue.py b/tests/func/experiments/test_queue.py index eee07ecc109..5f90656e94a 100644 --- a/tests/func/experiments/test_queue.py +++ b/tests/func/experiments/test_queue.py @@ -108,3 +108,25 @@ def test_queue_doesnt_remove_untracked_params_file(tmp_dir, dvc, scm): scm.commit("init") dvc.experiments.run(stage.addressing, params=["foo=2"], queue=True) assert (tmp_dir / "params.yaml").exists() + + +def test_copy_paths_queue(tmp_dir, scm, dvc): + stage = dvc.stage.add( + cmd="cat file && ls dir", + name="foo", + ) + scm.add_commit(["dvc.yaml"], message="add dvc.yaml") + + (tmp_dir / "dir").mkdir() + (tmp_dir / "dir" / "file").write_text("dir/file") + scm.ignore(tmp_dir / "dir") + (tmp_dir / "file").write_text("file") + scm.ignore(tmp_dir / "file") + + dvc.experiments.run(stage.addressing, queue=True) + results = dvc.experiments.run(run_all=True) + + exp = first(results) + fs = scm.get_fs(exp) + assert not fs.exists("dir") + assert not fs.exists("file") diff --git a/tests/func/experiments/test_set_params.py b/tests/func/experiments/test_set_params.py index d708d6f268a..f3de6f2dd34 100644 --- a/tests/func/experiments/test_set_params.py +++ b/tests/func/experiments/test_set_params.py @@ -116,10 +116,7 @@ def test_hydra_sweep( assert patched.call_count == len(expected) for e in expected: patched.assert_any_call( - mocker.ANY, - params=e, - reset=True, - targets=None, + mocker.ANY, params=e, reset=True, targets=None, copy_paths=None ) diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 9ca3fd54f70..fac882e2450 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -138,6 +138,7 @@ def test_experiments_run(dvc, scm, mocker): "checkpoint_resume": None, "reset": False, "machine": None, + "copy_paths": [], } default_arguments.update(repro_arguments)