diff --git a/dvc/command/dag.py b/dvc/command/dag.py index 07c519fbd7..8867bf1dd0 100644 --- a/dvc/command/dag.py +++ b/dvc/command/dag.py @@ -36,18 +36,20 @@ def _collect_targets(repo, target, outs): return [stage.addressing for stage, _ in pairs] targets = [] + + outs_trie = repo.index.outs_trie for stage, info in pairs: if not info: targets.extend([str(out) for out in stage.outs]) continue - for out in repo.outs_trie.itervalues(prefix=info.parts): # noqa: B301 + for out in outs_trie.itervalues(prefix=info.parts): # noqa: B301 targets.extend(str(out)) return targets -def _transform(repo, outs): +def _transform(index, outs): import networkx as nx from dvc.stage import Stage @@ -55,7 +57,7 @@ def _transform(repo, outs): def _relabel(node) -> str: return node.addressing if isinstance(node, Stage) else str(node) - G = repo.outs_graph if outs else repo.graph + G = index.outs_graph if outs else index.graph return nx.relabel_nodes(G, _relabel, copy=True) @@ -85,7 +87,7 @@ def _filter(G, targets, full): def _build(repo, target=None, full=False, outs=False): targets = _collect_targets(repo, target, outs) - G = _transform(repo, outs) + G = _transform(repo.index, outs) return _filter(G, targets, full) diff --git a/dvc/command/stage.py b/dvc/command/stage.py index b9491b2d67..aab3c1c652 100644 --- a/dvc/command/stage.py +++ b/dvc/command/stage.py @@ -66,7 +66,7 @@ def prepare_stages_data( class CmdStageList(CmdBase): def _get_stages(self) -> Iterable["Stage"]: if self.args.all: - stages: List["Stage"] = self.repo.stages # type: ignore + stages: List["Stage"] = self.repo.index.stages # type: ignore logger.trace( # type: ignore[attr-defined] "%d no. of stages found", len(stages) ) diff --git a/dvc/command/status.py b/dvc/command/status.py index c673956020..5252b0c5e8 100644 --- a/dvc/command/status.py +++ b/dvc/command/status.py @@ -80,7 +80,7 @@ def run(self): return 0 # additional hints for the user - if not self.repo.stages: + if not self.repo.index.stages: ui.write(self.EMPTY_PROJECT_MSG) elif self.args.cloud or self.args.remote: remote = self.args.remote or self.repo.config["core"].get( diff --git a/dvc/objects/__init__.py b/dvc/objects/__init__.py index ad9721a469..159b8fb6ac 100644 --- a/dvc/objects/__init__.py +++ b/dvc/objects/__init__.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterator, Union from .tree import Tree @@ -24,3 +24,11 @@ def load(odb: "ObjectDB", hash_info: "HashInfo") -> "HashFile": if hash_info.isdir: return Tree.load(odb, hash_info) return odb.get(hash_info) + + +def iterobjs( + obj: Union["Tree", "HashFile"] +) -> Iterator[Union["Tree", "HashFile"]]: + if isinstance(obj, Tree): + yield from (entry_obj for _, entry_obj in obj) + yield obj diff --git a/dvc/output.py b/dvc/output.py index d3d3bf343e..910ebfafea 100644 --- a/dvc/output.py +++ b/dvc/output.py @@ -1004,6 +1004,18 @@ def merge(self, ancestor, other): def fspath(self): return self.path_info.fspath + @property + def is_decorated(self) -> bool: + return self.is_metric or self.is_plot + + @property + def is_metric(self) -> bool: + return bool(self.metric) or bool(self.live) + + @property + def is_plot(self) -> bool: + return bool(self.plot) + ARTIFACT_SCHEMA = { **CHECKSUMS_SCHEMA, diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 00405c5e23..e853be2190 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -3,9 +3,9 @@ from collections import defaultdict from contextlib import contextmanager from functools import wraps -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, Set -from funcy import cached_property, cat +from funcy import cached_property from dvc.exceptions import FileMissingError from dvc.exceptions import IsADirectoryError as DvcIsADirectoryError @@ -14,11 +14,9 @@ from dvc.path_info import PathInfo from dvc.utils.fs import path_isin -from .graph import build_graph, build_outs_graph, get_pipelines -from .trie import build_outs_trie - if TYPE_CHECKING: from dvc.fs.base import BaseFileSystem + from dvc.objects.file import HashFile from dvc.scm import Base @@ -57,7 +55,6 @@ class Repo: DVC_DIR = ".dvc" from dvc.repo.add import add - from dvc.repo.brancher import brancher from dvc.repo.checkout import checkout from dvc.repo.commit import commit from dvc.repo.destroy import destroy @@ -207,6 +204,12 @@ def __init__( def __str__(self): return self.url or self.root_dir + @cached_property + def index(self): + from dvc.repo.index import Index + + return Index(self) + @staticmethod def open(url, *args, **kwargs): if url is None: @@ -323,23 +326,10 @@ def _ignore(self): self.scm.ignore_list(flist) - def check_modified_graph(self, new_stages, old_stages=None): - """Generate graph including the new stage to check for errors""" - # Building graph might be costly for the ones with many DVC-files, - # so we provide this undocumented hack to skip it. See [1] for - # more details. The hack can be used as: - # - # repo = Repo(...) - # repo._skip_graph_checks = True - # repo.add(...) - # - # A user should care about not duplicating outs and not adding cycles, - # otherwise DVC might have an undefined behaviour. - # - # [1] https://github.com/iterative/dvc/issues/2671 - if not getattr(self, "_skip_graph_checks", False): - existing_stages = self.stages if old_stages is None else old_stages - build_graph(existing_stages + new_stages) + def brancher(self, *args, **kwargs): + from dvc.repo.brancher import brancher + + return brancher(self, *args, **kwargs) def used_objs( self, @@ -373,16 +363,14 @@ def used_objs( """ used = defaultdict(set) - def _add_suffix(objs, suffix): - from dvc.objects.tree import Tree + def _add_suffix(objs: Set["HashFile"], suffix: str) -> None: + from itertools import chain + + from dvc.objects import iterobjs - for obj in objs: + for obj in chain.from_iterable(map(iterobjs, objs)): if obj.name is not None: obj.name += suffix - if isinstance(obj, Tree): - for _, entry_obj in obj: - if entry_obj.name is not None: - entry_obj.name += suffix for branch in self.brancher( revs=revs, @@ -391,25 +379,17 @@ def _add_suffix(objs, suffix): all_commits=all_commits, all_experiments=all_experiments, ): - targets = targets or [None] - - pairs = cat( - self.stage.collect_granular( - target, recursive=recursive, with_deps=with_deps - ) - for target in targets - ) - - for stage, filter_info in pairs: - for odb, objs in stage.get_used_objs( - remote=remote, - force=force, - jobs=jobs, - filter_info=filter_info, - ).items(): - if branch: - _add_suffix(objs, f" ({branch})") - used[odb].update(objs) + for odb, objs in self.index.used_objs( + targets, + remote=remote, + force=force, + jobs=jobs, + recursive=recursive, + with_deps=with_deps, + ).items(): + if branch: + _add_suffix(objs, f" ({branch})") + used[odb].update(objs) if used_run_cache: for odb, objs in self.stage_cache.get_used_objs( @@ -419,39 +399,13 @@ def _add_suffix(objs, suffix): return used - @cached_property - def outs_trie(self): - return build_outs_trie(self.stages) - - @cached_property - def graph(self): - return build_graph(self.stages, self.outs_trie) - - @cached_property - def outs_graph(self): - return build_outs_graph(self.graph, self.outs_trie) - - @cached_property - def pipelines(self): - return get_pipelines(self.graph) - - @cached_property - def stages(self): - """ - Walks down the root directory looking for Dvcfiles, - skipping the directories that are related with - any SCM (e.g. `.git`), DVC itself (`.dvc`), or directories - tracked by DVC (e.g. `dvc add data` would skip `data/`) - - NOTE: For large repos, this could be an expensive - operation. Consider using some memoization. - """ - error_handler = self.stage_collection_error_handler - return self.stage.collect_repo(onerror=error_handler) + @property + def stages(self): # obsolete, only for backward-compatibility + return self.index.stages def find_outs_by_path(self, path, outs=None, recursive=False, strict=True): # using `outs_graph` to ensure graph checks are run - outs = outs or self.outs_graph + outs = outs or self.index.outs_graph abs_path = os.path.abspath(path) path_info = PathInfo(abs_path) @@ -512,11 +466,7 @@ def close(self): def _reset(self): self.state.close() self.scm._reset() # pylint: disable=protected-access - self.__dict__.pop("outs_trie", None) - self.__dict__.pop("outs_graph", None) - self.__dict__.pop("graph", None) - self.__dict__.pop("stages", None) - self.__dict__.pop("pipelines", None) + self.__dict__.pop("index", None) self.__dict__.pop("dvcignore", None) def __enter__(self): diff --git a/dvc/repo/add.py b/dvc/repo/add.py index 18a05282fa..7318cb320e 100644 --- a/dvc/repo/add.py +++ b/dvc/repo/add.py @@ -169,14 +169,13 @@ def add( # noqa: C901 desc = "Collecting targets" stages_it = create_stages(repo, add_targets, fname, transfer, **kwargs) stages = list(ui.progress(stages_it, desc=desc, unit="file")) - msg = "Collecting stages from the workspace" with translate_graph_error(stages), ui.status(msg) as status: # remove existing stages that are to-be replaced with these # new stages for the graph checks. - old_stages = set(repo.stages) - set(stages) + new_index = repo.index.update(stages) status.update("Checking graph") - repo.check_modified_graph(stages, list(old_stages)) + new_index.check_graph() odb = None if to_remote: diff --git a/dvc/repo/checkout.py b/dvc/repo/checkout.py index 356a83957f..0418126101 100644 --- a/dvc/repo/checkout.py +++ b/dvc/repo/checkout.py @@ -28,13 +28,7 @@ def _fspath_dir(path): def _remove_unused_links(repo): - used = [ - out.fspath - for stage in repo.stages - for out in stage.outs - if out.scheme == "local" - ] - + used = [out.fspath for out in repo.index.outs if out.scheme == "local"] unused = repo.state.get_unused_links(used, repo.fs) ret = [_fspath_dir(u) for u in unused] repo.state.remove_links(unused, repo.fs) diff --git a/dvc/repo/collect.py b/dvc/repo/collect.py index acc4f0497a..dac019972f 100644 --- a/dvc/repo/collect.py +++ b/dvc/repo/collect.py @@ -20,12 +20,9 @@ def _collect_outs( repo: "Repo", output_filter: FilterFn = None, deps: bool = False ) -> Outputs: - outs = [ - out - for stage in repo.graph # using `graph` to ensure graph checks run - for out in (stage.deps if deps else stage.outs) - ] - return list(filter(output_filter, outs)) if output_filter else outs + index = repo.index + index.check_graph() # ensure graph is correct + return list(filter(output_filter, index.deps if deps else index.outs)) def _collect_paths( diff --git a/dvc/repo/destroy.py b/dvc/repo/destroy.py index ee5bbdc6b0..55da5f54fc 100644 --- a/dvc/repo/destroy.py +++ b/dvc/repo/destroy.py @@ -5,7 +5,7 @@ @locked def _destroy_stages(repo): - for stage in repo.stages: + for stage in repo.index.stages: stage.unprotect_outs() stage.dvcfile.remove(force=True) diff --git a/dvc/repo/diff.py b/dvc/repo/diff.py index fbbd32aa33..ccc31e9245 100644 --- a/dvc/repo/diff.py +++ b/dvc/repo/diff.py @@ -142,21 +142,20 @@ def _to_checksum(output): )[1].hash_info.value return output.hash_info.value - for stage in repo.stages: - for output in stage.outs: - if _exists(output): - yield_output = targets is None or any( - output.path_info.isin_or_eq(target) for target in targets - ) - - if yield_output: - yield _to_path(output), _to_checksum(output) - - if output.is_dir_checksum and ( - yield_output - or any(target.isin(output.path_info) for target in targets) - ): - yield from _dir_output_paths(repo_fs, output, targets) + for output in repo.index.outs: + if _exists(output): + yield_output = targets is None or any( + output.path_info.isin_or_eq(target) for target in targets + ) + + if yield_output: + yield _to_path(output), _to_checksum(output) + + if output.is_dir_checksum and ( + yield_output + or any(target.isin(output.path_info) for target in targets) + ): + yield from _dir_output_paths(repo_fs, output, targets) def _dir_output_paths(repo_fs, output, targets=None): diff --git a/dvc/repo/imp_url.py b/dvc/repo/imp_url.py index f372a3bf65..aa8ec83bae 100644 --- a/dvc/repo/imp_url.py +++ b/dvc/repo/imp_url.py @@ -67,7 +67,8 @@ def imp_url( dvcfile.remove() try: - self.check_modified_graph([stage]) + new_index = self.index.add(stage) + new_index.check_graph() except OutputDuplicationError as exc: raise OutputDuplicationError(exc.output, set(exc.stages) - {stage}) diff --git a/dvc/repo/index.py b/dvc/repo/index.py new file mode 100644 index 0000000000..ca55ecd329 --- /dev/null +++ b/dvc/repo/index.py @@ -0,0 +1,285 @@ +from contextlib import suppress +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, +) + +from funcy import cached_property, memoize, nullcontext + +from dvc.utils import dict_md5 + +if TYPE_CHECKING: + from networkx import DiGraph + from pygtrie import Trie + + from dvc.dependency import Dependency, ParamsDependency + from dvc.fs.base import BaseFileSystem + from dvc.objects import HashInfo, ObjectDB + from dvc.output import Output + from dvc.repo.stage import StageLoad + from dvc.stage import Stage + from dvc.types import StrPath, TargetType + + +ObjectContainer = Dict[Optional["ObjectDB"], Set["HashInfo"]] + + +class Index: + def __init__( + self, + repo: "Repo", # pylint: disable=redefined-outer-name + fs: "BaseFileSystem" = None, + stages: List["Stage"] = None, + ) -> None: + """Index is an immutable collection of stages. + + Generally, Index is a complete collection of stages at a point in time. + With "a point in time", it means it is collected from the user's + workspace or a git revision. + And, since Index is immutable, the collection is frozen in time. + + Index provides multiple ways to view this collection: + + stages - provides direct access to this collection + outputs - provides direct access to the outputs + objects - provides direct access to the objects + graph - + ... and many more. + + Index also provides ways to slice and dice this collection. + Some `views` might not make sense when sliced (eg: pipelines/graph). + """ + + self.repo: "Repo" = repo + self.fs: "BaseFileSystem" = fs or repo.fs + self.stage_collector: "StageLoad" = repo.stage + if stages is not None: + self.stages: List["Stage"] = stages + + @cached_property + def stages(self) -> List["Stage"]: # pylint: disable=method-hidden + # note that ideally we should be keeping this in a set as it is unique, + # hashable and has no concept of orderliness on its own. But we depend + # on this to be somewhat ordered for status/metrics/plots, etc. + onerror = self.repo.stage_collection_error_handler + return self.stage_collector.collect_repo(onerror=onerror) + + def __repr__(self) -> str: + from dvc.fs.local import LocalFileSystem + + rev = "workspace" + if not isinstance(self.fs, LocalFileSystem): + rev = self.repo.get_rev()[:7] + return f"Index({self.repo}, fs@{rev})" + + def __len__(self) -> int: + return len(self.stages) + + def __contains__(self, stage: "Stage") -> bool: + # as we are keeping stages inside a list, it might be slower. + return stage in self.stages + + def __iter__(self) -> Iterator["Stage"]: + yield from self.stages + + def filter(self, filter_fn: Callable[["Stage"], bool]) -> "Index": + stages_it = filter(filter_fn, self) + return Index(self.repo, self.fs, stages=list(stages_it)) + + def slice(self, path: "StrPath") -> "Index": + from dvc.utils import relpath + from dvc.utils.fs import path_isin + + target_path = relpath(path, self.repo.root_dir) + + def is_stage_inside_path(stage: "Stage") -> bool: + return path_isin(stage.path_in_repo, target_path) + + return self.filter(is_stage_inside_path) + + @property + def outs(self) -> Iterator["Output"]: + for stage in self: + yield from stage.outs + + @property + def decorated_outputs(self) -> Iterator["Output"]: + for output in self.outs: + if output.is_decorated: + yield output + + @property + def metrics(self) -> Iterator["Output"]: + for output in self.outs: + if output.is_metric: + yield output + + @property + def plots(self) -> Iterator["Output"]: + for output in self.outs: + if output.is_plot: + yield output + + @property + def deps(self) -> Iterator["Dependency"]: + for stage in self: + yield from stage.deps + + @property + def params(self) -> Iterator["ParamsDependency"]: + from dvc.dependency import ParamsDependency + + for dep in self.deps: + if isinstance(dep, ParamsDependency): + yield dep + + @cached_property + def outs_trie(self) -> "Trie": + from dvc.repo.trie import build_outs_trie + + return build_outs_trie(self.stages) + + @property + def graph(self) -> "DiGraph": + return self.build_graph() + + @cached_property + def outs_graph(self) -> "DiGraph": + from dvc.repo.graph import build_outs_graph + + return build_outs_graph(self.graph, self.outs_trie) + + def used_objs( + self, + targets: "TargetType" = None, + with_deps: bool = False, + remote: str = None, + force: bool = False, + recursive: bool = False, + jobs: int = None, + ) -> "ObjectContainer": + from collections import defaultdict + from itertools import chain + + from dvc.utils.collections import ensure_list + + used: "ObjectContainer" = defaultdict(set) + collect_targets: Sequence[Optional[str]] = (None,) + if targets: + collect_targets = ensure_list(targets) + + pairs = chain.from_iterable( + self.stage_collector.collect_granular( + target, recursive=recursive, with_deps=with_deps + ) + for target in collect_targets + ) + + for stage, filter_info in pairs: + for odb, objs in stage.get_used_objs( + remote=remote, + force=force, + jobs=jobs, + filter_info=filter_info, + ).items(): + used[odb].update(objs) + return used + + # Following methods help us treat the collection as a set-like structure + # and provides faux-immutability. + # These methods do not preserve stages order. + + def update(self, stages: Iterable["Stage"]) -> "Index": + new_stages = set(stages) + # we remove existing stages with same hashes at first + # and then re-add the new ones later. + stages_set = (set(self.stages) - new_stages) | new_stages + return Index(self.repo, self.fs, stages=list(stages_set)) + + def add(self, stage: "Stage") -> "Index": + return self.update([stage]) + + def remove( + self, stage: "Stage", ignore_not_existing: bool = False + ) -> "Index": + stages = self._discard_stage( + stage, ignore_not_existing=ignore_not_existing + ) + return Index(self.repo, self.fs, stages=stages) + + def discard(self, stage: "Stage") -> "Index": + return self.remove(stage, ignore_not_existing=True) + + def difference(self, stages: Iterable["Stage"]) -> "Index": + # this does not preserve the order + stages_set = set(self.stages) - set(stages) + return Index(self.repo, self.fs, stages=list(stages_set)) + + def _discard_stage( + self, stage: "Stage", ignore_not_existing: bool = False + ) -> List["Stage"]: + stages = self.stages[:] + ctx = suppress(ValueError) if ignore_not_existing else nullcontext() + with ctx: + stages.remove(stage) + return stages + + @memoize + def build_graph(self) -> "DiGraph": + from dvc.repo.graph import build_graph + + return build_graph(self.stages, self.outs_trie) + + def check_graph(self) -> None: + if not getattr(self.repo, "_skip_graph_checks", False): + self.build_graph() + + def dumpd(self) -> Dict[str, Dict]: + def dump(stage: "Stage"): + key = stage.path_in_repo + try: + key += ":" + stage.name # type: ignore[attr-defined] + except AttributeError: + pass + return key, stage.dumpd() + + return dict(dump(stage) for stage in self) + + @cached_property + def identifier(self) -> str: + """Unique identifier for the index. + + We can use this to optimize and skip opening some indices + eg: on push/pull/fetch/gc --all-commits. + + Currently, it is unique to the platform (windows vs posix). + """ + return dict_md5(self.dumpd()) + + +if __name__ == "__main__": + from funcy import log_durations + + from dvc.repo import Repo + + repo = Repo() + index = Index(repo, repo.fs) + print(index) + with log_durations(print, "collecting stages"): + # pylint: disable=pointless-statement + print("no of stages", len(index.stages)) + with log_durations(print, "building graph"): + index.build_graph() + with log_durations(print, "calculating hash"): + print(index.identifier) + with log_durations(print, "updating"): + index2 = index.update(index.stages) + with log_durations(print, "calculating hash"): + print(index2.identifier) diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 6630ddd7e8..4ce1317d1f 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -83,7 +83,7 @@ def _read_params( def _collect_vars(repo, params) -> Dict: vars_params: Dict[str, Dict] = defaultdict(dict) - for stage in repo.stages: + for stage in repo.index.stages: if isinstance(stage, PipelineStage) and stage.tracked_vars: for file, vars_ in stage.tracked_vars.items(): # `params` file are shown regardless of `tracked` or not diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 9cbef4742a..fbccff4f1f 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -108,7 +108,7 @@ def reproduce( stages = set() if pipeline or all_pipelines: - pipelines = get_pipelines(self.graph) + pipelines = get_pipelines(self.index.graph) if all_pipelines: used_pipelines = pipelines else: @@ -132,7 +132,7 @@ def reproduce( ) ) - return _reproduce_stages(self.graph, list(stages), **kwargs) + return _reproduce_stages(self.index.graph, list(stages), **kwargs) def _reproduce_stages( diff --git a/dvc/repo/run.py b/dvc/repo/run.py index c44bf2558c..3f149d52d1 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -32,5 +32,8 @@ def run( else: stage.run(no_commit=no_commit, run_cache=run_cache) + new_index = self.index.add(stage) + new_index.check_graph() + stage.dump(update_lock=not no_exec) return stage diff --git a/dvc/repo/stage.py b/dvc/repo/stage.py index b8f5a02342..a3d87e6470 100644 --- a/dvc/repo/stage.py +++ b/dvc/repo/stage.py @@ -110,8 +110,9 @@ def wrapper(loader: "StageLoad", *args, **kwargs): class StageLoad: - def __init__(self, repo: "Repo") -> None: + def __init__(self, repo: "Repo", fs=None) -> None: self.repo: "Repo" = repo + self._fs = fs @locked def add( @@ -164,7 +165,6 @@ def create( is_valid_name, prepare_file_path, validate_kwargs, - validate_state, ) stage_data = validate_kwargs( @@ -184,7 +184,13 @@ def create( stage_cls, repo=self.repo, path=path, **stage_data ) if validate: - validate_state(self.repo, stage, force=force) + if not force: + from dvc.stage.utils import check_stage_exists + + check_stage_exists(self.repo, stage, stage.path) + + new_index = self.repo.index.add(stage) + new_index.check_graph() restore_meta(stage) return stage @@ -305,11 +311,13 @@ def load_glob(self, path: str, expr: str = None): @property def fs(self): + if self._fs: + return self._fs return self.repo.fs @property def graph(self) -> "DiGraph": - return self.repo.graph + return self.repo.index.graph def collect( self, @@ -351,9 +359,9 @@ def collect( glob: Use `target` as a pattern to match stages in a file. """ if not target: - return list(graph) if graph else self.repo.stages + return list(graph) if graph else list(self.repo.index) - if recursive and self.repo.fs.isdir(target): + if recursive and self.fs.isdir(target): from dvc.repo.graph import collect_inside_path path = os.path.abspath(target) @@ -392,7 +400,7 @@ def collect_granular( (see `collect()` for other arguments) """ if not target: - return [StageInfo(stage) for stage in self.repo.stages] + return [StageInfo(stage) for stage in self.repo.index] stages, file, _ = _collect_specific_target( self, target, with_deps, recursive, accept_group @@ -432,7 +440,7 @@ def collect_granular( return [StageInfo(stage) for stage in stages] - def collect_repo(self, onerror: Callable[[str, Exception], None] = None): + def _collect_repo(self, onerror: Callable[[str, Exception], None] = None): """Collects all of the stages present in the DVC repo. Args: @@ -465,7 +473,6 @@ def is_out_or_ignored(root, directory): # trailing slash needed to check if a directory is gitignored return dir_path in outs or is_ignored(f"{dir_path}{sep}") - stages = [] for root, dirs, files in self.repo.dvcignore.walk( self.fs, self.repo.root_dir ): @@ -480,7 +487,7 @@ def is_out_or_ignored(root, directory): continue raise - stages.extend(new_stages) + yield from new_stages outs.update( out.fspath for stage in new_stages @@ -488,4 +495,6 @@ def is_out_or_ignored(root, directory): if out.scheme == "local" ) dirs[:] = [d for d in dirs if not is_out_or_ignored(root, d)] - return stages + + def collect_repo(self, onerror: Callable[[str, Exception], None] = None): + return list(self._collect_repo(onerror)) diff --git a/dvc/stage/utils.py b/dvc/stage/utils.py index 62f5299b38..462f761253 100644 --- a/dvc/stage/utils.py +++ b/dvc/stage/utils.py @@ -1,6 +1,5 @@ import os import pathlib -from contextlib import suppress from typing import TYPE_CHECKING, Union from funcy import concat, first, lsplit, rpartial, without @@ -277,7 +276,7 @@ def prepare_file_path(kwargs): ) -def _check_stage_exists( +def check_stage_exists( repo: "Repo", stage: Union["Stage", "PipelineStage"], path: str ): from dvc.dvcfile import make_dvcfile @@ -302,30 +301,6 @@ def _check_stage_exists( ) -def validate_state( - repo: "Repo", new: Union["Stage", "PipelineStage"], force: bool = False -) -> None: - """Validates that the new stage: - - * does not already exist with the same (name, path), unless force is True. - * does not affect the correctness of the repo's graph (to-be graph). - """ - stages = repo.stages[:] if force else None - if force: - # remove an existing stage (if exists) - with suppress(ValueError): - # uses name and path to determine this which is same for the - # existing (if any) and the new one - stages.remove(new) - else: - _check_stage_exists(repo, new, new.path) - # if not `force`, we don't need to replace existing stage with the new - # one, as the dumping this is not going to replace it anyway (as the - # stage with same path + name does not exist (from above check). - - repo.check_modified_graph(new_stages=[new], old_stages=stages) - - def validate_kwargs(single_stage: bool = False, fname: str = None, **kwargs): """Prepare, validate and process kwargs passed from cli""" cmd = kwargs.get("cmd") diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index e396be44b5..9b29a4571c 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -81,18 +81,15 @@ def dict_filter(d, exclude=()): """ Exclude specified keys from a nested dict """ + if not exclude or not isinstance(d, (list, dict)): + return d if isinstance(d, list): return [dict_filter(e, exclude) for e in d] - if isinstance(d, dict): - return { - k: dict_filter(v, exclude) - for k, v in d.items() - if k not in exclude - } - - return d + return { + k: dict_filter(v, exclude) for k, v in d.items() if k not in exclude + } def dict_hash(d, typ, exclude=()): diff --git a/dvc/utils/fs.py b/dvc/utils/fs.py index 4b80923e86..bfcf4cb5fa 100644 --- a/dvc/utils/fs.py +++ b/dvc/utils/fs.py @@ -147,7 +147,7 @@ def remove(path): raise -def path_isin(child: "StrPath", parent: "StrPath"): +def path_isin(child: "StrPath", parent: "StrPath") -> bool: """Check if given `child` path is inside `parent`.""" def normalize_path(path) -> str: diff --git a/tests/dir_helpers.py b/tests/dir_helpers.py index a45a62c5d0..18c163f22a 100644 --- a/tests/dir_helpers.py +++ b/tests/dir_helpers.py @@ -158,32 +158,31 @@ def scm_gen(self, struct, text="", commit=None): paths = self.gen(struct, text) return self.scm_add(paths, commit=commit) + def commit(self, output_paths, msg): + def to_gitignore(stage_path): + from dvc.scm import Git + + return os.path.join(os.path.dirname(stage_path), Git.GITIGNORE) + + gitignores = [ + to_gitignore(s) + for s in output_paths + if os.path.exists(to_gitignore(s)) + ] + return self.scm_add(output_paths + gitignores, commit=msg) + def dvc_add(self, filenames, commit=None): self._require("dvc") filenames = _coerce_filenames(filenames) stages = self.dvc.add(filenames) if commit: - stage_paths = [s.path for s in stages] - - def to_gitignore(stage_path): - from dvc.scm import Git - - return os.path.join(os.path.dirname(stage_path), Git.GITIGNORE) - - gitignores = [ - to_gitignore(s) - for s in stage_paths - if os.path.exists(to_gitignore(s)) - ] - self.scm_add(stage_paths + gitignores, commit=commit) - + self.commit([s.path for s in stages], msg=commit) return stages def scm_add(self, filenames, commit=None): self._require("scm") filenames = _coerce_filenames(filenames) - self.scm.add(filenames) if commit: self.scm.commit(commit) diff --git a/tests/func/conftest.py b/tests/func/conftest.py index 41851900f8..7d52b5622f 100644 --- a/tests/func/conftest.py +++ b/tests/func/conftest.py @@ -5,11 +5,23 @@ @pytest.fixture def run_copy_metrics(tmp_dir, run_copy): - def run(file1, file2, commit=None, tag=None, single_stage=True, **kwargs): + def run( + file1, + file2, + commit=None, + tag=None, + single_stage=True, + name=None, + **kwargs, + ): + if name: + single_stage = False + stage = tmp_dir.dvc.run( cmd=f"python copy.py {file1} {file2}", deps=[file1], single_stage=single_stage, + name=name, **kwargs, ) diff --git a/tests/func/test_add.py b/tests/func/test_add.py index f8d15d7ced..8e2040cddf 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -1206,7 +1206,7 @@ def test_add_on_not_existing_file_should_not_remove_stage_file(tmp_dir, dvc): @pytest.mark.parametrize( "target", [ - "dvc.repo.Repo.check_modified_graph", + "dvc.repo.index.Index.check_graph", "dvc.stage.Stage.save", "dvc.stage.Stage.commit", ], diff --git a/tests/func/test_dvcfile.py b/tests/func/test_dvcfile.py index 9bbbaee76e..6710c04cb6 100644 --- a/tests/func/test_dvcfile.py +++ b/tests/func/test_dvcfile.py @@ -164,7 +164,7 @@ def test_stage_collection(tmp_dir, dvc): always_changed=True, single_stage=True, ) - assert set(dvc.stages) == {stage1, stage3, stage2} + assert set(dvc.index.stages) == {stage1, stage3, stage2} def test_remove_stage(tmp_dir, dvc, run_copy): diff --git a/tests/func/test_repo_index.py b/tests/func/test_repo_index.py new file mode 100644 index 0000000000..0224d6eb0c --- /dev/null +++ b/tests/func/test_repo_index.py @@ -0,0 +1,247 @@ +import os +from itertools import chain + +import pytest +from pygtrie import Trie + +from dvc.repo.index import Index +from dvc.stage import PipelineStage, Stage +from dvc.utils import relpath +from tests.func.plots.utils import _write_json + + +def test_index(tmp_dir, scm, dvc, run_copy): + (stage1,) = tmp_dir.dvc_gen("foo", "foo") + stage2 = run_copy("foo", "bar", name="copy-foo-bar") + tmp_dir.commit([s.outs[0].fspath for s in (stage1, stage2)], msg="add") + + index = Index(dvc) + assert index.fs == dvc.fs + + assert len(index) == len(index.stages) == 2 + assert set(index.stages) == set(index) == {stage1, stage2} + assert stage1 in index + assert stage2 in index + + assert index.outs_graph + assert index.graph + assert index.build_graph() + assert isinstance(index.outs_trie, Trie) + assert index.identifier + index.check_graph() + + +def test_repr(tmp_dir, scm, dvc): + tmp_dir.dvc_gen("foo", "foo", commit="add foo") + + brancher = dvc.brancher([scm.get_rev()]) + rev = next(brancher) + assert rev == "workspace" + assert repr(Index(dvc)) == f"Index({dvc}, fs@{rev})" + + rev = next(brancher) + assert rev == scm.get_rev() + assert repr(Index(dvc)) == f"Index({dvc}, fs@{rev[:7]})" + + +def test_filter_index(tmp_dir, dvc, run_copy): + tmp_dir.dvc_gen("foo", "foo") + stage2 = run_copy("foo", "bar", name="copy-foo-bar") + + def filter_pipeline(stage): + return bool(stage.cmd) + + filtered_index = Index(dvc).filter(filter_pipeline) + assert list(filtered_index) == [stage2] + + +def test_slice_index(tmp_dir, dvc): + tmp_dir.gen({"dir1": {"foo": "foo"}, "dir2": {"bar": "bar"}}) + with (tmp_dir / "dir1").chdir(): + (stage1,) = dvc.add("foo") + with (tmp_dir / "dir2").chdir(): + (stage2,) = dvc.add("bar") + + index = Index(dvc) + + sliced = index.slice("dir1") + assert set(sliced) == {stage1} + assert sliced.stages is not index.stages # sanity check + + sliced = index.slice(tmp_dir / "dir1") + assert set(sliced) == {stage1} + + sliced = index.slice("dir2") + assert set(sliced) == {stage2} + + with (tmp_dir / "dir1").chdir(): + sliced = index.slice(relpath(tmp_dir / "dir2")) + assert set(sliced) == {stage2} + + +def outputs_equal(actual, expected): + actual, expected = list(actual), list(expected) + + def sort_fn(output): + return output.fspath + + assert len(actual) == len(expected) + pairs = zip(sorted(actual, key=sort_fn), sorted(expected, key=sort_fn)) + assert all(actual.fspath == expected.fspath for actual, expected in pairs) + return True + + +def test_deps_outs_getters(tmp_dir, dvc, run_copy_metrics): + (foo_stage,) = tmp_dir.dvc_gen({"foo": "foo"}) + tmp_dir.gen({"params.yaml": "param: 100\n"}) + tmp_dir.gen({"m_temp.yaml": str(5)}) + + run_stage1 = run_copy_metrics( + "m_temp.yaml", + "m.yaml", + metrics=["m.yaml"], + params=["param"], + name="copy-metrics", + ) + _write_json(tmp_dir, [{"a": 1, "b": 2}, {"a": 2, "b": 3}], "metric_t.json") + run_stage2 = run_copy_metrics( + "metric_t.json", + "metric.json", + plots_no_cache=["metric.json"], + name="copy-metrics2", + ) + + index = Index(dvc) + + stages = [foo_stage, run_stage1, run_stage2] + (metrics,) = run_stage1.outs + _, params = run_stage1.deps + (plots,) = run_stage2.outs + + expected_outs = chain.from_iterable([stage.outs for stage in stages]) + expected_deps = chain.from_iterable([stage.deps for stage in stages]) + + assert outputs_equal(index.outs, expected_outs) + assert outputs_equal(index.deps, expected_deps) + assert outputs_equal(index.decorated_outputs, [metrics, plots]) + assert outputs_equal(index.metrics, [metrics]) + assert outputs_equal(index.plots, [plots]) + assert outputs_equal(index.params, [params]) + + +def test_add_update(dvc): + """Test that add/update overwrites existing stages with the new ones. + + The old stages and the new ones might have same hash, so we are + making sure that the old stages were removed and replaced by new ones + using `id`/`is` checks. + """ + index = Index(dvc) + new_stage = Stage(dvc, path="path1") + new_index = index.add(new_stage) + + assert not index.stages + assert new_index.stages == [new_stage] + + dup_stage1 = Stage(dvc, path="path1") + dup_stage2 = Stage(dvc, path="path2") + dup_index = index.update([dup_stage1, dup_stage2]) + assert not index.stages + assert len(new_index) == 1 + assert new_index.stages[0] is new_stage + assert set(map(id, dup_index.stages)) == {id(dup_stage1), id(dup_stage2)} + + +def assert_index_equal(first, second, strict=True, ordered=True): + assert len(first) == len(second), "Index have different no. of stages" + assert set(first) == set(second), "Index does not have same stages" + if ordered: + assert list(first) == list( + second + ), "Index does not have same sequence of stages" + if strict: + assert set(map(id, first)) == set( + map(id, second) + ), "Index is not strictly equal" + + +def test_discard_remove(dvc): + stage = Stage(dvc, path="path1") + index = Index(dvc, stages=[stage]) + + assert list(index.discard(Stage(dvc, "path2"))) == list(index) + new_index = index.discard(stage) + assert len(new_index) == 0 + + with pytest.raises(ValueError): + index.remove(Stage(dvc, "path2")) + assert index.stages == [stage] + assert list(index.remove(stage)) == [] + + +def test_difference(dvc): + stages = [Stage(dvc, path=f"path{i}") for i in range(10)] + index = Index(dvc, stages=stages) + + new_index = index.difference([*stages[:5], Stage(dvc, path="path100")]) + assert index.stages == stages + assert set(new_index) == set(stages[5:]) + + +def test_dumpd(dvc): + stages = [ + PipelineStage(dvc, "dvc.yaml", name="stage1"), + Stage(dvc, "path"), + ] + index = Index(dvc, stages=stages) + assert index.dumpd() == {"dvc.yaml:stage1": {}, "path": {}} + assert index.identifier == "d43da84e9001540c26abf2bf4541c275" + + +def test_unique_identifier(tmp_dir, dvc, scm, run_copy): + dvc.config["core"]["autostage"] = True + tmp_dir.dvc_gen("foo", "foo") + run_copy("foo", "bar", name="copy-foo-bar") + + revs = [] + n_commits = 5 + for i in range(n_commits): + # create a few empty commits + scm.commit(f"commit {i}") + revs.append(scm.get_rev()) + assert len(set(revs)) == n_commits # the commit revs should be unique + + ids = [] + for _ in dvc.brancher(revs=revs): + index = Index(dvc) + assert index.stages + ids.append(index.identifier) + + # we get "workspace" as well from the brancher by default + assert len(revs) + 1 == len(ids) + possible_ids = { + True: "2ba7c7c5b395d4211348d6274b869fc7", + False: "8406970ad2fcafaa84d9310330a67576", + } + assert set(ids) == {possible_ids[os.name == "posix"]} + + +def test_skip_graph_checks(dvc, mocker): + # See https://github.com/iterative/dvc/issues/2671 for more info + mock_build_graph = mocker.spy(Index, "build_graph") + + # sanity check + Index(dvc).check_graph() + assert mock_build_graph.called + mock_build_graph.reset_mock() + + # check that our hack can be enabled + dvc._skip_graph_checks = True + Index(dvc).check_graph() + assert not mock_build_graph.called + mock_build_graph.reset_mock() + + # check that our hack can be disabled + dvc._skip_graph_checks = False + Index(dvc).check_graph() + assert mock_build_graph.called diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 948505a921..46b4801231 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -169,10 +169,12 @@ def test_nested(self): # NOTE: os.walk() walks in a sorted order and we need dir2 subdirs to # be processed before dir1 to load error.dvc first. - self.dvc.stages = [ - nested_stage, - Dvcfile(self.dvc, error_stage_path).stage, - ] + self.dvc.index = self.dvc.index.update( + [ + nested_stage, + Dvcfile(self.dvc, error_stage_path).stage, + ] + ) with patch.object(self.dvc, "_reset"): # to prevent `stages` resetting with self.assertRaises(StagePathAsOutputError): diff --git a/tests/func/test_stage_load.py b/tests/func/test_stage_load.py index de3f6b89d0..5264e15bf1 100644 --- a/tests/func/test_stage_load.py +++ b/tests/func/test_stage_load.py @@ -86,7 +86,9 @@ def test_collect_with_not_existing_output_or_stage_name( def test_stages(tmp_dir, dvc): def collect_stages(): - return {stage.relpath for stage in Repo(os.fspath(tmp_dir)).stages} + return { + stage.relpath for stage in Repo(os.fspath(tmp_dir)).index.stages + } tmp_dir.dvc_gen({"file": "a", "dir/file": "b", "dir/subdir/file": "c"}) @@ -149,7 +151,7 @@ def test_collect_generated(tmp_dir, dvc): } dump_yaml("dvc.yaml", d) - all_stages = set(dvc.stages) + all_stages = set(dvc.index.stages) assert len(all_stages) == 5 assert set(dvc.stage.collect()) == all_stages @@ -412,7 +414,7 @@ def test_collect_optimization(tmp_dir, dvc, mocker): # Forget cached stages and graph and error out on collection dvc._reset() mocker.patch( - "dvc.repo.Repo.stages", + "dvc.repo.index.Index.stages", property(raiser(Exception("Should not collect"))), ) @@ -427,7 +429,7 @@ def test_collect_optimization_on_stage_name(tmp_dir, dvc, mocker, run_copy): # Forget cached stages and graph and error out on collection dvc._reset() mocker.patch( - "dvc.repo.Repo.stages", + "dvc.repo.index.Index.stages", property(raiser(Exception("Should not collect"))), ) @@ -444,7 +446,7 @@ def test_collect_repo_callback(tmp_dir, dvc, mocker): dump_yaml(tmp_dir / PIPELINE_FILE, {"stages": {"cmd": "echo hello world"}}) dvc._reset() - assert dvc.stages == [stage] + assert dvc.index.stages == [stage] mock.assert_called_once() file_path, exc = mock.call_args[0] diff --git a/tests/unit/command/test_dag.py b/tests/unit/command/test_dag.py index f351585fb6..cb2e508ba3 100644 --- a/tests/unit/command/test_dag.py +++ b/tests/unit/command/test_dag.py @@ -17,7 +17,7 @@ def test_dag(tmp_dir, dvc, mocker, fmt): cmd = cli_args.func(cli_args) - mocker.patch("dvc.command.dag._build", return_value=dvc.graph) + mocker.patch("dvc.command.dag._build", return_value=dvc.index.graph) assert cmd.run() == 0 @@ -46,7 +46,7 @@ def repo(tmp_dir, dvc): def test_build(repo): - assert nx.is_isomorphic(_build(repo), repo.graph) + assert nx.is_isomorphic(_build(repo), repo.index.graph) def test_build_target(repo): @@ -69,7 +69,7 @@ def test_build_granular_target_with_outs(repo): def test_build_full(repo): G = _build(repo, target="3", full=True) - assert nx.is_isomorphic(G, repo.graph) + assert nx.is_isomorphic(G, repo.index.graph) # NOTE: granular or not, full outs DAG should be the same @@ -94,7 +94,7 @@ def test_build_full_outs(repo, granular): def test_show_ascii(repo): assert [ - line.rstrip() for line in _show_ascii(repo.graph).splitlines() + line.rstrip() for line in _show_ascii(repo.index.graph).splitlines() ] == [ " +----------------+ +----------------+", # noqa: E501 " | stage: 'a.dvc' | | stage: 'b.dvc' |", # noqa: E501 @@ -115,7 +115,7 @@ def test_show_ascii(repo): def test_show_dot(repo): - assert _show_dot(repo.graph) == ( + assert _show_dot(repo.index.graph) == ( "strict digraph {\n" "stage;\n" "stage;\n" diff --git a/tests/unit/command/test_status.py b/tests/unit/command/test_status.py index bf92bac1d2..23dd9502d3 100644 --- a/tests/unit/command/test_status.py +++ b/tests/unit/command/test_status.py @@ -81,7 +81,7 @@ def test_status_empty(dvc, mocker, capsys): cmd = cli_args.func(cli_args) - spy = mocker.spy(cmd.repo.stage, "collect_repo") + spy = mocker.spy(cmd.repo.stage, "_collect_repo") assert cmd.run() == 0 @@ -108,7 +108,7 @@ def test_status_up_to_date(dvc, mocker, capsys, cloud_opts, expected_message): mocker.patch.dict(cmd.repo.config, {"core": {"remote": "default"}}) mocker.patch.object(cmd.repo, "status", autospec=True, return_value={}) mocker.patch.object( - cmd.repo.stage, "collect_repo", return_value=[object()], autospec=True + cmd.repo.stage, "_collect_repo", return_value=[object()], autospec=True ) assert cmd.run() == 0 diff --git a/tests/unit/repo/test_repo.py b/tests/unit/repo/test_repo.py index 56e07c8ba3..96040b379e 100644 --- a/tests/unit/repo/test_repo.py +++ b/tests/unit/repo/test_repo.py @@ -77,7 +77,7 @@ def test_locked(mocker): def test_skip_graph_checks(tmp_dir, dvc, mocker, run_copy): # See https://github.com/iterative/dvc/issues/2671 for more info - mock_build_graph = mocker.patch("dvc.repo.build_graph") + mock_build_graph = mocker.patch("dvc.repo.index.Index.build_graph") # sanity check tmp_dir.gen("foo", "foo text")