Skip to content

Commit

Permalink
dep: repo: use dvcfs
Browse files Browse the repository at this point in the history
Makes dvc get/import/etc use dvcfs, which already supports cloud versioning,
circular imports and other stuff.

Also makes dvc import behave more like dvc import-url, so that we can use
the same existing logic for fetching those using index instead of objects.

Fixes #8789
Related iterative/studio#4782
  • Loading branch information
efiop committed Apr 2, 2023
1 parent 24e10e0 commit 67900a7
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 252 deletions.
187 changes: 19 additions & 168 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
import errno
import os
from collections import defaultdict
from copy import copy
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Dict, Optional, Union

from voluptuous import Required

from dvc.prompt import confirm
from dvc.utils import as_posix

from .base import Dependency

if TYPE_CHECKING:
from typing import ContextManager

from dvc.output import Output
from dvc.repo import Repo
from dvc.fs import DVCFileSystem
from dvc.stage import Stage
from dvc_data.hashfile.hash_info import HashInfo
from dvc_data.hashfile.meta import Meta
from dvc_data.hashfile.obj import HashFile
from dvc_objects.db import ObjectDB


class RepoDependency(Dependency):
Expand All @@ -37,18 +26,11 @@ class RepoDependency(Dependency):
}

def __init__(self, def_repo: Dict[str, str], stage: "Stage", *args, **kwargs):
from dvc.fs import DVCFileSystem

self.def_repo = def_repo
self._objs: Dict[str, "HashFile"] = {}
self._meta: Dict[str, "Meta"] = {}
super().__init__(stage, *args, **kwargs)

self.fs = DVCFileSystem(
url=self.def_repo[self.PARAM_URL],
rev=self._get_rev(),
)
self.fs_path = self.def_path
self.fs = self._make_fs()
self.fs_path = as_posix(self.def_path)

def _parse_path(self, fs, fs_path): # noqa: ARG002
return None
Expand All @@ -61,8 +43,8 @@ def __str__(self):
return f"{self.def_path} ({self.def_repo[self.PARAM_URL]})"

def workspace_status(self):
current = self.get_obj(locked=True).hash_info
updated = self.get_obj(locked=False).hash_info
current = self._make_fs(locked=True).repo.get_rev()
updated = self._make_fs(locked=False).repo.get_rev()

if current != updated:
return {str(self): "update available"}
Expand All @@ -73,165 +55,34 @@ def status(self):
return self.workspace_status()

def save(self):
pass
rev = self.fs.repo.get_rev()
if self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = rev

def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]:
return {self.PARAM_PATH: self.def_path, self.PARAM_REPO: self.def_repo}

def download(self, to: "Output", jobs: Optional[int] = None):
from dvc_data.hashfile.checkout import checkout

for odb, objs in self.get_used_objs().items():
self.repo.cloud.pull(objs, jobs=jobs, odb=odb)

obj = self.get_obj()
checkout(
to.fs_path,
to.fs,
obj,
self.repo.cache.local,
ignore=None,
state=self.repo.state,
prompt=confirm,
)

def update(self, rev: Optional[str] = None):
if rev:
self.def_repo[self.PARAM_REV] = rev
with self._make_repo(locked=False) as repo:
self.def_repo[self.PARAM_REV_LOCK] = repo.get_rev()
self.fs = self._make_fs(rev=rev, locked=False)
self.def_repo[self.PARAM_REV_LOCK] = self.fs.repo.get_rev()

def changed_checksum(self) -> bool:
# From current repo point of view what describes RepoDependency is its
# origin project url and rev_lock, and it makes RepoDependency
# immutable, hence its impossible for checksum to change.
return False

def get_used_objs(self, **kwargs) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]:
used, _, _ = self._get_used_and_obj(**kwargs)
return used

def _get_used_and_obj(
self, obj_only: bool = False, **kwargs
) -> Tuple[Dict[Optional["ObjectDB"], Set["HashInfo"]], "Meta", "HashFile"]:
from dvc.config import NoRemoteError
from dvc.exceptions import NoOutputOrStageError
from dvc.utils import as_posix
from dvc_data.hashfile.build import build
from dvc_data.hashfile.tree import Tree, TreeError

local_odb = self.repo.cache.local
locked = kwargs.pop("locked", True)
with self._make_repo(locked=locked, cache_dir=local_odb.path) as repo:
used_obj_ids = defaultdict(set)
rev = repo.get_rev()
if locked and self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = rev

if not obj_only:
try:
for odb, obj_ids in repo.used_objs(
[os.path.join(repo.root_dir, self.def_path)],
force=True,
jobs=kwargs.get("jobs"),
recursive=True,
).items():
if odb is None:
odb = repo.cloud.get_remote_odb()
odb.read_only = True
self._check_circular_import(odb, obj_ids)
used_obj_ids[odb].update(obj_ids)
except (NoRemoteError, NoOutputOrStageError):
pass

try:
object_store, meta, obj = build(
local_odb,
as_posix(self.def_path),
repo.dvcfs,
local_odb.fs.PARAM_CHECKSUM,
)
except (FileNotFoundError, TreeError) as exc:
raise FileNotFoundError(
errno.ENOENT,
os.strerror(errno.ENOENT) + f" in {self.def_repo[self.PARAM_URL]}",
self.def_path,
) from exc
object_store = copy(object_store)
object_store.read_only = True

self._objs[rev] = obj
self._meta[rev] = meta

used_obj_ids[object_store].add(obj.hash_info)
if isinstance(obj, Tree):
used_obj_ids[object_store].update(oid for _, _, oid in obj)
return used_obj_ids, meta, obj

def _check_circular_import(self, odb: "ObjectDB", obj_ids: Set["HashInfo"]) -> None:
from dvc.exceptions import CircularImportError
from dvc.fs.dvc import DVCFileSystem
from dvc_data.hashfile.db.reference import ReferenceHashFileDB
from dvc_data.hashfile.tree import Tree

if not isinstance(odb, ReferenceHashFileDB):
return

def iter_objs():
for hash_info in obj_ids:
if hash_info.isdir:
tree = Tree.load(odb, hash_info)
yield from (odb.get(hi.value) for _, _, hi in tree)
else:
assert hash_info.value
yield odb.get(hash_info.value)

checked_urls = set()
for obj in iter_objs():
if not isinstance(obj.fs, DVCFileSystem):
continue
if obj.fs.repo_url in checked_urls or obj.fs.repo.root_dir in checked_urls:
continue
self_url = self.repo.url or self.repo.root_dir
if (
obj.fs.repo_url is not None
and obj.fs.repo_url == self_url
or obj.fs.repo.root_dir == self.repo.root_dir
):
raise CircularImportError(self, obj.fs.repo_url, self_url)
checked_urls.update([obj.fs.repo_url, obj.fs.repo.root_dir])

def get_obj(self, filter_info=None, **kwargs):
locked = kwargs.get("locked", True)
rev = self._get_rev(locked=locked)
if rev in self._objs:
return self._objs[rev]
_, _, obj = self._get_used_and_obj(
obj_only=True, filter_info=filter_info, **kwargs
)
return obj

def get_meta(self, filter_info=None, **kwargs):
locked = kwargs.get("locked", True)
rev = self._get_rev(locked=locked)
if rev in self._meta:
return self._meta[rev]
_, meta, _ = self._get_used_and_obj(
obj_only=True, filter_info=filter_info, **kwargs
)
return meta

def _make_repo(self, locked: bool = True, **kwargs) -> "ContextManager[Repo]":
from dvc.external_repo import external_repo
def _make_fs(
self, rev: Optional[str] = None, locked: bool = True
) -> "DVCFileSystem":
from dvc.fs import DVCFileSystem

d = self.def_repo
rev = self._get_rev(locked=locked)
return external_repo(
d[self.PARAM_URL],
rev=rev,
return DVCFileSystem(
url=self.def_repo[self.PARAM_URL],
rev=rev or self._get_rev(locked=locked),
subrepos=True,
uninitialized=True,
**kwargs,
)

def _get_rev(self, locked: bool = True):
Expand Down
8 changes: 0 additions & 8 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,6 @@ def __init__(self, fs_paths):
self.fs_paths = fs_paths


class CircularImportError(DvcException):
def __init__(self, dep, a, b):
super().__init__(
f"'{dep}' contains invalid circular import. "
f"DVC repo '{a}' already imports from '{b}'."
)


class PrettyDvcException(DvcException):
def __pretty_exc__(self, **kwargs):
"""Print prettier exception message."""
15 changes: 2 additions & 13 deletions dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ def _collect_used_dir_cache(
return obj.filter(prefix)
return obj

def get_used_objs( # noqa: C901, PLR0911
def get_used_objs( # noqa: C901
self, **kwargs
) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]:
"""Return filtered set of used object IDs for this out."""
Expand All @@ -1076,9 +1076,7 @@ def get_used_objs( # noqa: C901, PLR0911

push: bool = kwargs.pop("push", False)
if self.stage.is_repo_import:
if push:
return {}
return self.get_used_external(**kwargs)
return {}

if push and not self.can_push:
return {}
Expand Down Expand Up @@ -1130,15 +1128,6 @@ def _named_obj_ids(self, obj):
oids.add(oid)
return oids

def get_used_external(
self, **kwargs
) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]:
if not self.use_cache or not self.stage.is_repo_import:
return {}

(dep,) = self.stage.deps
return dep.get_used_objs(**kwargs)

def _validate_output_path(self, path, stage=None):
from dvc.dvcfile import is_valid_filename

Expand Down
17 changes: 4 additions & 13 deletions dvc/repo/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ def unfetched_view(
changed_deps: List["Dependency"] = []

def need_fetch(stage: "Stage") -> bool:
if (
not stage.is_import
or stage.is_repo_import
or (stage.is_partial_import and not unpartial)
):
if not stage.is_import or (stage.is_partial_import and not unpartial):
return False

out = stage.outs[0]
Expand Down Expand Up @@ -65,6 +61,7 @@ def unpartial_imports(index: Union["Index", "IndexView"]) -> int:
Total number of files which were unpartialed.
"""
from dvc_data.hashfile.hash_info import HashInfo
from dvc_data.hashfile.meta import Meta

updated = 0
for out in index.outs:
Expand All @@ -73,14 +70,8 @@ def unpartial_imports(index: Union["Index", "IndexView"]) -> int:
workspace, key = out.index_key
entry = index.data[workspace][key]
if out.stage.is_partial_import:
if out.stage.is_repo_import:
dep = out.stage.deps[0]
out.hash_info = dep.get_obj().hash_info
out.meta = dep.get_meta()
else:
assert isinstance(entry.hash_info, HashInfo)
out.hash_info = entry.hash_info
out.meta = entry.meta
out.hash_info = entry.hash_info or HashInfo()
out.meta = entry.meta or Meta()
out.stage.md5 = out.stage.compute_md5()
out.stage.dump()
updated += out.meta.nfiles if out.meta.nfiles is not None else 1
Expand Down
3 changes: 2 additions & 1 deletion dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ def _data_prefixes(self) -> Dict[str, "_DataPrefixes"]:
workspace, key = out.index_key
if filter_info and out.fs.path.isin(filter_info, out.fs_path):
key = key + out.fs.path.relparts(filter_info, out.fs_path)
if out.meta.isdir or out.stage.is_import and out.stage.deps[0].meta.isdir:
entry = self._index.data[workspace][key]
if entry and entry.meta and entry.meta.isdir:
prefixes[workspace].recursive.add(key)
prefixes[workspace].explicit.update(key[:i] for i in range(len(key), 0, -1))
return prefixes
Expand Down
5 changes: 4 additions & 1 deletion dvc/repo/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False): # noqa:

if os.path.isdir(proj.site_cache_dir):
proj.close()
remove(proj.site_cache_dir)
try:
remove(proj.site_cache_dir)
except OSError:
logger.debug("failed to remove %s", dvc_dir, exc_info=True)
proj = Repo(root_dir)

with proj.scm_context(autostage=True) as context:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"configobj>=5.0.6",
"distro>=1.3",
"dpath<3,>=2.1.0",
"dvc-data>=0.46.0,<0.47",
"dvc-data>=0.47.1,<0.48",
"dvc-http",
"dvc-render>=0.3.1,<0.4.0",
"dvc-studio-client>=0.6.1,<1",
Expand Down
5 changes: 3 additions & 2 deletions tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,13 @@ def test_pull_git_imports(tmp_dir, dvc, scm, erepo):

assert dvc.pull()["fetched"] == 0

for item in ["foo", "new_dir", dvc.cache.local.path]:
for item in ["foo", "new_dir"]:
remove(item)
dvc.cache.local.clear()
os.makedirs(dvc.cache.local.path, exist_ok=True)
clean_repos()

assert dvc.pull(force=True)["fetched"] == 3
assert dvc.pull(force=True)["fetched"] == 2

assert (tmp_dir / "foo").exists()
assert (tmp_dir / "foo").read_text() == "foo"
Expand Down
Loading

0 comments on commit 67900a7

Please sign in to comment.