Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plots diff: collect live plots for queued experiments #9432

Merged
merged 5 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,6 @@ def get_exact_name(self, revs: Iterable[str]) -> Dict[str, Optional[str]]:
result[rev] = name
return result

def get_running_exps(self, fetch_refs: bool = True) -> Dict[str, Dict]:
"""Return info for running experiments."""
result = {}
for queue in (
self.workspace_queue,
self.tempdir_queue,
self.celery_queue,
):
result.update(queue.get_running_exps(fetch_refs))
return result

def apply(self, *args, **kwargs):
from dvc.repo.experiments.apply import apply

Expand Down
47 changes: 47 additions & 0 deletions dvc/repo/experiments/brancher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Iterator, Tuple

from dvc.repo.experiments.exceptions import InvalidExpRevError
from dvc.scm import RevError

if TYPE_CHECKING:
from dvc.repo import Repo


@contextmanager
def switch_repo(
repo: "Repo",
rev: str,
) -> Iterator[Tuple["Repo", str]]:
"""Return a repo instance (brancher) switched to rev.

If rev is the name of a running experiment, the returned instance will be
the live repo wherever the experiment is running.

NOTE: This will not resolve git SHA's that only exist in queued exp workspaces
(it will only match queued exp names).
"""
try:
with repo.switch(rev):
yield repo, rev
return
except RevError as exc:
orig_exc = exc
exps = repo.experiments

if rev == exps.workspace_queue.get_running_exp():
yield repo, "workspace"
return

for queue in (exps.tempdir_queue, exps.celery_queue):
try:
active_repo = queue.active_repo(rev)
except InvalidExpRevError:
continue
stack = ExitStack()
stack.enter_context(active_repo)
stack.enter_context(active_repo.switch("workspace"))
with stack:
yield active_repo, rev
return
raise orig_exc
7 changes: 7 additions & 0 deletions dvc/repo/experiments/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,10 @@ class UnresolvedRunningExpNamesError(UnresolvedExpNamesError):

class ExpQueueEmptyError(DvcException):
pass


class ExpNotStartedError(DvcException):
def __init__(self, name: str):
super().__init__(
f"Queued experiment '{name}' exists but has not started running yet"
)
3 changes: 3 additions & 0 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,9 @@ def filter_pipeline(stages):
exp_ref: Optional["ExpRefInfo"] = None
repro_force: bool = False

if info.name:
ui.write(f"Reproducing experiment '{info.name}'")

with cls._repro_dvc(
info,
infofile,
Expand Down
31 changes: 23 additions & 8 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,14 +649,6 @@ def stash_failed(self, entry: QueueEntry) -> None:
message=f"commit: {msg}",
)

@abstractmethod
def get_running_exps(self, fetch_refs: bool = True) -> Dict[str, Dict]:
"""Get the execution info of the currently running experiments

Args:
fetch_ref (bool): fetch completed checkpoints or not.
"""

@abstractmethod
def collect_active_data(
self,
Expand Down Expand Up @@ -710,3 +702,26 @@ def collect_failed_data(
Returns:
Dict mapping baseline revision to list of queued experiments.
"""

def active_repo(self, name: str) -> "Repo":
"""Return a Repo for the specified active experiment if it exists."""
from dvc.repo import Repo
from dvc.repo.experiments.exceptions import (
ExpNotStartedError,
InvalidExpRevError,
)
from dvc.repo.experiments.executor.base import ExecutorInfo, TaskStatus

for entry in self.iter_active():
if entry.name != name:
continue
infofile = self.get_infofile_path(entry.stash_rev)
executor_info = ExecutorInfo.load_json(infofile)
if executor_info.status < TaskStatus.RUNNING:
raise ExpNotStartedError(name)
dvc_root = os.path.join(executor_info.root_dir, executor_info.dvc_dir)
try:
return Repo(dvc_root)
except (FileNotFoundError, DvcException) as exc:
raise InvalidExpRevError(name) from exc
raise InvalidExpRevError(name)
13 changes: 0 additions & 13 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,19 +489,6 @@ def remove(self, *args, **kwargs):

return celery_remove(self, *args, **kwargs)

def get_running_exps(self, fetch_refs: bool = True) -> Dict[str, Dict]:
"""Get the execution info of the currently running experiments

Args:
fetch_ref (bool): fetch completed checkpoints or not.
"""
result: Dict[str, Dict] = {}
for entry in self.iter_active():
result.update(
fetch_running_exp_from_temp_dir(self, entry.stash_rev, fetch_refs)
)
return result

def get_ref_and_entry_by_names(
self,
exp_names: Union[str, List[str]],
Expand Down
8 changes: 0 additions & 8 deletions dvc/repo/experiments/queue/tempdir.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ def collect_executor(
) -> Dict[str, str]:
return BaseStashQueue.collect_executor(exp, executor, exec_result)

def get_running_exps(self, fetch_refs: bool = True) -> Dict[str, Dict]:
result: Dict[str, Dict] = {}
for entry in self.iter_active():
result.update(
fetch_running_exp_from_temp_dir(self, entry.stash_rev, fetch_refs)
)
return result

def collect_active_data(
self,
baseline_revs: Optional[Collection[str]],
Expand Down
25 changes: 5 additions & 20 deletions dvc/repo/experiments/queue/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,33 +187,18 @@ def logs(
):
raise NotImplementedError

def get_running_exps(
self,
fetch_refs: bool = True, # noqa: ARG002
) -> Dict[str, Dict]:
def get_running_exp(self) -> Optional[str]:
"""Return the name of the exp running in workspace (if it exists)."""
assert self._EXEC_NAME
result: Dict[str, Dict] = {}
if self._active_pid is None:
return result
return None

infofile = self.get_infofile_path(self._EXEC_NAME)

try:
info = ExecutorInfo.from_dict(load_json(infofile))
except OSError:
return result

if info.status < TaskStatus.FAILED:
# If we are appending to a checkpoint branch in a workspace
# run, show the latest checkpoint as running.
if info.status == TaskStatus.SUCCESS:
return result
last_rev = self.scm.get_ref(EXEC_BRANCH)
if last_rev:
result[last_rev] = info.asdict()
else:
result[self._EXEC_NAME] = info.asdict()
return result
return None
return info.name

def collect_active_data(
self,
Expand Down
59 changes: 33 additions & 26 deletions dvc/repo/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,51 +121,58 @@ def collect(

}
"""
from dvc.repo.experiments.brancher import switch_repo
from dvc.utils.collections import ensure_list

targets = ensure_list(targets)
targets = [self.repo.dvcfs.from_os_path(target) for target in targets]

for rev in self.repo.brancher(revs=revs):
# .brancher() adds unwanted workspace
if revs is not None and rev not in revs:
continue
rev = rev or "workspace"

res: Dict = {}
definitions = _collect_definitions(
self.repo,
targets=targets,
revision=rev,
onerror=onerror,
props=props,
)
if definitions:
res[rev] = {"definitions": definitions}

data_targets = _get_data_targets(definitions)

res[rev]["sources"] = self._collect_data_sources(
targets=data_targets,
recursive=recursive,
props=props,
if revs is None:
revs = ["workspace"]
else:
revs = list(revs)
if "workspace" in revs:
# reorder revs to match repo.brancher ordering
revs.remove("workspace")
revs = ["workspace"] + revs
for rev in revs:
with switch_repo(self.repo, rev) as (repo, _):
res: Dict = {}
definitions = _collect_definitions(
repo,
targets=targets,
revision=rev,
onerror=onerror,
props=props,
)
yield res
if definitions:
res[rev] = {"definitions": definitions}

data_targets = _get_data_targets(definitions)

res[rev]["sources"] = self._collect_data_sources(
repo,
targets=data_targets,
recursive=recursive,
props=props,
onerror=onerror,
)
yield res

@error_handler
def _collect_data_sources(
self,
repo: "Repo",
targets: Optional[List[str]] = None,
recursive: bool = False,
props: Optional[Dict] = None,
onerror: Optional[Callable] = None,
):
fs = self.repo.dvcfs
fs = repo.dvcfs

props = props or {}

plots = _collect_plots(self.repo, targets, recursive)
plots = _collect_plots(repo, targets, recursive)
res: Dict[str, Any] = {}
for fs_path, rev_props in plots.items():
joined_props = {**rev_props, **props}
Expand Down