Skip to content

Commit

Permalink
erepo: add support for using DependencyRepo's cache instead of tmp
Browse files Browse the repository at this point in the history
* Will close iterative#3611.
  • Loading branch information
pmrowla committed May 14, 2020
1 parent 0941a0b commit 6cf1065
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 30 deletions.
9 changes: 4 additions & 5 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _get_checksum(self, locked=True):
except OutputNotFoundError:
path = PathInfo(os.path.join(repo.root_dir, self.def_path))
# We are polluting our repo cache with some dir listing here
return repo.get_checksum(path, self.repo.cache.local)
with repo.use_cache(self.repo.cache.local):
return repo.get_checksum(path)

def status(self):
current_checksum = self._get_checksum(locked=True)
Expand All @@ -78,10 +79,8 @@ def download(self, to):
if self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = repo.get_rev()

if hasattr(repo, "cache"):
repo.cache.local.cache_dir = self.repo.cache.local.cache_dir

repo.get_external(self.def_path, to.path_info)
with repo.use_cache(self.repo.cache.local):
repo.get_external(self.def_path, to.path_info)

def update(self, rev=None):
if rev:
Expand Down
62 changes: 38 additions & 24 deletions dvc/external_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ def clean_repos():

class BaseExternalRepo:
_tree_rev = None
_local_cache = None

@contextmanager
def use_cache(self, cache):
"""Use specified cache instead of erepo tmpdir cache."""
self._local_cache = cache
if hasattr(self, "cache"):
save_cache = self.cache.local
self.cache.local = cache
# make cache aware of our repo tree
with cache.erepo_tree(self.tree):
yield
if hasattr(self, "cache"):
self.cache.local = save_cache
self._local_cache = None

@cached_property
def repo_tree(self):
Expand All @@ -79,9 +94,8 @@ def get_rev(self):
return self._tree_rev
return self.scm.get_rev()

def get_checksum(self, path, cache):
with cache.erepo_tree(self.tree):
return cache.get_checksum(path)
def get_checksum(self, path):
raise NotImplementedError

def get_external(self, path, to_info, **kwargs):
"""
Expand All @@ -96,29 +110,26 @@ def get_external(self, path, to_info, **kwargs):
if not self.repo_tree.exists(path_info):
raise PathMissingError(path, self.url)

if self.repo_tree.isdvc(path_info):
local_cache = self.cache.local
else:
local_cache = None

if self.repo_tree.isdvc(path_info):
(out,) = self.find_outs_by_path(path_info, strict=False)
filter_info = path_info.relative_to(out.path_info)
self._fetch_external(
[path_info], local_cache, filter_info=filter_info, **kwargs
)
else:
filter_info = None

self._fetch_external([path_info], filter_info=filter_info, **kwargs)

with self.state:
self.repo_tree.copytree(path_info, to_info)

def fetch_external(self, files, cache, **kwargs):
def fetch_external(self, files, **kwargs):
"""Fetch erepo files into the specified cache.
Works with files tracked by Git and DVC.
"""
raise NotImplementedError
files = [PathInfo(self.root_dir) / name for name in files]
return self._fetch_external(files, **kwargs)

def _fetch_external(self, path_infos, cache, **kwargs):
def _fetch_external(self, path_infos, **kwargs):
downloaded, failed = 0, 0

with self.state:
Expand All @@ -127,7 +138,7 @@ def _fetch_external(self, path_infos, cache, **kwargs):
(out,) = self.find_outs_by_path(path_info, strict=False)
d, f = self._fetch_out(out, **kwargs)
else:
d, f = self._fetch_git(path_info, cache)
d, f = self._fetch_git(path_info)
downloaded += d
failed += f

Expand All @@ -150,9 +161,16 @@ def _fetch_out(self, out, filter_info=None, **kwargs):
failed += exc.amount
return downloaded, failed

def _fetch_git(self, path_info, local_cache):
def _fetch_git(self, path_info):
"""Copy git tracked file into specified cache."""
downloaded, failed = 0, 0
if hasattr(self, "cache"):
local_cache = self.cache.local
elif self._local_cache:
local_cache = self._local_cache
else:
return downloaded, failed

info = local_cache.save_info(path_info)
if info.get(local_cache.PARAM_CHECKSUM) is None:
logger.exception(
Expand Down Expand Up @@ -236,10 +254,8 @@ def _add_upstream(self, src_repo):
self.config["remote"]["auto-generated-upstream"] = {"url": cache_dir}
self.config["core"]["remote"] = "auto-generated-upstream"

def fetch_external(self, files, cache, **kwargs):
self.cache.local.cache_dir = cache.cache_dir
files = [PathInfo(self.root_dir) / name for name in files]
return self._fetch_external(files, self.cache.local, **kwargs)
def get_checksum(self, path):
return self.cache.local.get_checksum(path)


class ExternalGitRepo(BaseExternalRepo):
Expand Down Expand Up @@ -272,10 +288,8 @@ def open_by_relpath(self, path, mode="r", encoding=None, **kwargs):
except FileNotFoundError:
raise PathMissingError(path, self.url)

def fetch_external(self, files, cache, **kwargs):
files = [PathInfo(self.root_dir) / name for name in files]
with cache.erepo_tree(self.tree):
return self._fetch_external(files, cache, **kwargs)
def get_checksum(self, path):
return self._local_cache.get_checksum(path)


def _cached_clone(url, rev, for_write=False):
Expand Down
3 changes: 2 additions & 1 deletion dvc/repo/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,5 @@ def _fetch_external(self, repo_url, repo_rev, files, **kwargs):
from dvc.external_repo import external_repo

with external_repo(repo_url, repo_rev) as repo:
return repo.fetch_external(files, self.cache.local, **kwargs)
with repo.use_cache(self.cache.local):
return repo.fetch_external(files, **kwargs)

0 comments on commit 6cf1065

Please sign in to comment.