Skip to content

Commit

Permalink
exp: refactor workspace repro into executor classes
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla committed Dec 14, 2021
1 parent 9d6603d commit 2671727
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 91 deletions.
98 changes: 15 additions & 83 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
from functools import wraps
from typing import Dict, Iterable, List, Mapping, Optional
from typing import Dict, Iterable, List, Mapping, Optional, Type

from funcy import cached_property, first

Expand All @@ -13,7 +13,6 @@

from .base import (
EXEC_APPLY,
EXEC_BASELINE,
EXEC_BRANCH,
EXEC_CHECKPOINT,
EXEC_NAMESPACE,
Expand All @@ -31,7 +30,11 @@
BaseExecutor,
ExecutorInfo,
)
from .executor.manager import BaseExecutorManager, LocalExecutorManager
from .executor.manager import (
BaseExecutorManager,
TempDirExecutorManager,
WorkspaceExecutorManager,
)
from .utils import exp_refs_by_rev

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -424,9 +427,14 @@ def reproduce_one(
)
return [stash_rev]
if tmp_dir or queue:
results = self._reproduce_revs(revs=[stash_rev], keep_stash=False)
manager_cls: Type = TempDirExecutorManager
else:
results = self._workspace_repro()
manager_cls = WorkspaceExecutorManager
results = self._reproduce_revs(
revs=[stash_rev],
keep_stash=False,
manager_cls=manager_cls,
)
exp_rev = first(results)
if exp_rev is not None:
self._log_reproduced(results, tmp_dir=tmp_dir)
Expand Down Expand Up @@ -580,6 +588,7 @@ def _reproduce_revs(
self,
revs: Optional[Iterable] = None,
keep_stash: Optional[bool] = True,
manager_cls: Type = TempDirExecutorManager,
**kwargs,
) -> Mapping[str, str]:
"""Reproduce the specified experiments.
Expand Down Expand Up @@ -616,7 +625,7 @@ def _reproduce_revs(
", ".join(rev[:7] for rev in to_run),
)

manager = LocalExecutorManager.from_stash_entries(
manager = manager_cls.from_stash_entries(
self.scm,
os.path.join(self.repo.tmp_dir, EXEC_TMP_DIR),
self.repo,
Expand Down Expand Up @@ -657,83 +666,6 @@ def _executors_repro(
"""
return manager.exec_queue(**kwargs)

@unlocked_repo
def _workspace_repro(self) -> Mapping[str, str]:
"""Run the most recently stashed experiment in the workspace."""
from dvc.stage.monitor import CheckpointKilledError
from dvc.utils.fs import makedirs

entry = first(self.stash_revs.values())
assert entry.stash_index == 0

# NOTE: the stash commit to be popped already contains all the current
# workspace changes plus CLI modified --params changes.
# `checkout --force` here will not lose any data (popping stash commit
# will result in conflict between workspace params and stashed CLI
# params, but we always want the stashed version).
with self.scm.detach_head(entry.head_rev, force=True, client="dvc"):
rev = self.stash.pop()
self.scm.set_ref(EXEC_BASELINE, entry.baseline_rev)
if entry.branch:
self.scm.set_ref(EXEC_BRANCH, entry.branch, symbolic=True)
elif self.scm.get_ref(EXEC_BRANCH):
self.scm.remove_ref(EXEC_BRANCH)
try:
orig_checkpoint = self.scm.get_ref(EXEC_CHECKPOINT)
pid_dir = os.path.join(
self.repo.tmp_dir,
EXEC_TMP_DIR,
EXEC_PID_DIR,
)
if not os.path.exists(pid_dir):
makedirs(pid_dir)
infofile = os.path.join(
pid_dir, f"workspace{BaseExecutor.INFOFILE_EXT}"
)
info = ExecutorInfo(
git_url=self.scm.root_dir,
baseline_rev=entry.baseline_rev,
location="workspace",
root_dir=self.scm.root_dir,
dvc_dir=self.dvc_dir,
name=entry.name,
wdir=os.getcwd(),
)
exec_result = BaseExecutor.reproduce(
info=info,
rev=rev,
infofile=infofile,
log_errors=False,
)

if not exec_result.exp_hash:
raise DvcException(
f"Failed to reproduce experiment '{rev[:7]}'"
)
if not exec_result.ref_info:
# repro succeeded but result matches baseline
# (no experiment generated or applied)
return {}
exp_rev = self.scm.get_ref(str(exec_result.ref_info))
self.scm.set_ref(EXEC_APPLY, exp_rev)
return {exp_rev: exec_result.exp_hash}
except CheckpointKilledError:
# Checkpoint errors have already been logged
return {}
except DvcException:
raise
except Exception as exc:
raise DvcException(
f"Failed to reproduce experiment '{rev[:7]}'"
) from exc
finally:
self.scm.remove_ref(EXEC_BASELINE)
if entry.branch:
self.scm.remove_ref(EXEC_BRANCH)
checkpoint = self.scm.get_ref(EXEC_CHECKPOINT)
if checkpoint and checkpoint != orig_checkpoint:
self.scm.set_ref(EXEC_APPLY, checkpoint)

def check_baseline(self, exp_rev):
baseline_sha = self.repo.scm.get_rev()
if exp_rev == baseline_sha:
Expand Down
49 changes: 48 additions & 1 deletion dvc/repo/experiments/executor/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from contextlib import ExitStack
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional

Expand All @@ -9,6 +10,8 @@
from dvc.utils.fs import remove

from ..base import (
EXEC_APPLY,
EXEC_BASELINE,
EXEC_BRANCH,
EXEC_CHECKPOINT,
EXEC_HEAD,
Expand Down Expand Up @@ -116,4 +119,48 @@ def from_stash_entry(


class WorkspaceExecutor(BaseLocalExecutor):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._detach_stack = ExitStack()
self._orig_checkpoint = self.scm.get_ref(EXEC_CHECKPOINT)

@classmethod
def from_stash_entry(
cls,
repo: "Repo",
stash_rev: str,
entry: "ExpStashEntry",
**kwargs,
):
root_dir = repo.scm.root_dir
executor = cls._from_stash_entry(repo, stash_rev, entry, root_dir)
logger.debug("Init workspace executor in '%s'", root_dir)
return executor

def init_git(self, scm: "Git", branch: Optional[str] = None):
self._detach_stack.enter_context(
self.scm.detach_head(
self.scm.get_ref(EXEC_HEAD),
force=True,
client="dvc",
)
)
merge_rev = self.scm.get_ref(EXEC_MERGE)
self.scm.merge(merge_rev, squash=True, commit=False)
if branch:
self.scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
elif scm.get_ref(EXEC_BRANCH):
self.scm.remove_ref(EXEC_BRANCH)

def init_cache(self, dvc: "Repo", rev: str, run_cache: bool = True):
pass

def cleanup(self):
with self._detach_stack:
self.scm.remove_ref(EXEC_BASELINE)
if self.scm.get_ref(EXEC_BRANCH):
self.scm.remove_ref(EXEC_BRANCH)
checkpoint = self.scm.get_ref(EXEC_CHECKPOINT)
if checkpoint and checkpoint != self._orig_checkpoint:
self.scm.set_ref(EXEC_APPLY, checkpoint)
super().cleanup()
97 changes: 90 additions & 7 deletions dvc/repo/experiments/executor/manager.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
import logging
import os
from abc import ABC
from collections import deque
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Deque, Dict, Optional, Tuple, Type

from dvc.proc.manager import ProcessManager

from ..base import (
EXEC_BASELINE,
EXEC_BRANCH,
EXEC_HEAD,
EXEC_MERGE,
CheckpointExistsError,
ExperimentExistsError,
ExpRefInfo,
ExpStashEntry,
)
from .base import EXEC_PID_DIR
from .base import EXEC_PID_DIR, BaseExecutor
from .local import TempDirExecutor, WorkspaceExecutor

if TYPE_CHECKING:
from scmrepo.git import Git

from dvc.repo import Repo

from .base import BaseExecutor

logger = logging.getLogger(__name__)


class BaseExecutorManager(ABC):
"""Manages executors for a collection of experiments to be run."""

EXECUTOR_CLS: Type = WorkspaceExecutor
EXECUTOR_CLS: Type = BaseExecutor

def __init__(
self,
Expand Down Expand Up @@ -124,7 +123,6 @@ def from_stash_entries(
def exec_queue(self, jobs: Optional[int] = 1):
"""Run dvc repro for queued executors in parallel."""
import signal
from collections import defaultdict
from concurrent.futures import (
CancelledError,
ProcessPoolExecutor,
Expand Down Expand Up @@ -230,5 +228,90 @@ def on_diverged(ref: str, checkpoint: bool):
return results


class LocalExecutorManager(BaseExecutorManager):
class TempDirExecutorManager(BaseExecutorManager):
EXECUTOR_CLS = TempDirExecutor


class WorkspaceExecutorManager(BaseExecutorManager):
EXECUTOR_CLS = WorkspaceExecutor

@classmethod
def from_stash_entries(
cls,
scm: "Git",
wdir: str,
repo: "Repo",
to_run: Dict[str, ExpStashEntry],
**kwargs,
):
manager = cls(scm, wdir)
try:
assert len(to_run) == 1
for stash_rev, entry in to_run.items():
scm.set_ref(EXEC_HEAD, entry.head_rev)
scm.set_ref(EXEC_MERGE, stash_rev)
scm.set_ref(EXEC_BASELINE, entry.baseline_rev)

executor = cls.EXECUTOR_CLS.from_stash_entry(
repo,
stash_rev,
entry,
**kwargs,
)
manager.enqueue(stash_rev, executor)
finally:
for ref in (EXEC_MERGE,):
scm.remove_ref(ref)
return manager

def _collect_executor(self, executor, exec_result) -> Dict[str, str]:
results = {}
exp_rev = self.scm.get_ref(EXEC_BRANCH)
if exp_rev:
logger.debug("Collected experiment '%s'.", exp_rev[:7])
results[exp_rev] = exec_result.exp_hash
return results

def exec_queue(self, jobs: Optional[int] = 1):
"""Run a single WorkspaceExecutor.
Workspace execution is done within the main DVC process
(rather than in multiprocessing context)
"""
from dvc.exceptions import DvcException
from dvc.stage.monitor import CheckpointKilledError

assert len(self._queue) == 1
result: Dict[str, Dict[str, str]] = defaultdict(dict)
rev, executor = self._queue.popleft()
infofile = os.path.join(
self.pid_dir,
f"{rev}{executor.INFOFILE_EXT}",
)
try:
exec_result = executor.reproduce(
info=executor.info,
rev=rev,
infofile=infofile,
log_level=logger.getEffectiveLevel(),
)
if not exec_result.exp_hash:
raise DvcException(
f"Failed to reproduce experiment '{rev[:7]}'"
)
if exec_result.ref_info:
result[rev].update(
self._collect_executor(executor, exec_result)
)
except CheckpointKilledError:
# Checkpoint errors have already been logged
return {}
except DvcException:
raise
except Exception as exc:
raise DvcException(
f"Failed to reproduce experiment '{rev[:7]}'"
) from exc
finally:
executor.cleanup()
return result

0 comments on commit 2671727

Please sign in to comment.