diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 998cac08f8..170e888e75 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -56,7 +56,8 @@ def _get_checksum(self, locked=True): except OutputNotFoundError: path = PathInfo(os.path.join(repo.root_dir, self.def_path)) # We are polluting our repo cache with some dir listing here - return repo.get_checksum(path, self.repo.cache.local) + with repo.use_cache(self.repo.cache.local): + return repo.get_checksum(path) def status(self): current_checksum = self._get_checksum(locked=True) @@ -78,10 +79,8 @@ def download(self, to): if self.def_repo.get(self.PARAM_REV_LOCK) is None: self.def_repo[self.PARAM_REV_LOCK] = repo.get_rev() - if hasattr(repo, "cache"): - repo.cache.local.cache_dir = self.repo.cache.local.cache_dir - - repo.get_external(self.def_path, to.path_info) + with repo.use_cache(self.repo.cache.local): + repo.get_external(self.def_path, to.path_info) def update(self, rev=None): if rev: diff --git a/dvc/external_repo.py b/dvc/external_repo.py index a597ba5817..4d1828bf94 100644 --- a/dvc/external_repo.py +++ b/dvc/external_repo.py @@ -68,6 +68,21 @@ def clean_repos(): class BaseExternalRepo: _tree_rev = None + _local_cache = None + + @contextmanager + def use_cache(self, cache): + """Use specified cache instead of erepo tmpdir cache.""" + self._local_cache = cache + if hasattr(self, "cache"): + save_cache = self.cache.local + self.cache.local = cache + # make cache aware of our repo tree + with cache.erepo_tree(self.tree): + yield + if hasattr(self, "cache"): + self.cache.local = save_cache + self._local_cache = None @cached_property def repo_tree(self): @@ -79,9 +94,8 @@ def get_rev(self): return self._tree_rev return self.scm.get_rev() - def get_checksum(self, path, cache): - with cache.erepo_tree(self.tree): - return cache.get_checksum(path) + def get_checksum(self, path): + raise NotImplementedError def get_external(self, path, to_info, **kwargs): """ @@ -96,29 +110,26 @@ def get_external(self, path, to_info, **kwargs): if not self.repo_tree.exists(path_info): raise PathMissingError(path, self.url) - if self.repo_tree.isdvc(path_info): - local_cache = self.cache.local - else: - local_cache = None - if self.repo_tree.isdvc(path_info): (out,) = self.find_outs_by_path(path_info, strict=False) filter_info = path_info.relative_to(out.path_info) - self._fetch_external( - [path_info], local_cache, filter_info=filter_info, **kwargs - ) + else: + filter_info = None + + self._fetch_external([path_info], filter_info=filter_info, **kwargs) with self.state: self.repo_tree.copytree(path_info, to_info) - def fetch_external(self, files, cache, **kwargs): + def fetch_external(self, files, **kwargs): """Fetch erepo files into the specified cache. Works with files tracked by Git and DVC. """ - raise NotImplementedError + files = [PathInfo(self.root_dir) / name for name in files] + return self._fetch_external(files, **kwargs) - def _fetch_external(self, path_infos, cache, **kwargs): + def _fetch_external(self, path_infos, **kwargs): downloaded, failed = 0, 0 with self.state: @@ -127,7 +138,7 @@ def _fetch_external(self, path_infos, cache, **kwargs): (out,) = self.find_outs_by_path(path_info, strict=False) d, f = self._fetch_out(out, **kwargs) else: - d, f = self._fetch_git(path_info, cache) + d, f = self._fetch_git(path_info) downloaded += d failed += f @@ -150,9 +161,16 @@ def _fetch_out(self, out, filter_info=None, **kwargs): failed += exc.amount return downloaded, failed - def _fetch_git(self, path_info, local_cache): + def _fetch_git(self, path_info): """Copy git tracked file into specified cache.""" downloaded, failed = 0, 0 + if hasattr(self, "cache"): + local_cache = self.cache.local + elif self._local_cache: + local_cache = self._local_cache + else: + return downloaded, failed + info = local_cache.save_info(path_info) if info.get(local_cache.PARAM_CHECKSUM) is None: logger.exception( @@ -236,10 +254,8 @@ def _add_upstream(self, src_repo): self.config["remote"]["auto-generated-upstream"] = {"url": cache_dir} self.config["core"]["remote"] = "auto-generated-upstream" - def fetch_external(self, files, cache, **kwargs): - self.cache.local.cache_dir = cache.cache_dir - files = [PathInfo(self.root_dir) / name for name in files] - return self._fetch_external(files, self.cache.local, **kwargs) + def get_checksum(self, path): + return self.cache.local.get_checksum(path) class ExternalGitRepo(BaseExternalRepo): @@ -272,10 +288,8 @@ def open_by_relpath(self, path, mode="r", encoding=None, **kwargs): except FileNotFoundError: raise PathMissingError(path, self.url) - def fetch_external(self, files, cache, **kwargs): - files = [PathInfo(self.root_dir) / name for name in files] - with cache.erepo_tree(self.tree): - return self._fetch_external(files, cache, **kwargs) + def get_checksum(self, path): + return self._local_cache.get_checksum(path) def _cached_clone(url, rev, for_write=False): diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index 5a48e81143..3db02abd31 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -78,4 +78,5 @@ def _fetch_external(self, repo_url, repo_rev, files, **kwargs): from dvc.external_repo import external_repo with external_repo(repo_url, repo_rev) as repo: - return repo.fetch_external(files, self.cache.local, **kwargs) + with repo.use_cache(self.cache.local): + return repo.fetch_external(files, **kwargs)