Skip to content

Commit

Permalink
exp run: Add --copy-paths arg.
Browse files Browse the repository at this point in the history
List of paths to copy inside the temp directory. Only used if `--temp` or `--queue` is specified.

Closes #5800
  • Loading branch information
daavoo committed Apr 4, 2023
1 parent cf52efb commit ce8434a
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 17 deletions.
10 changes: 10 additions & 0 deletions dvc/commands/experiments/exec_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def run(self):
queue=None,
log_level=logger.getEffectiveLevel(),
infofile=self.args.infofile,
copy_paths=self.args.copy_paths,
)
return 0

Expand All @@ -36,4 +37,13 @@ def add_parser(experiments_subparsers, parent_parser):
help="Path to executor info file",
default=None,
)
exec_run_parser.add_argument(
"--copy-paths",
action="append",
default=[],
help=(
"List of paths to copy inside the temp directory."
" Only used if `--temp` or `--run-all` is specified."
),
)
exec_run_parser.set_defaults(func=CmdExecutorRun)
11 changes: 11 additions & 0 deletions dvc/commands/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def run(self):
reset=self.args.reset,
tmp_dir=self.args.tmp_dir,
machine=self.args.machine,
copy_paths=self.args.copy_paths,
**self._common_kwargs,
)

Expand Down Expand Up @@ -136,3 +137,13 @@ def _add_run_common(parser):
# )
# metavar="<name>",
)
parser.add_argument(
"-C",
"--copy-paths",
action="append",
default=[],
help=(
"List of paths to copy inside the temp directory."
" Only used if `--temp` or `--queue` is specified."
),
)
11 changes: 7 additions & 4 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional

from funcy import chain, first

Expand Down Expand Up @@ -118,14 +118,15 @@ def stash_revs(self) -> Dict[str, "ExpStashEntry"]:
def reproduce_one(
self,
tmp_dir: bool = False,
copy_paths: Optional[List[str]] = None,
**kwargs,
):
"""Reproduce and checkout a single (standalone) experiment."""
exp_queue: "BaseStashQueue" = (
self.tempdir_queue if tmp_dir else self.workspace_queue
)
self.queue_one(exp_queue, **kwargs)
results = self._reproduce_queue(exp_queue)
results = self._reproduce_queue(exp_queue, copy_paths=copy_paths)
exp_rev = first(results)
if exp_rev is not None:
self._log_reproduced(results, tmp_dir=tmp_dir)
Expand Down Expand Up @@ -347,7 +348,9 @@ def reset_checkpoints(self):
self.scm.remove_ref(EXEC_APPLY)

@unlocked_repo
def _reproduce_queue(self, queue: "BaseStashQueue", **kwargs) -> Dict[str, str]:
def _reproduce_queue(
self, queue: "BaseStashQueue", copy_paths: Optional[List[str]] = None, **kwargs
) -> Dict[str, str]:
"""Reproduce queued experiments.
Arguments:
Expand All @@ -357,7 +360,7 @@ def _reproduce_queue(self, queue: "BaseStashQueue", **kwargs) -> Dict[str, str]:
dict mapping successfully reproduced experiment revs to their
results.
"""
exec_results = queue.reproduce()
exec_results = queue.reproduce(copy_paths=copy_paths)

results: Dict[str, str] = {}
for _, exp_result in exec_results.items():
Expand Down
16 changes: 16 additions & 0 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import pickle # nosec B403
import shutil
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -451,6 +452,7 @@ def reproduce(
infofile: Optional[str] = None,
log_errors: bool = True,
log_level: Optional[int] = None,
copy_paths: Optional[List[str]] = None,
**kwargs,
) -> "ExecutorResult":
"""Run dvc repro and return the result.
Expand Down Expand Up @@ -487,6 +489,7 @@ def filter_pipeline(stages):
info,
infofile,
log_errors=log_errors,
copy_paths=copy_paths,
**kwargs,
) as dvc:
if auto_push:
Expand Down Expand Up @@ -609,6 +612,7 @@ def _repro_dvc( # noqa: C901
info: "ExecutorInfo",
infofile: Optional[str] = None,
log_errors: bool = True,
copy_paths: Optional[List[str]] = None,
**kwargs,
):
from dvc_studio_client.post_live_metrics import post_live_metrics
Expand All @@ -623,6 +627,18 @@ def _repro_dvc( # noqa: C901
if cls.QUIET:
dvc.scm_context.quiet = cls.QUIET
old_cwd = os.getcwd()

if copy_paths:
for path in copy_paths:
if os.path.isfile(path):
shutil.copy(
os.path.realpath(path), os.path.join(dvc.root_dir, path)
)
elif os.path.isdir(path):
shutil.copytree(
os.path.realpath(path), os.path.join(dvc.root_dir, path)
)

if info.wdir:
os.chdir(os.path.join(dvc.scm.root_dir, info.wdir))
else:
Expand Down
3 changes: 2 additions & 1 deletion dvc/repo/experiments/executor/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import posixpath
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Iterable, Optional
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional

from dvc_ssh import SSHFileSystem
from funcy import first
Expand Down Expand Up @@ -242,6 +242,7 @@ def reproduce(
infofile: Optional[str] = None,
log_errors: bool = True,
log_level: Optional[int] = None,
copy_paths: Optional[List[str]] = None, # noqa: ARG003
**kwargs,
) -> "ExecutorResult":
"""Reproduce an experiment on a remote machine over SSH.
Expand Down
4 changes: 3 additions & 1 deletion dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
"""Iterate over items which been failed."""

@abstractmethod
def reproduce(self) -> Mapping[str, Mapping[str, str]]:
def reproduce(
self, copy_paths: Optional[List[str]] = None
) -> Mapping[str, Mapping[str, str]]:
"""Reproduce queued experiments sequentially."""

@abstractmethod
Expand Down
10 changes: 7 additions & 3 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ def start_workers(self, count: int) -> int:

return started

def put(self, *args, **kwargs) -> QueueEntry:
def put(
self, *args, copy_paths: Optional[List[str]] = None, **kwargs
) -> QueueEntry:
"""Stash an experiment and add it to the queue."""
with get_exp_rwlock(self.repo, writes=["workspace", CELERY_STASH]):
entry = self._stash_exp(*args, **kwargs)
self.celery.signature(run_exp.s(entry.asdict())).delay()
self.celery.signature(run_exp.s(entry.asdict(), copy_paths=copy_paths)).delay()
return entry

# NOTE: Queue consumption should not be done directly. Celery worker(s)
Expand Down Expand Up @@ -250,7 +252,9 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
if exp_result is None:
yield QueueDoneResult(queue_entry, exp_result)

def reproduce(self) -> Mapping[str, Mapping[str, str]]:
def reproduce(
self, copy_paths: Optional[List[str]] = None
) -> Mapping[str, Mapping[str, str]]:
raise NotImplementedError

def _load_info(self, rev: str) -> ExecutorInfo:
Expand Down
7 changes: 5 additions & 2 deletions dvc/repo/experiments/queue/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from celery import shared_task
from celery.utils.log import get_task_logger
Expand Down Expand Up @@ -91,7 +91,7 @@ def cleanup_exp(executor: TempDirExecutor, infofile: str) -> None:


@shared_task
def run_exp(entry_dict: Dict[str, Any]) -> None:
def run_exp(entry_dict: Dict[str, Any], copy_paths: Optional[List[str]] = None) -> None:
"""Run a full experiment.
Experiment subtasks are executed inline as one atomic operation.
Expand All @@ -108,6 +108,9 @@ def run_exp(entry_dict: Dict[str, Any]) -> None:
executor = setup_exp.s(entry_dict)()
try:
cmd = ["dvc", "exp", "exec-run", "--infofile", infofile]
if copy_paths:
for path in copy_paths:
cmd.extend(["--copy-paths", path])
proc_dict = queue.proc.run_signature(cmd, name=entry.stash_rev)()
collect_exp.s(proc_dict, entry_dict)()
finally:
Expand Down
9 changes: 7 additions & 2 deletions dvc/repo/experiments/queue/tempdir.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Generator, Optional
from typing import TYPE_CHECKING, Dict, Generator, List, Optional

from funcy import first

Expand Down Expand Up @@ -92,7 +92,11 @@ def iter_active(self) -> Generator[QueueEntry, None, None]:
)

def _reproduce_entry(
self, entry: QueueEntry, executor: "BaseExecutor"
self,
entry: QueueEntry,
executor: "BaseExecutor",
copy_paths: Optional[List[str]] = None,
**kwargs,
) -> Dict[str, Dict[str, str]]:
from dvc.stage.monitor import CheckpointKilledError

Expand All @@ -107,6 +111,7 @@ def _reproduce_entry(
infofile=infofile,
log_level=logger.getEffectiveLevel(),
log_errors=True,
copy_paths=copy_paths,
)
if not exec_result.exp_hash:
raise DvcException(f"Failed to reproduce experiment '{rev[:7]}'")
Expand Down
12 changes: 9 additions & 3 deletions dvc/repo/experiments/queue/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class WorkspaceQueue(BaseStashQueue):
_EXEC_NAME: Optional[str] = "workspace"

def put(self, *args, **kwargs) -> QueueEntry:
kwargs.pop("copy_paths", None)
with get_exp_rwlock(self.repo, writes=["workspace", WORKSPACE_STASH]):
return self._stash_exp(*args, **kwargs)

Expand Down Expand Up @@ -81,19 +82,24 @@ def iter_failed(self) -> Generator["QueueDoneResult", None, None]:
def iter_success(self) -> Generator["QueueDoneResult", None, None]:
raise NotImplementedError

def reproduce(self) -> Dict[str, Dict[str, str]]:
def reproduce(
self, copy_paths: Optional[List[str]] = None
) -> Dict[str, Dict[str, str]]:
results: Dict[str, Dict[str, str]] = defaultdict(dict)
try:
while True:
entry, executor = self.get()
results.update(self._reproduce_entry(entry, executor))
results.update(
self._reproduce_entry(entry, executor, copy_paths=copy_paths)
)
except ExpQueueEmptyError:
pass
return results

def _reproduce_entry(
self, entry: QueueEntry, executor: "BaseExecutor"
self, entry: QueueEntry, executor: "BaseExecutor", **kwargs
) -> Dict[str, Dict[str, str]]:
kwargs.pop("copy_paths", None)
from dvc.stage.monitor import CheckpointKilledError

results: Dict[str, Dict[str, str]] = defaultdict(dict)
Expand Down
8 changes: 7 additions & 1 deletion dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def run( # noqa: C901
jobs: int = 1,
tmp_dir: bool = False,
queue: bool = False,
copy_paths: Optional[Iterable[str]] = None,
**kwargs,
) -> Dict[str, str]:
"""Reproduce the specified targets as an experiment.
Expand Down Expand Up @@ -57,7 +58,11 @@ def run( # noqa: C901

if not queue:
return repo.experiments.reproduce_one(
targets=targets, params=path_overrides, tmp_dir=tmp_dir, **kwargs
targets=targets,
params=path_overrides,
tmp_dir=tmp_dir,
copy_paths=copy_paths,
**kwargs,
)

if hydra_sweep:
Expand All @@ -78,6 +83,7 @@ def run( # noqa: C901
repo.experiments.celery_queue,
targets=targets,
params=sweep_overrides,
copy_paths=copy_paths,
**kwargs,
)
if sweep_overrides:
Expand Down
45 changes: 45 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,48 @@ def test_untracked_top_level_files_are_included_in_exp(tmp_dir, scm, dvc, tmp):
fs = scm.get_fs(exp)
for file in ["metrics.json", "params.yaml", "plots.csv"]:
assert fs.exists(file)


@pytest.mark.parametrize("tmp", [True, False])
def test_copy_paths(tmp_dir, scm, dvc, tmp):
stage = dvc.stage.add(
cmd="cat file && ls dir",
name="foo",
)
scm.add_commit(["dvc.yaml"], message="add dvc.yaml")

(tmp_dir / "dir").mkdir()
(tmp_dir / "dir" / "file").write_text("dir/file")
scm.ignore(tmp_dir / "dir")
(tmp_dir / "file").write_text("file")
scm.ignore(tmp_dir / "file")

results = dvc.experiments.run(
stage.addressing, tmp_dir=tmp, copy_paths=["dir", "file"]
)
exp = first(results)
fs = scm.get_fs(exp)
assert not fs.exists("dir")
assert not fs.exists("file")


def test_copy_paths_queue(tmp_dir, scm, dvc):
stage = dvc.stage.add(
cmd="cat file && ls dir",
name="foo",
)
scm.add_commit(["dvc.yaml"], message="add dvc.yaml")

(tmp_dir / "dir").mkdir()
(tmp_dir / "dir" / "file").write_text("dir/file")
scm.ignore(tmp_dir / "dir")
(tmp_dir / "file").write_text("file")
scm.ignore(tmp_dir / "file")

dvc.experiments.run(stage.addressing, queue=True)
results = dvc.experiments.run(run_all=True)

exp = first(results)
fs = scm.get_fs(exp)
assert not fs.exists("dir")
assert not fs.exists("file")
1 change: 1 addition & 0 deletions tests/func/experiments/test_set_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test_hydra_sweep(
params=e,
reset=True,
targets=None,
copy_paths=None
)


Expand Down
1 change: 1 addition & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def test_experiments_run(dvc, scm, mocker):
"checkpoint_resume": None,
"reset": False,
"machine": None,
"copy_paths": [],
}
default_arguments.update(repro_arguments)

Expand Down

0 comments on commit ce8434a

Please sign in to comment.