From 67900a73b26dc344b0d5ce02030d7095786ecfcb Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Fri, 17 Mar 2023 15:23:11 +0200 Subject: [PATCH] dep: repo: use dvcfs Makes dvc get/import/etc use dvcfs, which already supports cloud versioning, circular imports and other stuff. Also makes dvc import behave more like dvc import-url, so that we can use the same existing logic for fetching those using index instead of objects. Fixes #8789 Related iterative/studio#4782 --- dvc/dependency/repo.py | 187 ++++------------------------------ dvc/exceptions.py | 8 -- dvc/output.py | 15 +-- dvc/repo/imports.py | 17 +--- dvc/repo/index.py | 3 +- dvc/repo/init.py | 5 +- pyproject.toml | 2 +- tests/func/test_data_cloud.py | 5 +- tests/func/test_import.py | 49 +-------- 9 files changed, 39 insertions(+), 252 deletions(-) diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index d187469bd3..91959caa9f 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -1,25 +1,14 @@ -import errno -import os -from collections import defaultdict -from copy import copy -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional, Union from voluptuous import Required -from dvc.prompt import confirm +from dvc.utils import as_posix from .base import Dependency if TYPE_CHECKING: - from typing import ContextManager - - from dvc.output import Output - from dvc.repo import Repo + from dvc.fs import DVCFileSystem from dvc.stage import Stage - from dvc_data.hashfile.hash_info import HashInfo - from dvc_data.hashfile.meta import Meta - from dvc_data.hashfile.obj import HashFile - from dvc_objects.db import ObjectDB class RepoDependency(Dependency): @@ -37,18 +26,11 @@ class RepoDependency(Dependency): } def __init__(self, def_repo: Dict[str, str], stage: "Stage", *args, **kwargs): - from dvc.fs import DVCFileSystem - self.def_repo = def_repo - self._objs: Dict[str, "HashFile"] = {} - self._meta: Dict[str, "Meta"] = {} super().__init__(stage, *args, **kwargs) - self.fs = DVCFileSystem( - url=self.def_repo[self.PARAM_URL], - rev=self._get_rev(), - ) - self.fs_path = self.def_path + self.fs = self._make_fs() + self.fs_path = as_posix(self.def_path) def _parse_path(self, fs, fs_path): # noqa: ARG002 return None @@ -61,8 +43,8 @@ def __str__(self): return f"{self.def_path} ({self.def_repo[self.PARAM_URL]})" def workspace_status(self): - current = self.get_obj(locked=True).hash_info - updated = self.get_obj(locked=False).hash_info + current = self._make_fs(locked=True).repo.get_rev() + updated = self._make_fs(locked=False).repo.get_rev() if current != updated: return {str(self): "update available"} @@ -73,33 +55,18 @@ def status(self): return self.workspace_status() def save(self): - pass + rev = self.fs.repo.get_rev() + if self.def_repo.get(self.PARAM_REV_LOCK) is None: + self.def_repo[self.PARAM_REV_LOCK] = rev def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]: return {self.PARAM_PATH: self.def_path, self.PARAM_REPO: self.def_repo} - def download(self, to: "Output", jobs: Optional[int] = None): - from dvc_data.hashfile.checkout import checkout - - for odb, objs in self.get_used_objs().items(): - self.repo.cloud.pull(objs, jobs=jobs, odb=odb) - - obj = self.get_obj() - checkout( - to.fs_path, - to.fs, - obj, - self.repo.cache.local, - ignore=None, - state=self.repo.state, - prompt=confirm, - ) - def update(self, rev: Optional[str] = None): if rev: self.def_repo[self.PARAM_REV] = rev - with self._make_repo(locked=False) as repo: - self.def_repo[self.PARAM_REV_LOCK] = repo.get_rev() + self.fs = self._make_fs(rev=rev, locked=False) + self.def_repo[self.PARAM_REV_LOCK] = self.fs.repo.get_rev() def changed_checksum(self) -> bool: # From current repo point of view what describes RepoDependency is its @@ -107,131 +74,15 @@ def changed_checksum(self) -> bool: # immutable, hence its impossible for checksum to change. return False - def get_used_objs(self, **kwargs) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]: - used, _, _ = self._get_used_and_obj(**kwargs) - return used - - def _get_used_and_obj( - self, obj_only: bool = False, **kwargs - ) -> Tuple[Dict[Optional["ObjectDB"], Set["HashInfo"]], "Meta", "HashFile"]: - from dvc.config import NoRemoteError - from dvc.exceptions import NoOutputOrStageError - from dvc.utils import as_posix - from dvc_data.hashfile.build import build - from dvc_data.hashfile.tree import Tree, TreeError - - local_odb = self.repo.cache.local - locked = kwargs.pop("locked", True) - with self._make_repo(locked=locked, cache_dir=local_odb.path) as repo: - used_obj_ids = defaultdict(set) - rev = repo.get_rev() - if locked and self.def_repo.get(self.PARAM_REV_LOCK) is None: - self.def_repo[self.PARAM_REV_LOCK] = rev - - if not obj_only: - try: - for odb, obj_ids in repo.used_objs( - [os.path.join(repo.root_dir, self.def_path)], - force=True, - jobs=kwargs.get("jobs"), - recursive=True, - ).items(): - if odb is None: - odb = repo.cloud.get_remote_odb() - odb.read_only = True - self._check_circular_import(odb, obj_ids) - used_obj_ids[odb].update(obj_ids) - except (NoRemoteError, NoOutputOrStageError): - pass - - try: - object_store, meta, obj = build( - local_odb, - as_posix(self.def_path), - repo.dvcfs, - local_odb.fs.PARAM_CHECKSUM, - ) - except (FileNotFoundError, TreeError) as exc: - raise FileNotFoundError( - errno.ENOENT, - os.strerror(errno.ENOENT) + f" in {self.def_repo[self.PARAM_URL]}", - self.def_path, - ) from exc - object_store = copy(object_store) - object_store.read_only = True - - self._objs[rev] = obj - self._meta[rev] = meta - - used_obj_ids[object_store].add(obj.hash_info) - if isinstance(obj, Tree): - used_obj_ids[object_store].update(oid for _, _, oid in obj) - return used_obj_ids, meta, obj - - def _check_circular_import(self, odb: "ObjectDB", obj_ids: Set["HashInfo"]) -> None: - from dvc.exceptions import CircularImportError - from dvc.fs.dvc import DVCFileSystem - from dvc_data.hashfile.db.reference import ReferenceHashFileDB - from dvc_data.hashfile.tree import Tree - - if not isinstance(odb, ReferenceHashFileDB): - return - - def iter_objs(): - for hash_info in obj_ids: - if hash_info.isdir: - tree = Tree.load(odb, hash_info) - yield from (odb.get(hi.value) for _, _, hi in tree) - else: - assert hash_info.value - yield odb.get(hash_info.value) - - checked_urls = set() - for obj in iter_objs(): - if not isinstance(obj.fs, DVCFileSystem): - continue - if obj.fs.repo_url in checked_urls or obj.fs.repo.root_dir in checked_urls: - continue - self_url = self.repo.url or self.repo.root_dir - if ( - obj.fs.repo_url is not None - and obj.fs.repo_url == self_url - or obj.fs.repo.root_dir == self.repo.root_dir - ): - raise CircularImportError(self, obj.fs.repo_url, self_url) - checked_urls.update([obj.fs.repo_url, obj.fs.repo.root_dir]) - - def get_obj(self, filter_info=None, **kwargs): - locked = kwargs.get("locked", True) - rev = self._get_rev(locked=locked) - if rev in self._objs: - return self._objs[rev] - _, _, obj = self._get_used_and_obj( - obj_only=True, filter_info=filter_info, **kwargs - ) - return obj - - def get_meta(self, filter_info=None, **kwargs): - locked = kwargs.get("locked", True) - rev = self._get_rev(locked=locked) - if rev in self._meta: - return self._meta[rev] - _, meta, _ = self._get_used_and_obj( - obj_only=True, filter_info=filter_info, **kwargs - ) - return meta - - def _make_repo(self, locked: bool = True, **kwargs) -> "ContextManager[Repo]": - from dvc.external_repo import external_repo + def _make_fs( + self, rev: Optional[str] = None, locked: bool = True + ) -> "DVCFileSystem": + from dvc.fs import DVCFileSystem - d = self.def_repo - rev = self._get_rev(locked=locked) - return external_repo( - d[self.PARAM_URL], - rev=rev, + return DVCFileSystem( + url=self.def_repo[self.PARAM_URL], + rev=rev or self._get_rev(locked=locked), subrepos=True, - uninitialized=True, - **kwargs, ) def _get_rev(self, locked: bool = True): diff --git a/dvc/exceptions.py b/dvc/exceptions.py index f3153e3186..84cc99f283 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -338,14 +338,6 @@ def __init__(self, fs_paths): self.fs_paths = fs_paths -class CircularImportError(DvcException): - def __init__(self, dep, a, b): - super().__init__( - f"'{dep}' contains invalid circular import. " - f"DVC repo '{a}' already imports from '{b}'." - ) - - class PrettyDvcException(DvcException): def __pretty_exc__(self, **kwargs): """Print prettier exception message.""" diff --git a/dvc/output.py b/dvc/output.py index 8267382e25..cadf48f995 100644 --- a/dvc/output.py +++ b/dvc/output.py @@ -1066,7 +1066,7 @@ def _collect_used_dir_cache( return obj.filter(prefix) return obj - def get_used_objs( # noqa: C901, PLR0911 + def get_used_objs( # noqa: C901 self, **kwargs ) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]: """Return filtered set of used object IDs for this out.""" @@ -1076,9 +1076,7 @@ def get_used_objs( # noqa: C901, PLR0911 push: bool = kwargs.pop("push", False) if self.stage.is_repo_import: - if push: - return {} - return self.get_used_external(**kwargs) + return {} if push and not self.can_push: return {} @@ -1130,15 +1128,6 @@ def _named_obj_ids(self, obj): oids.add(oid) return oids - def get_used_external( - self, **kwargs - ) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]: - if not self.use_cache or not self.stage.is_repo_import: - return {} - - (dep,) = self.stage.deps - return dep.get_used_objs(**kwargs) - def _validate_output_path(self, path, stage=None): from dvc.dvcfile import is_valid_filename diff --git a/dvc/repo/imports.py b/dvc/repo/imports.py index 0ec11e065e..d6f27a9051 100644 --- a/dvc/repo/imports.py +++ b/dvc/repo/imports.py @@ -28,11 +28,7 @@ def unfetched_view( changed_deps: List["Dependency"] = [] def need_fetch(stage: "Stage") -> bool: - if ( - not stage.is_import - or stage.is_repo_import - or (stage.is_partial_import and not unpartial) - ): + if not stage.is_import or (stage.is_partial_import and not unpartial): return False out = stage.outs[0] @@ -65,6 +61,7 @@ def unpartial_imports(index: Union["Index", "IndexView"]) -> int: Total number of files which were unpartialed. """ from dvc_data.hashfile.hash_info import HashInfo + from dvc_data.hashfile.meta import Meta updated = 0 for out in index.outs: @@ -73,14 +70,8 @@ def unpartial_imports(index: Union["Index", "IndexView"]) -> int: workspace, key = out.index_key entry = index.data[workspace][key] if out.stage.is_partial_import: - if out.stage.is_repo_import: - dep = out.stage.deps[0] - out.hash_info = dep.get_obj().hash_info - out.meta = dep.get_meta() - else: - assert isinstance(entry.hash_info, HashInfo) - out.hash_info = entry.hash_info - out.meta = entry.meta + out.hash_info = entry.hash_info or HashInfo() + out.meta = entry.meta or Meta() out.stage.md5 = out.stage.compute_md5() out.stage.dump() updated += out.meta.nfiles if out.meta.nfiles is not None else 1 diff --git a/dvc/repo/index.py b/dvc/repo/index.py index 156c0f2ebc..e2d2249d9c 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -530,7 +530,8 @@ def _data_prefixes(self) -> Dict[str, "_DataPrefixes"]: workspace, key = out.index_key if filter_info and out.fs.path.isin(filter_info, out.fs_path): key = key + out.fs.path.relparts(filter_info, out.fs_path) - if out.meta.isdir or out.stage.is_import and out.stage.deps[0].meta.isdir: + entry = self._index.data[workspace][key] + if entry and entry.meta and entry.meta.isdir: prefixes[workspace].recursive.add(key) prefixes[workspace].explicit.update(key[:i] for i in range(len(key), 0, -1)) return prefixes diff --git a/dvc/repo/init.py b/dvc/repo/init.py index 5a9ab093da..4a68a2bb2c 100644 --- a/dvc/repo/init.py +++ b/dvc/repo/init.py @@ -77,7 +77,10 @@ def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False): # noqa: if os.path.isdir(proj.site_cache_dir): proj.close() - remove(proj.site_cache_dir) + try: + remove(proj.site_cache_dir) + except OSError: + logger.debug("failed to remove %s", dvc_dir, exc_info=True) proj = Repo(root_dir) with proj.scm_context(autostage=True) as context: diff --git a/pyproject.toml b/pyproject.toml index 0037e8b5a0..e0ce0cdae5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "configobj>=5.0.6", "distro>=1.3", "dpath<3,>=2.1.0", - "dvc-data>=0.46.0,<0.47", + "dvc-data>=0.47.1,<0.48", "dvc-http", "dvc-render>=0.3.1,<0.4.0", "dvc-studio-client>=0.6.1,<1", diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 286e01ce1e..f4d86b9dd7 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -226,12 +226,13 @@ def test_pull_git_imports(tmp_dir, dvc, scm, erepo): assert dvc.pull()["fetched"] == 0 - for item in ["foo", "new_dir", dvc.cache.local.path]: + for item in ["foo", "new_dir"]: remove(item) + dvc.cache.local.clear() os.makedirs(dvc.cache.local.path, exist_ok=True) clean_repos() - assert dvc.pull(force=True)["fetched"] == 3 + assert dvc.pull(force=True)["fetched"] == 2 assert (tmp_dir / "foo").exists() assert (tmp_dir / "foo").read_text() == "foo" diff --git a/tests/func/test_import.py b/tests/func/test_import.py index b087918f1e..dbf252d948 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -8,7 +8,6 @@ from dvc.cachemgr import CacheManager from dvc.config import NoRemoteError from dvc.dvcfile import load_file -from dvc.exceptions import DownloadError from dvc.fs import system from dvc.scm import Git from dvc.stage.exceptions import StagePathNotFoundError @@ -240,7 +239,9 @@ def test_pull_import_no_download(tmp_dir, scm, dvc, erepo_dir): dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported", no_download=True) dvc.pull(["foo_imported.dvc"]) - assert os.path.exists("foo_imported") + assert (tmp_dir / "foo_imported").exists + assert (tmp_dir / "foo_imported" / "bar").read_bytes() == b"bar" + assert (tmp_dir / "foo_imported" / "baz").read_bytes() == b"baz contents" stage = load_file(dvc, "foo_imported.dvc").stage @@ -338,34 +339,7 @@ def test_push_wildcard_from_bare_git_repo( with dvc_repo.chdir(): dvc_repo.dvc.imp(os.fspath(tmp_dir), "dirextra") - with pytest.raises(FileNotFoundError): - dvc_repo.dvc.imp(os.fspath(tmp_dir), "dir123") - - -def test_download_error_pulling_imported_stage(mocker, tmp_dir, dvc, erepo_dir): - with erepo_dir.chdir(): - erepo_dir.dvc_gen("foo", "foo content", commit="create foo") - dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported") - - dst_stage = load_file(dvc, "foo_imported.dvc").stage - dst_cache = dst_stage.outs[0].cache_path - - remove("foo_imported") - remove(dst_cache) - - def unreliable_download(_from_fs, from_info, _to_fs, to_info, **kwargs): - on_error = kwargs["on_error"] - assert on_error - if isinstance(from_info, str): - from_info = [from_info] - if isinstance(to_info, str): - to_info = [to_info] - for from_i, to_i in zip(from_info, to_info): - on_error(from_i, to_i, Exception()) - - mocker.patch("dvc_objects.fs.generic.transfer", unreliable_download) - with pytest.raises(DownloadError): - dvc.pull(["foo_imported.dvc"]) + dvc_repo.dvc.imp(os.fspath(tmp_dir), "dir123") @pytest.mark.parametrize("dname", [".", "dir", "dir/subdir"]) @@ -634,21 +608,6 @@ def test_chained_import(tmp_dir, dvc, make_tmp_dir, erepo_dir, local_cloud): assert (dst / "bar").read_text() == "bar" -def test_circular_import(tmp_dir, dvc, scm, erepo_dir): - from dvc.exceptions import CircularImportError - - with erepo_dir.chdir(): - erepo_dir.dvc_gen({"dir": {"foo": "foo", "bar": "bar"}}, commit="init") - - dvc.imp(os.fspath(erepo_dir), "dir", "dir_imported") - scm.add("dir_imported.dvc") - scm.commit("import") - - with erepo_dir.chdir(): - with pytest.raises(CircularImportError): - erepo_dir.dvc.imp(os.fspath(tmp_dir), "dir_imported", "circular_import") - - @pytest.mark.parametrize("paths", ([], ["dir"])) def test_parameterized_repo(tmp_dir, dvc, scm, erepo_dir, paths): path = erepo_dir.joinpath(*paths)