From 6964f8160b663fbedfc34dbd921a0fb84489e616 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Mon, 9 Aug 2021 22:09:04 +0300 Subject: [PATCH] repofs: use underlying fs.download to download files --- dvc/fs/dvc.py | 22 +++++++++++++++------- dvc/fs/git.py | 14 ++++++++++++++ dvc/fs/repo.py | 25 +++++++++++++------------ 3 files changed, 42 insertions(+), 19 deletions(-) diff --git a/dvc/fs/dvc.py b/dvc/fs/dvc.py index 2e5182e4ee..b92caf8611 100644 --- a/dvc/fs/dvc.py +++ b/dvc/fs/dvc.py @@ -62,9 +62,7 @@ def _get_granular_hash( return obj.hash_info raise FileNotFoundError - def open( # type: ignore - self, path: PathInfo, mode="r", encoding=None, remote=None, **kwargs - ): # pylint: disable=arguments-differ + def _get_fs_path(self, path: PathInfo, remote=None): try: outs = self._find_outs(path, strict=False) except OutputNotFoundError as exc: @@ -92,16 +90,20 @@ def open( # type: ignore else: checksum = out.hash_info.value remote_info = remote_odb.hash_to_path_info(checksum) - return remote_odb.fs.open( - remote_info, mode=mode, encoding=encoding - ) + return remote_odb.fs, remote_info if out.is_dir_checksum: checksum = self._get_granular_hash(path, out).value cache_path = out.odb.hash_to_path_info(checksum).url else: cache_path = out.cache_path - return open(cache_path, mode=mode, encoding=encoding) + return out.odb.fs, cache_path + + def open( # type: ignore + self, path: PathInfo, mode="r", encoding=None, **kwargs + ): # pylint: disable=arguments-renamed + fs, fspath = self._get_fs_path(path, **kwargs) + return fs.open(fspath, mode=mode, encoding=encoding) def exists(self, path): # pylint: disable=arguments-renamed try: @@ -253,3 +255,9 @@ def info(self, path_info): ret[obj.hash_info.name] = obj.hash_info.value return ret + + def _download(self, from_info, to_file, **kwargs): + fs, path = self._get_fs_path(from_info) + fs._download( # pylint: disable=protected-access + path, to_file, **kwargs + ) diff --git a/dvc/fs/git.py b/dvc/fs/git.py index b55d164891..46c4127cfc 100644 --- a/dvc/fs/git.py +++ b/dvc/fs/git.py @@ -126,3 +126,17 @@ def walk_files(self, path_info, **kwargs): for file in files: # NOTE: os.path.join is ~5.5 times slower yield f"{root}{os.sep}{file}" + + def _download( + self, from_info, to_file, name=None, no_progress_bar=False, **kwargs + ): + import shutil + + from dvc.progress import Tqdm + + with open(to_file, "wb+") as to_fobj: + with Tqdm.wrapattr( + to_fobj, "write", desc=name, disable=no_progress_bar + ) as wrapped: + with self.open(from_info, "rb", **kwargs) as from_fobj: + shutil.copyfileobj(from_fobj, wrapped) diff --git a/dvc/fs/repo.py b/dvc/fs/repo.py index 10f5764516..c1517abbee 100644 --- a/dvc/fs/repo.py +++ b/dvc/fs/repo.py @@ -453,19 +453,20 @@ def walk_files(self, path_info, **kwargs): for fname in files: yield PathInfo(root) / fname - def _download( - self, from_info, to_file, name=None, no_progress_bar=False, **kwargs - ): - import shutil - - from dvc.progress import Tqdm + def _download(self, from_info, to_file, **kwargs): + fs, dvc_fs = self._get_fs_pair(from_info) + try: + fs._download( # pylint: disable=protected-access + from_info, to_file, **kwargs + ) + return + except FileNotFoundError: + if not dvc_fs: + raise - with open(to_file, "wb+") as to_fobj: - with Tqdm.wrapattr( - to_fobj, "write", desc=name, disable=no_progress_bar - ) as wrapped: - with self.open(from_info, "rb", **kwargs) as from_fobj: - shutil.copyfileobj(from_fobj, wrapped) + dvc_fs._download( # pylint: disable=protected-access + from_info, to_file, **kwargs + ) def metadata(self, path): abspath = os.path.abspath(path)