Skip to content

Commit

Permalink
repro: support glob/foreach-group to run at once through CLI (#4976)
Browse files Browse the repository at this point in the history
* repro: support regex/foreach-group to run at once

Fixes #4912
Fixes #4886
Fixes #4958

* Use `tree.isdir` rather than `os.path.isdir`

* Use glob rather than regex

* Update dvc/command/repro.py

* s/regex/glob

* disable glob on `collect_granular`

There's no need for a glob here

* add tests for `collect` and `collect_granular`
  • Loading branch information
skshetry authored Nov 27, 2020
1 parent 011bd18 commit f4f0554
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 101 deletions.
7 changes: 7 additions & 0 deletions dvc/command/repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _repro_kwargs(self):
"recursive": self.args.recursive,
"force_downstream": self.args.force_downstream,
"pull": self.args.pull,
"glob": self.args.glob,
}


Expand Down Expand Up @@ -175,6 +176,12 @@ def add_arguments(repro_parser):
"from the run-cache."
),
)
repro_parser.add_argument(
"--glob",
action="store_true",
default=False,
help="Allows targets containing shell-style wildcards.",
)


def add_parser(subparsers, parent_parser):
Expand Down
132 changes: 76 additions & 56 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import os
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from functools import wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List

from funcy import cached_property, cat
from git import InvalidGitRepositoryError

from dvc.config import Config
from dvc.dvcfile import PIPELINE_FILE, Dvcfile, is_valid_filename
from dvc.dvcfile import PIPELINE_FILE, is_valid_filename
from dvc.exceptions import FileMissingError
from dvc.exceptions import IsADirectoryError as DvcIsADirectoryError
from dvc.exceptions import (
Expand All @@ -17,6 +17,7 @@
OutputNotFoundError,
)
from dvc.path_info import PathInfo
from dvc.repo.stage import StageLoad
from dvc.scm import Base
from dvc.scm.base import SCMError
from dvc.tree.repo import RepoTree
Expand All @@ -28,6 +29,9 @@
from .trie import build_outs_trie

if TYPE_CHECKING:
from networkx import DiGraph

from dvc.stage import Stage
from dvc.tree.base import BaseTree


Expand Down Expand Up @@ -165,6 +169,7 @@ def __init__(

self.cache = Cache(self)
self.cloud = DataCloud(self)
self.stage = StageLoad(self)

if scm or not self.dvc_dir:
self.lock = LockNoop()
Expand Down Expand Up @@ -270,25 +275,6 @@ def _ignore(self):

self.scm.ignore_list(flist)

def get_stage(self, path=None, name=None):
if not path:
path = PIPELINE_FILE
logger.debug("Assuming '%s' to be a stage inside '%s'", name, path)

dvcfile = Dvcfile(self, path)
return dvcfile.stages[name]

def get_stages(self, path=None, name=None):
if not path:
path = PIPELINE_FILE
logger.debug("Assuming '%s' to be a stage inside '%s'", name, path)

if name:
return [self.get_stage(path, name)]

dvcfile = Dvcfile(self, path)
return list(dvcfile.stages.values())

def check_modified_graph(self, new_stages):
"""Generate graph including the new stage to check for errors"""
# Building graph might be costly for the ones with many DVC-files,
Expand All @@ -306,79 +292,105 @@ def check_modified_graph(self, new_stages):
if not getattr(self, "_skip_graph_checks", False):
build_graph(self.stages + new_stages)

def _collect_inside(self, path, graph):
@staticmethod
def _collect_inside(path: str, graph: "DiGraph"):
import networkx as nx

stages = nx.dfs_postorder_nodes(graph)
return [stage for stage in stages if path_isin(stage.path, path)]

def collect(
self, target=None, with_deps=False, recursive=False, graph=None
self,
target: str = None,
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
glob: bool = False,
):
if not target:
return list(graph) if graph else self.stages

if recursive and os.path.isdir(target):
if recursive and self.tree.isdir(target):
return self._collect_inside(
os.path.abspath(target), graph or self.graph
)

path, name = parse_target(target)
stages = self.get_stages(path, name)
stages = self.stage.from_target(
target, accept_group=accept_group, glob=glob
)
if not with_deps:
return stages

return self._collect_stages_with_deps(stages, graph=graph)

def _collect_stages_with_deps(
self, stages: List["Stage"], graph: "DiGraph" = None
):
res = set()
for stage in stages:
res.update(self._collect_pipeline(stage, graph=graph))
return res

def _collect_pipeline(self, stage, graph=None):
def _collect_pipeline(self, stage: "Stage", graph: "DiGraph" = None):
import networkx as nx

pipeline = get_pipeline(get_pipelines(graph or self.graph), stage)
return nx.dfs_postorder_nodes(pipeline, stage)

def _collect_from_default_dvcfile(self, target):
dvcfile = Dvcfile(self, PIPELINE_FILE)
if dvcfile.exists():
return dvcfile.stages.get(target)

def collect_granular(
self, target=None, with_deps=False, recursive=False, graph=None
def _collect_specific_target(
self, target: str, with_deps: bool, recursive: bool, accept_group: bool
):
"""
Priority is in the order of following in case of ambiguity:
- .dvc file or .yaml file
- dir if recursive and directory exists
- stage_name
- output file
"""
if not target:
return [(stage, None) for stage in self.stages]

# Optimization: do not collect the graph for a specific target
file, name = parse_target(target)
stages = []

# Optimization: do not collect the graph for a specific target
if not file:
# parsing is ambiguous when it does not have a colon
# or if it's not a dvcfile, as it can be a stage name
# in `dvc.yaml` or, an output in a stage.
logger.debug(
"Checking if stage '%s' is in '%s'", target, PIPELINE_FILE
)
if not (recursive and os.path.isdir(target)):
stage = self._collect_from_default_dvcfile(target)
if stage:
stages = (
self._collect_pipeline(stage) if with_deps else [stage]
if not (
recursive and self.tree.isdir(target)
) and self.tree.exists(PIPELINE_FILE):
with suppress(StageNotFound):
stages = self.stage.load_all(
PIPELINE_FILE, target, accept_group=accept_group
)
if with_deps:
stages = self._collect_stages_with_deps(stages)

elif not with_deps and is_valid_filename(file):
stages = self.get_stages(file, name)
stages = self.stage.load_all(
file, name, accept_group=accept_group,
)

return stages, file, name

def collect_granular(
self,
target: str = None,
with_deps: bool = False,
recursive: bool = False,
graph: "DiGraph" = None,
accept_group: bool = False,
):
"""
Priority is in the order of following in case of ambiguity:
- .dvc file or .yaml file
- dir if recursive and directory exists
- stage_name
- output file
"""
if not target:
return [(stage, None) for stage in self.stages]

stages, file, _ = self._collect_specific_target(
target, with_deps, recursive, accept_group
)
if not stages:
if not (recursive and os.path.isdir(target)):
if not (recursive and self.tree.isdir(target)):
try:
(out,) = self.find_outs_by_path(target, strict=False)
filter_info = PathInfo(os.path.abspath(target))
Expand All @@ -387,7 +399,13 @@ def collect_granular(
pass

try:
stages = self.collect(target, with_deps, recursive, graph)
stages = self.collect(
target,
with_deps,
recursive,
graph,
accept_group=accept_group,
)
except StageFileDoesNotExistError as exc:
# collect() might try to use `target` as a stage name
# and throw error that dvc.yaml does not exist, whereas it
Expand Down Expand Up @@ -498,7 +516,9 @@ def _collect_stages(self):

for root, dirs, files in self.tree.walk(self.root_dir):
for file_name in filter(is_valid_filename, files):
new_stages = self.get_stages(os.path.join(root, file_name))
new_stages = self.stage.load_file(
os.path.join(root, file_name)
)
stages.extend(new_stages)
outs.update(
out.fspath
Expand Down
12 changes: 7 additions & 5 deletions dvc/repo/freeze.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import typing

from . import locked

if typing.TYPE_CHECKING:
from . import Repo

@locked
def _set(repo, target, frozen):
from dvc.utils import parse_target

path, name = parse_target(target)
stage = repo.get_stage(path, name)
@locked
def _set(repo: "Repo", target, frozen):
stage = repo.stage.get_target(target)
stage.frozen = frozen
stage.dvcfile.dump(stage, update_lock=False)

Expand Down
10 changes: 6 additions & 4 deletions dvc/repo/remove.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
import typing

from ..utils import parse_target
from . import locked

if typing.TYPE_CHECKING:
from dvc.repo import Repo

logger = logging.getLogger(__name__)


@locked
def remove(self, target, outs=False):
path, name = parse_target(target)
stages = self.get_stages(path, name)
def remove(self: "Repo", target: str, outs: bool = False):
stages = self.stage.from_target(target)

for stage in stages:
stage.remove(remove_outs=outs, force=outs)
Expand Down
21 changes: 15 additions & 6 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import typing
from functools import partial

from dvc.exceptions import InvalidArgumentError, ReproductionError
Expand All @@ -8,6 +9,9 @@
from . import locked
from .graph import get_pipeline, get_pipelines

if typing.TYPE_CHECKING:
from . import Repo

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -75,15 +79,15 @@ def _get_active_graph(G):
@locked
@scm_context
def reproduce(
self,
self: "Repo",
target=None,
recursive=False,
pipeline=False,
all_pipelines=False,
**kwargs,
):
from dvc.utils import parse_target

glob = kwargs.pop("glob", False)
accept_group = not glob
assert target is None or isinstance(target, str)
if not target and not all_pipelines:
raise InvalidArgumentError(
Expand All @@ -97,12 +101,11 @@ def reproduce(
active_graph = _get_active_graph(self.graph)
active_pipelines = get_pipelines(active_graph)

path, name = parse_target(target)
if pipeline or all_pipelines:
if all_pipelines:
pipelines = active_pipelines
else:
stage = self.get_stage(path, name)
stage = self.stage.get_target(target)
pipelines = [get_pipeline(active_pipelines, stage)]

targets = []
Expand All @@ -111,7 +114,13 @@ def reproduce(
if pipeline.in_degree(stage) == 0:
targets.append(stage)
else:
targets = self.collect(target, recursive=recursive, graph=active_graph)
targets = self.collect(
target,
recursive=recursive,
graph=active_graph,
accept_group=accept_group,
glob=glob,
)

return _reproduce_stages(active_graph, targets, **kwargs)

Expand Down
Loading

0 comments on commit f4f0554

Please sign in to comment.