Skip to content

Commit

Permalink
erepo: more tree abstractions
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla committed May 14, 2020
1 parent bf371a6 commit dc71713
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 162 deletions.
147 changes: 24 additions & 123 deletions dvc/external_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@

from dvc.config import NoRemoteError, NotDvcRepoError
from dvc.exceptions import (
DownloadError,
FileMissingError,
NoOutputInExternalRepoError,
NoRemoteInExternalRepoError,
OutputNotFoundError,
PathMissingError,
RecursiveImportError,
)
from dvc.path_info import PathInfo
from dvc.repo import Repo
Expand Down Expand Up @@ -72,20 +70,6 @@ class BaseExternalRepo:
_tree_rev = None
_local_cache = None

@contextmanager
def use_cache(self, cache):
"""Use specified cache instead of erepo tmpdir cache."""
self._local_cache = cache
if hasattr(self, "cache"):
save_cache = self.cache.local
self.cache.local = cache
# make cache aware of our repo tree
with cache.erepo_tree(self.tree):
yield
if hasattr(self, "cache"):
self.cache.local = save_cache
self._local_cache = None

@cached_property
def repo_tree(self):
return RepoTree(self)
Expand Down Expand Up @@ -115,89 +99,49 @@ def _open_by_repo_tree_relpath(
except FileNotFoundError:
raise PathMissingError(path, self.url)

def get_checksum(self, path):
raise NotImplementedError

def check_recursive_imports(self, path):
"""Raise RecursiveImportError if path_info contains recursively added
DVC outs.
"""
path_info = PathInfo(self.root_dir) / path
self._recursive_outputs(path_info)

def _recursive_outputs(self, path_info, recursive=False):
# if path_info is a non-dvc directory, we need to check for
# recursively added dvc files
fetch_infos = []
for root, dirs, files in self.repo_tree.walk(path_info):
root_path = PathInfo(root)
for name in dirs + files:
if name == Repo.DVC_DIR:
# import from subrepos currently unsupported
raise RecursiveImportError(
path_info.relative_to(self.root_dir), subrepo=True
)
if self.repo_tree.isdvc(root_path / name):
if recursive:
fetch_infos.append(self._fetch_info(root_path / name))
else:
raise RecursiveImportError(
path_info.relative_to(self.root_dir)
)
return fetch_infos

def _fetch_info(self, path_info):
if self.repo_tree.isdvc(path_info):
(out,) = self.find_outs_by_path(path_info, strict=False)
filter_info = path_info.relative_to(out.path_info)
else:
filter_info = None
return path_info, filter_info

def get_external(self, path, to_info, recursive=False, **kwargs):
def get_external(self, path, to_info, cache=None, **kwargs):
"""
Pull the corresponding file or directory specified by `path` and
into `to_info`.
It works with files tracked by Git and DVC, and also local files
outside the repository.
"""
path_info = PathInfo(self.root_dir) / path

if not self.repo_tree.exists(path_info):
raise PathMissingError(path, self.url)
save_git = False
if cache:
save_git = True
elif hasattr(self, "cache"):
cache = self.cache.local

fetch_infos = [self._fetch_info(path_info)]
if self.repo_tree.isdir(path_info) and not self.repo_tree.isdvc(
path_info
):
fetch_infos.extend(self._recursive_outputs(path_info, recursive))

self._fetch_external(fetch_infos, **kwargs)
path_info = PathInfo(self.root_dir) / path
self._fetch_external([path_info], cache, save_git=save_git, **kwargs)

with self.state:
self.repo_tree.copytree(path_info, to_info)
self.repo_tree.checkout(path_info, to_info, cache)

def fetch_external(self, files, **kwargs):
def fetch_external(self, files, cache=None, **kwargs):
"""Fetch erepo files into the specified cache.
Works with files tracked by Git and DVC.
"""
files = [(PathInfo(self.root_dir) / name, None) for name in files]
return self._fetch_external(files, **kwargs)
save_git = False
if cache:
save_git = True
elif hasattr(self, "cache"):
cache = self.cache.local

files = [PathInfo(self.root_dir) / name for name in files]
return self._fetch_external(files, cache, save_git=save_git, **kwargs)

def _fetch_external(self, fetch_infos, **kwargs):
def _fetch_external(self, path_infos, cache, **kwargs):
downloaded, failed = 0, 0

with self.state:
for path_info, filter_info in fetch_infos:
if self.repo_tree.isdvc(path_info):
(out,) = self.find_outs_by_path(path_info, strict=False)
d, f = self._fetch_out(
out, filter_info=filter_info, **kwargs
)
else:
d, f = self._fetch_git(path_info)
for path_info in path_infos:
try:
d, f = self.repo_tree.fetch(path_info, cache, **kwargs)
except FileNotFoundError:
raise PathMissingError(path_info, self.url)
downloaded += d
failed += f

Expand All @@ -209,43 +153,6 @@ def _fetch_external(self, fetch_infos, **kwargs):
)
return downloaded, failed

def _fetch_out(self, out, filter_info=None, **kwargs):
"""Fetch specified erepo out."""
downloaded, failed = 0, 0
if out.changed_cache(filter_info=filter_info):
used_cache = out.get_used_cache()
try:
downloaded += self.cloud.pull(used_cache, **kwargs)
except DownloadError as exc:
failed += exc.amount
return downloaded, failed

def _fetch_git(self, path_info):
"""Copy git tracked file into specified cache."""
downloaded, failed = 0, 0
if hasattr(self, "cache"):
local_cache = self.cache.local
elif self._local_cache:
local_cache = self._local_cache
else:
return downloaded, failed

info = local_cache.save_info(path_info)
if info.get(local_cache.PARAM_CHECKSUM) is None:
logger.exception(
"failed to fetch '{}' from '{}' repo".format(
path_info, self.url
)
)
failed += 1
elif local_cache.changed_cache(info[local_cache.PARAM_CHECKSUM]):
local_cache.save_tree(self.repo_tree, path_info, info)
logger.debug(
"fetched '{}' from '{}' repo".format(path_info, self.url)
)
downloaded += 1
return downloaded, failed


class ExternalRepo(Repo, BaseExternalRepo):
def __init__(self, root_dir, url, rev, for_write=False):
Expand Down Expand Up @@ -313,9 +220,6 @@ def _add_upstream(self, src_repo):
self.config["remote"]["auto-generated-upstream"] = {"url": cache_dir}
self.config["core"]["remote"] = "auto-generated-upstream"

def get_checksum(self, path):
return self.cache.local.get_checksum(path)

@contextmanager
def open_by_relpath(self, path, remote=None, mode="r", encoding=None):
"""Opens a specified resource as a file object."""
Expand Down Expand Up @@ -363,9 +267,6 @@ def open_by_relpath(self, path, mode="r", encoding=None, **kwargs):
) as fobj:
yield fobj

def get_checksum(self, path):
return self._local_cache.get_checksum(path)


def _cached_clone(url, rev, for_write=False):
"""Clone an external git repo to a temporary directory.
Expand Down
Loading

0 comments on commit dc71713

Please sign in to comment.