From 76228a21183e996b613fa01944d2d953702335e5 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Wed, 15 Jul 2020 23:44:09 +0300 Subject: [PATCH] dvc: split remotes from trees (#4212) * dvc: remove tree-specific wrappers from remote This helps combat the remote/tree/cache confusion. * dvc: split remotes from trees Related to #4050 --- dvc/api.py | 2 +- dvc/command/version.py | 2 +- dvc/dependency/__init__.py | 2 +- dvc/dependency/azure.py | 3 +- dvc/dependency/http.py | 3 +- dvc/dependency/https.py | 3 +- dvc/main.py | 2 +- dvc/output/__init__.py | 9 +- dvc/output/base.py | 3 +- dvc/output/gs.py | 3 +- dvc/output/hdfs.py | 3 +- dvc/output/local.py | 3 +- dvc/output/s3.py | 3 +- dvc/output/ssh.py | 3 +- dvc/remote/__init__.py | 83 +-- dvc/remote/base.py | 707 +------------------------ dvc/remote/local.py | 312 +---------- dvc/remote/ssh.py | 73 +++ dvc/repo/tree.py | 4 +- dvc/stage/utils.py | 4 +- dvc/tree/__init__.py | 78 +++ dvc/{remote => tree}/azure.py | 3 +- dvc/tree/base.py | 703 ++++++++++++++++++++++++ dvc/{remote => tree}/gdrive.py | 3 +- dvc/{remote => tree}/gs.py | 3 +- dvc/{remote => tree}/hdfs.py | 0 dvc/{remote => tree}/http.py | 3 +- dvc/{remote => tree}/https.py | 0 dvc/tree/local.py | 309 +++++++++++ dvc/{remote => tree}/oss.py | 3 +- dvc/{remote => tree}/pool.py | 0 dvc/{remote => tree}/s3.py | 3 +- dvc/{remote => tree}/ssh/__init__.py | 72 +-- dvc/{remote => tree}/ssh/connection.py | 3 +- tests/conftest.py | 2 +- tests/func/remote/test_gdrive.py | 2 +- tests/func/test_add.py | 4 +- tests/func/test_checkout.py | 14 +- tests/func/test_remote.py | 16 +- tests/func/test_s3.py | 2 +- tests/remotes/gdrive.py | 2 +- tests/remotes/ssh.py | 6 +- tests/unit/remote/ssh/test_pool.py | 4 +- tests/unit/remote/ssh/test_ssh.py | 2 +- tests/unit/remote/test_azure.py | 2 +- tests/unit/remote/test_base.py | 4 +- tests/unit/remote/test_gdrive.py | 2 +- tests/unit/remote/test_gs.py | 2 +- tests/unit/remote/test_http.py | 2 +- tests/unit/remote/test_local.py | 3 +- tests/unit/remote/test_oss.py | 2 +- tests/unit/remote/test_remote.py | 6 +- tests/unit/remote/test_remote_tree.py | 32 +- tests/unit/remote/test_s3.py | 2 +- 54 files changed, 1292 insertions(+), 1229 deletions(-) create mode 100644 dvc/remote/ssh.py create mode 100644 dvc/tree/__init__.py rename dvc/{remote => tree}/azure.py (99%) create mode 100644 dvc/tree/base.py rename dvc/{remote => tree}/gdrive.py (99%) rename dvc/{remote => tree}/gs.py (99%) rename dvc/{remote => tree}/hdfs.py (100%) rename dvc/{remote => tree}/http.py (99%) rename dvc/{remote => tree}/https.py (100%) create mode 100644 dvc/tree/local.py rename dvc/{remote => tree}/oss.py (98%) rename dvc/{remote => tree}/pool.py (100%) rename dvc/{remote => tree}/s3.py (99%) rename dvc/{remote => tree}/ssh/__init__.py (77%) rename dvc/{remote => tree}/ssh/connection.py (99%) diff --git a/dvc/api.py b/dvc/api.py index e469d24063..a2d648aa17 100644 --- a/dvc/api.py +++ b/dvc/api.py @@ -30,7 +30,7 @@ def get_url(path, repo=None, rev=None, remote=None): raise UrlNotDvcRepoError(_repo.url) # pylint: disable=no-member out = _repo.find_out_by_relpath(path) remote_obj = _repo.cloud.get_remote(remote) - return str(remote_obj.hash_to_path_info(out.checksum)) + return str(remote_obj.tree.hash_to_path_info(out.checksum)) def open( # noqa, pylint: disable=redefined-builtin diff --git a/dvc/command/version.py b/dvc/command/version.py index 02d38e81fb..589eb548da 100644 --- a/dvc/command/version.py +++ b/dvc/command/version.py @@ -122,7 +122,7 @@ def get_linktype_support_info(repo): @staticmethod def get_supported_remotes(): - from dvc.remote import TREES + from dvc.tree import TREES supported_remotes = [] for tree_cls in TREES: diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index ab737fdd26..91bd96c232 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -12,9 +12,9 @@ from dvc.dependency.s3 import S3Dependency from dvc.dependency.ssh import SSHDependency from dvc.output.base import BaseOutput -from dvc.remote import get_cloud_tree from dvc.scheme import Schemes +from ..tree import get_cloud_tree from .repo import RepoDependency DEPS = [ diff --git a/dvc/dependency/azure.py b/dvc/dependency/azure.py index da7b0727c3..ad93443f13 100644 --- a/dvc/dependency/azure.py +++ b/dvc/dependency/azure.py @@ -1,6 +1,7 @@ from dvc.dependency.base import BaseDependency from dvc.output.base import BaseOutput -from dvc.remote.azure import AzureRemoteTree + +from ..tree.azure import AzureRemoteTree class AzureDependency(BaseDependency, BaseOutput): diff --git a/dvc/dependency/http.py b/dvc/dependency/http.py index e8f0bf1e3d..4ddb18d621 100644 --- a/dvc/dependency/http.py +++ b/dvc/dependency/http.py @@ -1,6 +1,7 @@ from dvc.dependency.base import BaseDependency from dvc.output.base import BaseOutput -from dvc.remote.http import HTTPRemoteTree + +from ..tree.http import HTTPRemoteTree class HTTPDependency(BaseDependency, BaseOutput): diff --git a/dvc/dependency/https.py b/dvc/dependency/https.py index e95ac83f67..8d9b686435 100644 --- a/dvc/dependency/https.py +++ b/dvc/dependency/https.py @@ -1,5 +1,4 @@ -from dvc.remote.https import HTTPSRemoteTree - +from ..tree.https import HTTPSRemoteTree from .http import HTTPDependency diff --git a/dvc/main.py b/dvc/main.py index 7bf18bb2cc..dc737aa5f5 100644 --- a/dvc/main.py +++ b/dvc/main.py @@ -9,7 +9,7 @@ from dvc.exceptions import DvcException, DvcParserError, NotDvcRepoError from dvc.external_repo import clean_repos from dvc.logger import FOOTER, disable_other_loggers -from dvc.remote.pool import close_pools +from dvc.tree.pool import close_pools from dvc.utils import format_link # Workaround for CPython bug. See [1] and [2] for more info. diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index 3f7654d4e0..5942fb1125 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -10,12 +10,13 @@ from dvc.output.local import LocalOutput from dvc.output.s3 import S3Output from dvc.output.ssh import SSHOutput -from dvc.remote import get_cloud_tree -from dvc.remote.hdfs import HDFSRemoteTree -from dvc.remote.local import LocalRemoteTree -from dvc.remote.s3 import S3RemoteTree from dvc.scheme import Schemes +from ..tree import get_cloud_tree +from ..tree.hdfs import HDFSRemoteTree +from ..tree.local import LocalRemoteTree +from ..tree.s3 import S3RemoteTree + OUTS = [ HDFSOutput, S3Output, diff --git a/dvc/output/base.py b/dvc/output/base.py index aba44a7960..ccce194540 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -12,7 +12,8 @@ DvcException, RemoteCacheRequiredError, ) -from dvc.remote.base import BaseRemoteTree + +from ..tree.base import BaseRemoteTree logger = logging.getLogger(__name__) diff --git a/dvc/output/gs.py b/dvc/output/gs.py index ccab57a4f7..04c3bab373 100644 --- a/dvc/output/gs.py +++ b/dvc/output/gs.py @@ -1,5 +1,6 @@ from dvc.output.s3 import S3Output -from dvc.remote.gs import GSRemoteTree + +from ..tree.gs import GSRemoteTree class GSOutput(S3Output): diff --git a/dvc/output/hdfs.py b/dvc/output/hdfs.py index a6632db308..49b8414564 100644 --- a/dvc/output/hdfs.py +++ b/dvc/output/hdfs.py @@ -1,5 +1,6 @@ from dvc.output.base import BaseOutput -from dvc.remote.hdfs import HDFSRemoteTree + +from ..tree.hdfs import HDFSRemoteTree class HDFSOutput(BaseOutput): diff --git a/dvc/output/local.py b/dvc/output/local.py index 9783773578..fd2c8b2860 100644 --- a/dvc/output/local.py +++ b/dvc/output/local.py @@ -5,10 +5,11 @@ from dvc.exceptions import DvcException from dvc.istextfile import istextfile from dvc.output.base import BaseOutput -from dvc.remote.local import LocalRemoteTree from dvc.utils import relpath from dvc.utils.fs import path_isin +from ..tree.local import LocalRemoteTree + logger = logging.getLogger(__name__) diff --git a/dvc/output/s3.py b/dvc/output/s3.py index 92be85e510..0b139a0eac 100644 --- a/dvc/output/s3.py +++ b/dvc/output/s3.py @@ -1,5 +1,6 @@ from dvc.output.base import BaseOutput -from dvc.remote.s3 import S3RemoteTree + +from ..tree.s3 import S3RemoteTree class S3Output(BaseOutput): diff --git a/dvc/output/ssh.py b/dvc/output/ssh.py index 017f5de5a6..aac712bc36 100644 --- a/dvc/output/ssh.py +++ b/dvc/output/ssh.py @@ -1,5 +1,6 @@ from dvc.output.base import BaseOutput -from dvc.remote.ssh import SSHRemoteTree + +from ..tree.ssh import SSHRemoteTree class SSHOutput(BaseOutput): diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index d375ac202f..0c622d4b40 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -1,82 +1,7 @@ -import posixpath -from urllib.parse import urlparse - -from dvc.remote.azure import AzureRemoteTree -from dvc.remote.base import Remote -from dvc.remote.gdrive import GDriveRemoteTree -from dvc.remote.gs import GSRemoteTree -from dvc.remote.hdfs import HDFSRemoteTree -from dvc.remote.http import HTTPRemoteTree -from dvc.remote.https import HTTPSRemoteTree -from dvc.remote.local import LocalRemote, LocalRemoteTree -from dvc.remote.oss import OSSRemoteTree -from dvc.remote.s3 import S3RemoteTree -from dvc.remote.ssh import SSHRemote, SSHRemoteTree - -TREES = [ - AzureRemoteTree, - GDriveRemoteTree, - GSRemoteTree, - HDFSRemoteTree, - HTTPRemoteTree, - HTTPSRemoteTree, - S3RemoteTree, - SSHRemoteTree, - OSSRemoteTree, - # NOTE: LocalRemoteTree is the default -] - - -def _get_tree(remote_conf): - for tree_cls in TREES: - if tree_cls.supported(remote_conf): - return tree_cls - return LocalRemoteTree - - -def _get_conf(repo, **kwargs): - name = kwargs.get("name") - if name: - remote_conf = repo.config["remote"][name.lower()] - else: - remote_conf = kwargs - return _resolve_remote_refs(repo.config, remote_conf) - - -def _resolve_remote_refs(config, remote_conf): - # Support for cross referenced remotes. - # This will merge the settings, shadowing base ref with remote_conf. - # For example, having: - # - # dvc remote add server ssh://localhost - # dvc remote modify server user root - # dvc remote modify server ask_password true - # - # dvc remote add images remote://server/tmp/pictures - # dvc remote modify images user alice - # dvc remote modify images ask_password false - # dvc remote modify images password asdf1234 - # - # Results on a config dictionary like: - # - # { - # "url": "ssh://localhost/tmp/pictures", - # "user": "alice", - # "password": "asdf1234", - # "ask_password": False, - # } - parsed = urlparse(remote_conf["url"]) - if parsed.scheme != "remote": - return remote_conf - - base = config["remote"][parsed.netloc] - url = posixpath.join(base["url"], parsed.path.lstrip("/")) - return {**base, **remote_conf, "url": url} - - -def get_cloud_tree(repo, **kwargs): - remote_conf = _get_conf(repo, **kwargs) - return _get_tree(remote_conf)(repo, remote_conf) +from ..tree import get_cloud_tree +from .base import Remote +from .local import LocalRemote +from .ssh import SSHRemote def get_remote(repo, **kwargs): diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 91b116e700..acd9db2143 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -1,34 +1,19 @@ import hashlib -import itertools import json import logging -import tempfile -from concurrent.futures import ThreadPoolExecutor, as_completed from copy import copy -from functools import partial, wraps -from multiprocessing import cpu_count -from operator import itemgetter -from urllib.parse import urlparse +from functools import wraps from shortuuid import uuid import dvc.prompt as prompt -from dvc.exceptions import ( - CheckoutError, - ConfirmRemoveError, - DvcException, - DvcIgnoreInCollectedDirError, - RemoteCacheRequiredError, -) -from dvc.ignore import DvcIgnore -from dvc.path_info import PathInfo, URLInfo, WindowsPathInfo +from dvc.exceptions import CheckoutError, ConfirmRemoveError, DvcException +from dvc.path_info import WindowsPathInfo from dvc.progress import Tqdm from dvc.remote.index import RemoteIndex, RemoteIndexNoop from dvc.remote.slow_link_detection import slow_link_guard -from dvc.state import StateNoop -from dvc.utils import tmp_fname -from dvc.utils.fs import makedirs, move -from dvc.utils.http import open_url + +from ..tree.base import RemoteActionNotImplemented logger = logging.getLogger(__name__) @@ -46,24 +31,6 @@ } -class RemoteCmdError(DvcException): - def __init__(self, remote, cmd, ret, err): - super().__init__( - "{remote} command '{cmd}' finished with non-zero return code" - " {ret}': {err}".format(remote=remote, cmd=cmd, ret=ret, err=err) - ) - - -class RemoteActionNotImplemented(DvcException): - def __init__(self, action, scheme): - m = f"{action} is not supported for {scheme} remotes" - super().__init__(m) - - -class RemoteMissingDepsError(DvcException): - pass - - class DirCacheError(DvcException): def __init__(self, hash_): super().__init__( @@ -82,636 +49,6 @@ def wrapper(obj, named_cache, remote, *args, **kwargs): return wrapper -class BaseRemoteTree: - scheme = "base" - REQUIRES = {} - PATH_CLS = URLInfo - JOBS = 4 * cpu_count() - - PARAM_RELPATH = "relpath" - CHECKSUM_DIR_SUFFIX = ".dir" - HASH_JOBS = max(1, min(4, cpu_count() // 2)) - DEFAULT_VERIFY = False - LIST_OBJECT_PAGE_SIZE = 1000 - TRAVERSE_WEIGHT_MULTIPLIER = 5 - TRAVERSE_PREFIX_LEN = 3 - TRAVERSE_THRESHOLD_SIZE = 500000 - CAN_TRAVERSE = True - - CACHE_MODE = None - SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} - PARAM_CHECKSUM = None - - state = StateNoop() - - def __init__(self, repo, config): - self.repo = repo - self.config = config - - self._check_requires(config) - - shared = config.get("shared") - self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared] - - self.hash_jobs = ( - config.get("hash_jobs") - or (self.repo and self.repo.config["core"].get("hash_jobs")) - or self.HASH_JOBS - ) - self.verify = config.get("verify", self.DEFAULT_VERIFY) - self.path_info = None - - @classmethod - def get_missing_deps(cls): - import importlib - - missing = [] - for package, module in cls.REQUIRES.items(): - try: - importlib.import_module(module) - except ImportError: - missing.append(package) - - return missing - - def _check_requires(self, config): - missing = self.get_missing_deps() - if not missing: - return - - url = config.get("url", f"{self.scheme}://") - msg = ( - "URL '{}' is supported but requires these missing " - "dependencies: {}. If you have installed dvc using pip, " - "choose one of these options to proceed: \n" - "\n" - " 1) Install specific missing dependencies:\n" - " pip install {}\n" - " 2) Install dvc package that includes those missing " - "dependencies: \n" - " pip install 'dvc[{}]'\n" - " 3) Install dvc package with all possible " - "dependencies included: \n" - " pip install 'dvc[all]'\n" - "\n" - "If you have installed dvc from a binary package and you " - "are still seeing this message, please report it to us " - "using https://github.com/iterative/dvc/issues. Thank you!" - ).format(url, missing, " ".join(missing), self.scheme) - raise RemoteMissingDepsError(msg) - - @classmethod - def supported(cls, config): - if isinstance(config, (str, bytes)): - url = config - else: - url = config["url"] - - # NOTE: silently skipping remote, calling code should handle that - parsed = urlparse(url) - return parsed.scheme == cls.scheme - - @property - def file_mode(self): - return self._file_mode - - @property - def dir_mode(self): - return self._dir_mode - - @property - def cache(self): - return getattr(self.repo.cache, self.scheme) - - def open(self, path_info, mode="r", encoding=None): - if hasattr(self, "_generate_download_url"): - func = self._generate_download_url # noqa,pylint:disable=no-member - get_url = partial(func, path_info) - return open_url(get_url, mode=mode, encoding=encoding) - - raise RemoteActionNotImplemented("open", self.scheme) - - def exists(self, path_info): - raise NotImplementedError - - # pylint: disable=unused-argument - - def isdir(self, path_info): - """Optional: Overwrite only if the remote has a way to distinguish - between a directory and a file. - """ - return False - - def isfile(self, path_info): - """Optional: Overwrite only if the remote has a way to distinguish - between a directory and a file. - """ - return True - - def iscopy(self, path_info): - """Check if this file is an independent copy.""" - return False # We can't be sure by default - - def walk_files(self, path_info, **kwargs): - """Return a generator with `PathInfo`s to all the files. - - Optional kwargs: - prefix (bool): If true `path_info` will be treated as a prefix - rather than directory path. - """ - raise NotImplementedError - - def is_empty(self, path_info): - return False - - def remove(self, path_info): - raise RemoteActionNotImplemented("remove", self.scheme) - - def makedirs(self, path_info): - """Optional: Implement only if the remote needs to create - directories before copying/linking/moving data - """ - - def move(self, from_info, to_info, mode=None): - assert mode is None - self.copy(from_info, to_info) - self.remove(from_info) - - def copy(self, from_info, to_info): - raise RemoteActionNotImplemented("copy", self.scheme) - - def copy_fobj(self, fobj, to_info): - raise RemoteActionNotImplemented("copy_fobj", self.scheme) - - def symlink(self, from_info, to_info): - raise RemoteActionNotImplemented("symlink", self.scheme) - - def hardlink(self, from_info, to_info): - raise RemoteActionNotImplemented("hardlink", self.scheme) - - def reflink(self, from_info, to_info): - raise RemoteActionNotImplemented("reflink", self.scheme) - - @staticmethod - def protect(path_info): - pass - - def is_protected(self, path_info): - return False - - # pylint: enable=unused-argument - - @staticmethod - def unprotect(path_info): - pass - - @classmethod - def is_dir_hash(cls, hash_): - if not hash_: - return False - return hash_.endswith(cls.CHECKSUM_DIR_SUFFIX) - - def get_hash(self, path_info, tree=None, **kwargs): - assert isinstance(path_info, str) or path_info.scheme == self.scheme - - if not tree: - tree = self - - if not tree.exists(path_info): - return None - - if tree == self: - # pylint: disable=assignment-from-none - hash_ = self.state.get(path_info) - else: - hash_ = None - # If we have dir hash in state db, but dir cache file is lost, - # then we need to recollect the dir via .get_dir_hash() call below, - # see https://github.com/iterative/dvc/issues/2219 for context - if ( - hash_ - and self.is_dir_hash(hash_) - and not tree.exists(self.cache.hash_to_path_info(hash_)) - ): - hash_ = None - - if hash_: - return hash_ - - if tree.isdir(path_info): - hash_ = self.get_dir_hash(path_info, tree, **kwargs) - else: - hash_ = tree.get_file_hash(path_info) - - if hash_ and self.exists(path_info): - self.state.save(path_info, hash_) - - return hash_ - - def get_file_hash(self, path_info): - raise NotImplementedError - - def get_dir_hash(self, path_info, tree, **kwargs): - if not self.cache: - raise RemoteCacheRequiredError(path_info) - - dir_info = self._collect_dir(path_info, tree, **kwargs) - return self._save_dir_info(dir_info, path_info) - - def hash_to_path_info(self, hash_): - return self.path_info / hash_[0:2] / hash_[2:] - - def path_to_hash(self, path): - parts = self.PATH_CLS(path).parts[-2:] - - if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): - raise ValueError(f"Bad cache file path '{path}'") - - return "".join(parts) - - def save_info(self, path_info, tree=None, **kwargs): - return { - self.PARAM_CHECKSUM: self.get_hash(path_info, tree=tree, **kwargs) - } - - @staticmethod - def _calculate_hashes(file_infos, tree): - file_infos = list(file_infos) - with Tqdm( - total=len(file_infos), - unit="md5", - desc="Computing file/dir hashes (only done once)", - ) as pbar: - worker = pbar.wrap_fn(tree.get_file_hash) - with ThreadPoolExecutor(max_workers=tree.hash_jobs) as executor: - tasks = executor.map(worker, file_infos) - hashes = dict(zip(file_infos, tasks)) - return hashes - - def _collect_dir(self, path_info, tree, **kwargs): - file_infos = set() - - for fname in tree.walk_files(path_info, **kwargs): - if DvcIgnore.DVCIGNORE_FILE == fname.name: - raise DvcIgnoreInCollectedDirError(fname.parent) - - file_infos.add(fname) - - hashes = {fi: self.state.get(fi) for fi in file_infos} - not_in_state = {fi for fi, hash_ in hashes.items() if hash_ is None} - - new_hashes = self._calculate_hashes(not_in_state, tree) - hashes.update(new_hashes) - - result = [ - { - self.PARAM_CHECKSUM: hashes[fi], - # NOTE: this is lossy transformation: - # "hey\there" -> "hey/there" - # "hey/there" -> "hey/there" - # The latter is fine filename on Windows, which - # will transform to dir/file on back transform. - # - # Yes, this is a BUG, as long as we permit "/" in - # filenames on Windows and "\" on Unix - self.PARAM_RELPATH: fi.relative_to(path_info).as_posix(), - } - for fi in file_infos - ] - - # Sorting the list by path to ensure reproducibility - return sorted(result, key=itemgetter(self.PARAM_RELPATH)) - - def _save_dir_info(self, dir_info, path_info): - hash_, tmp_info = self._get_dir_info_hash(dir_info) - new_info = self.cache.hash_to_path_info(hash_) - if self.cache.changed_cache_file(hash_): - self.cache.tree.makedirs(new_info.parent) - self.cache.tree.move( - tmp_info, new_info, mode=self.cache.CACHE_MODE - ) - - if self.exists(path_info): - self.state.save(path_info, hash_) - self.state.save(new_info, hash_) - - return hash_ - - def _get_dir_info_hash(self, dir_info): - tmp = tempfile.NamedTemporaryFile(delete=False).name - with open(tmp, "w+") as fobj: - json.dump(dir_info, fobj, sort_keys=True) - - tree = self.cache.tree - from_info = PathInfo(tmp) - to_info = tree.path_info / tmp_fname("") - tree.upload(from_info, to_info, no_progress_bar=True) - - hash_ = tree.get_file_hash(to_info) + self.CHECKSUM_DIR_SUFFIX - return hash_, to_info - - def upload(self, from_info, to_info, name=None, no_progress_bar=False): - if not hasattr(self, "_upload"): - raise RemoteActionNotImplemented("upload", self.scheme) - - if to_info.scheme != self.scheme: - raise NotImplementedError - - if from_info.scheme != "local": - raise NotImplementedError - - logger.debug("Uploading '%s' to '%s'", from_info, to_info) - - name = name or from_info.name - - self._upload( # noqa, pylint: disable=no-member - from_info.fspath, - to_info, - name=name, - no_progress_bar=no_progress_bar, - ) - - def download( - self, - from_info, - to_info, - name=None, - no_progress_bar=False, - file_mode=None, - dir_mode=None, - ): - if not hasattr(self, "_download"): - raise RemoteActionNotImplemented("download", self.scheme) - - if from_info.scheme != self.scheme: - raise NotImplementedError - - if to_info.scheme == self.scheme != "local": - self.copy(from_info, to_info) - return 0 - - if to_info.scheme != "local": - raise NotImplementedError - - if self.isdir(from_info): - return self._download_dir( - from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ) - return self._download_file( - from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ) - - def _download_dir( - self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ): - from_infos = list(self.walk_files(from_info)) - to_infos = ( - to_info / info.relative_to(from_info) for info in from_infos - ) - - with Tqdm( - total=len(from_infos), - desc="Downloading directory", - unit="Files", - disable=no_progress_bar, - ) as pbar: - download_files = pbar.wrap_fn( - partial( - self._download_file, - name=name, - no_progress_bar=True, - file_mode=file_mode, - dir_mode=dir_mode, - ) - ) - with ThreadPoolExecutor(max_workers=self.JOBS) as executor: - futures = [ - executor.submit(download_files, from_info, to_info) - for from_info, to_info in zip(from_infos, to_infos) - ] - - # NOTE: unlike pulling/fetching cache, where we need to - # download everything we can, not raising an error here might - # turn very ugly, as the user might think that he has - # downloaded a complete directory, while having a partial one, - # which might cause unexpected results in his pipeline. - for future in as_completed(futures): - # NOTE: executor won't let us raise until all futures that - # it has are finished, so we need to cancel them ourselves - # before re-raising. - exc = future.exception() - if exc: - for entry in futures: - entry.cancel() - raise exc - - def _download_file( - self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode - ): - makedirs(to_info.parent, exist_ok=True, mode=dir_mode) - - logger.debug("Downloading '%s' to '%s'", from_info, to_info) - name = name or to_info.name - - tmp_file = tmp_fname(to_info) - - self._download( # noqa, pylint: disable=no-member - from_info, tmp_file, name=name, no_progress_bar=no_progress_bar - ) - - move(tmp_file, to_info, mode=file_mode) - - def list_paths(self, prefix=None, progress_callback=None): - if prefix: - if len(prefix) > 2: - path_info = self.path_info / prefix[:2] / prefix[2:] - else: - path_info = self.path_info / prefix[:2] - prefix = True - else: - path_info = self.path_info - prefix = False - if progress_callback: - for file_info in self.walk_files(path_info, prefix=prefix): - progress_callback() - yield file_info.path - else: - yield from self.walk_files(path_info, prefix=prefix) - - def list_hashes(self, prefix=None, progress_callback=None): - """Iterate over hashes in this tree. - - If `prefix` is specified, only hashes which begin with `prefix` - will be returned. - """ - for path in self.list_paths(prefix, progress_callback): - try: - yield self.path_to_hash(path) - except ValueError: - logger.debug( - "'%s' doesn't look like a cache file, skipping", path - ) - - def all(self, jobs=None, name=None): - """Iterate over all hashes in this tree. - - Hashes will be fetched in parallel threads according to prefix - (except for small remotes) and a progress bar will be displayed. - """ - logger.debug( - "Fetching all hashes from '{}'".format( - name if name else "remote cache" - ) - ) - - if not self.CAN_TRAVERSE: - return self.list_hashes() - - remote_size, remote_hashes = self.estimate_remote_size(name=name) - return self.list_hashes_traverse( - remote_size, remote_hashes, jobs, name - ) - - def _hashes_with_limit(self, limit, prefix=None, progress_callback=None): - count = 0 - for hash_ in self.list_hashes(prefix, progress_callback): - yield hash_ - count += 1 - if count > limit: - logger.debug( - "`list_hashes()` returned max '{}' hashes, " - "skipping remaining results".format(limit) - ) - return - - def _max_estimation_size(self, hashes): - # Max remote size allowed for us to use traverse method - return max( - self.TRAVERSE_THRESHOLD_SIZE, - len(hashes) - / self.TRAVERSE_WEIGHT_MULTIPLIER - * self.LIST_OBJECT_PAGE_SIZE, - ) - - def estimate_remote_size(self, hashes=None, name=None): - """Estimate tree size based on number of entries beginning with - "00..." prefix. - """ - prefix = "0" * self.TRAVERSE_PREFIX_LEN - total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) - if hashes: - max_hashes = self._max_estimation_size(hashes) - else: - max_hashes = None - - with Tqdm( - desc="Estimating size of " - + (f"cache in '{name}'" if name else "remote cache"), - unit="file", - ) as pbar: - - def update(n=1): - pbar.update(n * total_prefixes) - - if max_hashes: - hashes = self._hashes_with_limit( - max_hashes / total_prefixes, prefix, update - ) - else: - hashes = self.list_hashes(prefix, update) - - remote_hashes = set(hashes) - if remote_hashes: - remote_size = total_prefixes * len(remote_hashes) - else: - remote_size = total_prefixes - logger.debug(f"Estimated remote size: {remote_size} files") - return remote_size, remote_hashes - - def list_hashes_traverse( - self, remote_size, remote_hashes, jobs=None, name=None - ): - """Iterate over all hashes found in this tree. - Hashes are fetched in parallel according to prefix, except in - cases where the remote size is very small. - - All hashes from the remote (including any from the size - estimation step passed via the `remote_hashes` argument) will be - returned. - - NOTE: For large remotes the list of hashes will be very - big(e.g. 100M entries, md5 for each is 32 bytes, so ~3200Mb list) - and we don't really need all of it at the same time, so it makes - sense to use a generator to gradually iterate over it, without - keeping all of it in memory. - """ - num_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE - if num_pages < 256 / self.JOBS: - # Fetching prefixes in parallel requires at least 255 more - # requests, for small enough remotes it will be faster to fetch - # entire cache without splitting it into prefixes. - # - # NOTE: this ends up re-fetching hashes that were already - # fetched during remote size estimation - traverse_prefixes = [None] - initial = 0 - else: - yield from remote_hashes - initial = len(remote_hashes) - traverse_prefixes = [f"{i:02x}" for i in range(1, 256)] - if self.TRAVERSE_PREFIX_LEN > 2: - traverse_prefixes += [ - "{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN) - for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2)) - ] - with Tqdm( - desc="Querying " - + (f"cache in '{name}'" if name else "remote cache"), - total=remote_size, - initial=initial, - unit="file", - ) as pbar: - - def list_with_update(prefix): - return list( - self.list_hashes( - prefix=prefix, progress_callback=pbar.update - ) - ) - - with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: - in_remote = executor.map(list_with_update, traverse_prefixes,) - yield from itertools.chain.from_iterable(in_remote) - - def list_hashes_exists(self, hashes, jobs=None, name=None): - """Return list of the specified hashes which exist in this tree. - Hashes will be queried individually. - """ - logger.debug( - "Querying {} hashes via object_exists".format(len(hashes)) - ) - with Tqdm( - desc="Querying " - + ("cache in " + name if name else "remote cache"), - total=len(hashes), - unit="file", - ) as pbar: - - def exists_with_progress(path_info): - ret = self.exists(path_info) - pbar.update_msg(str(path_info)) - return ret - - with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: - path_infos = map(self.hash_to_path_info, hashes) - in_remote = executor.map(exists_with_progress, path_infos) - ret = list(itertools.compress(hashes, in_remote)) - return ret - - def _remove_unpacked_dir(self, hash_): - pass - - class Remote: """Cloud remote class. @@ -735,41 +72,15 @@ def __init__(self, tree): else: self.index = RemoteIndexNoop() - @property - def path_info(self): - return self.tree.path_info - def __repr__(self): return "{class_name}: '{path_info}'".format( class_name=type(self).__name__, - path_info=self.path_info or "No path", + path_info=self.tree.path_info or "No path", ) @property def cache(self): - return getattr(self.repo.cache, self.scheme) - - @property - def scheme(self): - return self.tree.scheme - - def is_dir_hash(self, hash_): - return self.tree.is_dir_hash(hash_) - - def get_hash(self, path_info, **kwargs): - return self.tree.get_hash(path_info, **kwargs) - - def hash_to_path_info(self, hash_): - return self.tree.hash_to_path_info(hash_) - - def path_to_hash(self, path): - return self.tree.path_to_hash(path) - - def save_info(self, path_info, **kwargs): - return self.tree.save_info(path_info, **kwargs) - - def open(self, *args, **kwargs): - return self.tree.open(*args, **kwargs) + return getattr(self.repo.cache, self.tree.scheme) def hashes_exist(self, hashes, jobs=None, name=None): """Check if the given hashes are stored in the remote. @@ -893,7 +204,7 @@ class CloudCache: """Cloud cache class.""" DEFAULT_CACHE_TYPES = ["copy"] - CACHE_MODE = BaseRemoteTree.CACHE_MODE + CACHE_MODE = None def __init__(self, tree): self.tree = tree @@ -1123,7 +434,7 @@ def _cache_is_copy(self, path_info): return True workspace_file = path_info.with_name("." + uuid()) - test_cache_file = self.path_info / ".cache_type_test_file" + test_cache_file = self.tree.path_info / ".cache_type_test_file" if not self.tree.exists(test_cache_file): with self.tree.open(test_cache_file, "wb") as fobj: fobj.write(bytes(1)) diff --git a/dvc/remote/local.py b/dvc/remote/local.py index c691a78ab1..b7ca83930c 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -1,14 +1,12 @@ import errno import logging import os -import stat from concurrent.futures import ThreadPoolExecutor, as_completed from functools import partial, wraps from funcy import cached_property, concat -from shortuuid import uuid -from dvc.exceptions import DownloadError, DvcException, UploadError +from dvc.exceptions import DownloadError, UploadError from dvc.path_info import PathInfo from dvc.progress import Tqdm from dvc.remote.base import ( @@ -16,309 +14,15 @@ STATUS_MAP, STATUS_MISSING, STATUS_NEW, - BaseRemoteTree, CloudCache, Remote, index_locked, ) from dvc.remote.index import RemoteIndexNoop -from dvc.scheme import Schemes -from dvc.scm.tree import WorkingTree, is_working_tree -from dvc.system import System -from dvc.utils import file_md5, relpath, tmp_fname -from dvc.utils.fs import ( - copy_fobj_to_file, - copyfile, - makedirs, - move, - remove, - walk_files, -) - -logger = logging.getLogger(__name__) - - -class LocalRemoteTree(BaseRemoteTree): - scheme = Schemes.LOCAL - PATH_CLS = PathInfo - PARAM_CHECKSUM = "md5" - PARAM_PATH = "path" - TRAVERSE_PREFIX_LEN = 2 - UNPACKED_DIR_SUFFIX = ".unpacked" - - CACHE_MODE = 0o444 - SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)} - - def __init__(self, repo, config): - super().__init__(repo, config) - url = config.get("url") - self.path_info = self.PATH_CLS(url) if url else None - - @property - def state(self): - from dvc.state import StateNoop - - return self.repo.state if self.repo else StateNoop() - - @cached_property - def work_tree(self): - # When using repo.brancher, repo.tree may change to/from WorkingTree to - # GitTree arbitarily. When repo.tree is GitTree, local cache needs to - # use its own WorkingTree instance. - if self.repo: - return WorkingTree(self.repo.root_dir) - return None - - @staticmethod - def open(path_info, mode="r", encoding=None): - return open(path_info, mode=mode, encoding=encoding) - - def exists(self, path_info): - assert isinstance(path_info, str) or path_info.scheme == "local" - if not self.repo: - return os.path.exists(path_info) - return self.work_tree.exists(path_info) - - def isfile(self, path_info): - if not self.repo: - return os.path.isfile(path_info) - return self.work_tree.isfile(path_info) - - def isdir(self, path_info): - if not self.repo: - return os.path.isdir(path_info) - return self.work_tree.isdir(path_info) - - def iscopy(self, path_info): - return not ( - System.is_symlink(path_info) or System.is_hardlink(path_info) - ) - - def walk_files(self, path_info, **kwargs): - for fname in self.work_tree.walk_files(path_info): - yield PathInfo(fname) - - def is_empty(self, path_info): - path = path_info.fspath - - if self.isfile(path_info) and os.path.getsize(path) == 0: - return True - - if self.isdir(path_info) and len(os.listdir(path)) == 0: - return True - - return False - - def remove(self, path_info): - if isinstance(path_info, PathInfo): - if path_info.scheme != "local": - raise NotImplementedError - path = path_info.fspath - else: - path = path_info - - if self.exists(path): - remove(path) - - def makedirs(self, path_info): - makedirs(path_info, exist_ok=True, mode=self.dir_mode) - - def move(self, from_info, to_info, mode=None): - if from_info.scheme != "local" or to_info.scheme != "local": - raise NotImplementedError - - self.makedirs(to_info.parent) - - if mode is None: - if self.isfile(from_info): - mode = self.file_mode - else: - mode = self.dir_mode - - move(from_info, to_info, mode=mode) - - def copy(self, from_info, to_info): - tmp_info = to_info.parent / tmp_fname(to_info.name) - try: - System.copy(from_info, tmp_info) - os.chmod(tmp_info, self.file_mode) - os.rename(tmp_info, to_info) - except Exception: - self.remove(tmp_info) - raise - - def copy_fobj(self, fobj, to_info): - self.makedirs(to_info.parent) - tmp_info = to_info.parent / tmp_fname(to_info.name) - try: - copy_fobj_to_file(fobj, tmp_info) - os.chmod(tmp_info, self.file_mode) - os.rename(tmp_info, to_info) - except Exception: - self.remove(tmp_info) - raise - - @staticmethod - def symlink(from_info, to_info): - System.symlink(from_info, to_info) - - @staticmethod - def is_symlink(path_info): - return System.is_symlink(path_info) - - def hardlink(self, from_info, to_info): - # If there are a lot of empty files (which happens a lot in datasets), - # and the cache type is `hardlink`, we might reach link limits and - # will get something like: `too many links error` - # - # This is because all those empty files will have the same hash - # (i.e. 68b329da9893e34099c7d8ad5cb9c940), therefore, they will be - # linked to the same file in the cache. - # - # From https://en.wikipedia.org/wiki/Hard_link - # * ext4 limits the number of hard links on a file to 65,000 - # * Windows with NTFS has a limit of 1024 hard links on a file - # - # That's why we simply create an empty file rather than a link. - if self.getsize(from_info) == 0: - self.open(to_info, "w").close() - - logger.debug( - "Created empty file: {src} -> {dest}".format( - src=str(from_info), dest=str(to_info) - ) - ) - return - System.hardlink(from_info, to_info) +from ..tree.local import LocalRemoteTree - @staticmethod - def is_hardlink(path_info): - return System.is_hardlink(path_info) - - def reflink(self, from_info, to_info): - tmp_info = to_info.parent / tmp_fname(to_info.name) - System.reflink(from_info, tmp_info) - # NOTE: reflink has its own separate inode, so you can set permissions - # that are different from the source. - os.chmod(tmp_info, self.file_mode) - os.rename(tmp_info, to_info) - - def _unprotect_file(self, path): - if System.is_symlink(path) or System.is_hardlink(path): - logger.debug(f"Unprotecting '{path}'") - tmp = os.path.join(os.path.dirname(path), "." + uuid()) - - # The operations order is important here - if some application - # would access the file during the process of copyfile then it - # would get only the part of file. So, at first, the file should be - # copied with the temporary name, and then original file should be - # replaced by new. - copyfile(path, tmp, name="Unprotecting '{}'".format(relpath(path))) - remove(path) - os.rename(tmp, path) - - else: - logger.debug( - "Skipping copying for '{}', since it is not " - "a symlink or a hardlink.".format(path) - ) - - os.chmod(path, self.file_mode) - - def _unprotect_dir(self, path): - assert is_working_tree(self.repo.tree) - - for fname in self.repo.tree.walk_files(path): - self._unprotect_file(fname) - - def unprotect(self, path_info): - path = path_info.fspath - if not os.path.exists(path): - raise DvcException(f"can't unprotect non-existing data '{path}'") - - if os.path.isdir(path): - self._unprotect_dir(path) - else: - self._unprotect_file(path) - - def protect(self, path_info): - path = os.fspath(path_info) - mode = self.CACHE_MODE - - try: - os.chmod(path, mode) - except OSError as exc: - # There is nothing we need to do in case of a read-only file system - if exc.errno == errno.EROFS: - return - - # In shared cache scenario, we might not own the cache file, so we - # need to check if cache file is already protected. - if exc.errno not in [errno.EPERM, errno.EACCES]: - raise - - actual = stat.S_IMODE(os.stat(path).st_mode) - if actual != mode: - raise - - def is_protected(self, path_info): - try: - mode = os.stat(path_info).st_mode - except FileNotFoundError: - return False - - return stat.S_IMODE(mode) == self.CACHE_MODE - - def get_file_hash(self, path_info): - return file_md5(path_info)[0] - - @staticmethod - def getsize(path_info): - return os.path.getsize(path_info) - - def _upload( - self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs - ): - makedirs(to_info.parent, exist_ok=True) - - tmp_file = tmp_fname(to_info) - copyfile( - from_file, tmp_file, name=name, no_progress_bar=no_progress_bar - ) - - self.protect(tmp_file) - os.rename(tmp_file, to_info) - - @staticmethod - def _download( - from_info, to_file, name=None, no_progress_bar=False, **_kwargs - ): - copyfile( - from_info, to_file, no_progress_bar=no_progress_bar, name=name - ) - - def list_paths(self, prefix=None, progress_callback=None): - assert self.path_info is not None - if prefix: - path_info = self.path_info / prefix[:2] - if not self.exists(path_info): - return - else: - path_info = self.path_info - # NOTE: use utils.fs walk_files since tree.walk_files will not follow - # symlinks - if progress_callback: - for path in walk_files(path_info): - progress_callback() - yield path - else: - yield from walk_files(path_info) - - def _remove_unpacked_dir(self, hash_): - info = self.hash_to_path_info(hash_) - path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) - self.remove(path_info) +logger = logging.getLogger(__name__) def _log_exceptions(func, operation): @@ -440,7 +144,9 @@ def _status( {dir_hash: set(file_hash, ...)} which can be used to map a .dir file to its file contents. """ - logger.debug(f"Preparing to collect status from {remote.path_info}") + logger.debug( + f"Preparing to collect status from {remote.tree.path_info}" + ) md5s = set(named_cache.scheme_keys(self.scheme)) logger.debug("Collecting information from local cache...") @@ -466,7 +172,7 @@ def _status( if md5s: remote_exists.update( remote.hashes_exist( - md5s, jobs=jobs, name=str(remote.path_info) + md5s, jobs=jobs, name=str(remote.tree.path_info) ) ) return self._make_status( @@ -558,7 +264,7 @@ def _get_plans(self, download, remote, status_info, status): ): if info["status"] == status: cache.append(self.hash_to_path_info(md5)) - path_infos.append(remote.hash_to_path_info(md5)) + path_infos.append(remote.tree.hash_to_path_info(md5)) names.append(info["name"]) hashes.append(md5) @@ -582,7 +288,7 @@ def _process( logger.debug( "Preparing to {} '{}'".format( "download data from" if download else "upload data to", - remote.path_info, + remote.tree.path_info, ) ) diff --git a/dvc/remote/ssh.py b/dvc/remote/ssh.py new file mode 100644 index 0000000000..84a261acd6 --- /dev/null +++ b/dvc/remote/ssh.py @@ -0,0 +1,73 @@ +import errno +import itertools +import logging +from concurrent.futures import ThreadPoolExecutor + +from dvc.progress import Tqdm +from dvc.utils import to_chunks + +from .base import Remote + +logger = logging.getLogger(__name__) + + +class SSHRemote(Remote): + def batch_exists(self, path_infos, callback): + def _exists(chunk_and_channel): + chunk, channel = chunk_and_channel + ret = [] + for path in chunk: + try: + channel.stat(path) + ret.append(True) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + ret.append(False) + callback(path) + return ret + + with self.tree.ssh(path_infos[0]) as ssh: + channels = ssh.open_max_sftp_channels() + max_workers = len(channels) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + paths = [path_info.path for path_info in path_infos] + chunks = to_chunks(paths, num_chunks=max_workers) + chunks_and_channels = zip(chunks, channels) + outcome = executor.map(_exists, chunks_and_channels) + results = list(itertools.chain.from_iterable(outcome)) + + return results + + def hashes_exist(self, hashes, jobs=None, name=None): + """This is older implementation used in remote/base.py + We are reusing it in RemoteSSH, because SSH's batch_exists proved to be + faster than current approach (relying on exists(path_info)) applied in + remote/base. + """ + if not self.tree.CAN_TRAVERSE: + return list(set(hashes) & set(self.tree.all())) + + # possibly prompt for credentials before "Querying" progress output + self.tree.ensure_credentials() + + with Tqdm( + desc="Querying " + + ("cache in " + name if name else "remote cache"), + total=len(hashes), + unit="file", + ) as pbar: + + def exists_with_progress(chunks): + return self.batch_exists(chunks, callback=pbar.update_msg) + + with ThreadPoolExecutor( + max_workers=jobs or self.tree.JOBS + ) as executor: + path_infos = [self.tree.hash_to_path_info(x) for x in hashes] + chunks = to_chunks(path_infos, num_chunks=self.tree.JOBS) + results = executor.map(exists_with_progress, chunks) + in_remote = itertools.chain.from_iterable(results) + ret = list(itertools.compress(hashes, in_remote)) + return ret diff --git a/dvc/repo/tree.py b/dvc/repo/tree.py index bc6e7fc9dd..5b78e2fef3 100644 --- a/dvc/repo/tree.py +++ b/dvc/repo/tree.py @@ -85,8 +85,8 @@ def open( else: checksum = out.checksum try: - remote_info = remote_obj.hash_to_path_info(checksum) - return remote_obj.open( + remote_info = remote_obj.tree.hash_to_path_info(checksum) + return remote_obj.tree.open( remote_info, mode=mode, encoding=encoding ) except RemoteActionNotImplemented: diff --git a/dvc/stage/utils.py b/dvc/stage/utils.py index ab7f61f7f2..14b22a73df 100644 --- a/dvc/stage/utils.py +++ b/dvc/stage/utils.py @@ -8,8 +8,8 @@ from dvc.utils.fs import path_isin from ..dependency import ParamsDependency -from ..remote.local import LocalRemoteTree -from ..remote.s3 import S3RemoteTree +from ..tree.local import LocalRemoteTree +from ..tree.s3 import S3RemoteTree from ..utils import dict_md5, format_link, relpath from .exceptions import ( MissingDataSource, diff --git a/dvc/tree/__init__.py b/dvc/tree/__init__.py new file mode 100644 index 0000000000..c41d076475 --- /dev/null +++ b/dvc/tree/__init__.py @@ -0,0 +1,78 @@ +import posixpath +from urllib.parse import urlparse + +from .azure import AzureRemoteTree +from .gdrive import GDriveRemoteTree +from .gs import GSRemoteTree +from .hdfs import HDFSRemoteTree +from .http import HTTPRemoteTree +from .https import HTTPSRemoteTree +from .local import LocalRemoteTree +from .oss import OSSRemoteTree +from .s3 import S3RemoteTree +from .ssh import SSHRemoteTree + +TREES = [ + AzureRemoteTree, + GDriveRemoteTree, + GSRemoteTree, + HDFSRemoteTree, + HTTPRemoteTree, + HTTPSRemoteTree, + S3RemoteTree, + SSHRemoteTree, + OSSRemoteTree, + # NOTE: LocalRemoteTree is the default +] + + +def _get_tree(remote_conf): + for tree_cls in TREES: + if tree_cls.supported(remote_conf): + return tree_cls + return LocalRemoteTree + + +def _get_conf(repo, **kwargs): + name = kwargs.get("name") + if name: + remote_conf = repo.config["remote"][name.lower()] + else: + remote_conf = kwargs + return _resolve_remote_refs(repo.config, remote_conf) + + +def _resolve_remote_refs(config, remote_conf): + # Support for cross referenced remotes. + # This will merge the settings, shadowing base ref with remote_conf. + # For example, having: + # + # dvc remote add server ssh://localhost + # dvc remote modify server user root + # dvc remote modify server ask_password true + # + # dvc remote add images remote://server/tmp/pictures + # dvc remote modify images user alice + # dvc remote modify images ask_password false + # dvc remote modify images password asdf1234 + # + # Results on a config dictionary like: + # + # { + # "url": "ssh://localhost/tmp/pictures", + # "user": "alice", + # "password": "asdf1234", + # "ask_password": False, + # } + parsed = urlparse(remote_conf["url"]) + if parsed.scheme != "remote": + return remote_conf + + base = config["remote"][parsed.netloc] + url = posixpath.join(base["url"], parsed.path.lstrip("/")) + return {**base, **remote_conf, "url": url} + + +def get_cloud_tree(repo, **kwargs): + remote_conf = _get_conf(repo, **kwargs) + return _get_tree(remote_conf)(repo, remote_conf) diff --git a/dvc/remote/azure.py b/dvc/tree/azure.py similarity index 99% rename from dvc/remote/azure.py rename to dvc/tree/azure.py index 77b831fc19..03c18617d9 100644 --- a/dvc/remote/azure.py +++ b/dvc/tree/azure.py @@ -7,9 +7,10 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes +from .base import BaseRemoteTree + logger = logging.getLogger(__name__) diff --git a/dvc/tree/base.py b/dvc/tree/base.py new file mode 100644 index 0000000000..00a982aea7 --- /dev/null +++ b/dvc/tree/base.py @@ -0,0 +1,703 @@ +import itertools +import json +import logging +import tempfile +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import partial, wraps +from multiprocessing import cpu_count +from operator import itemgetter +from urllib.parse import urlparse + +from dvc.exceptions import ( + DvcException, + DvcIgnoreInCollectedDirError, + RemoteCacheRequiredError, +) +from dvc.ignore import DvcIgnore +from dvc.path_info import PathInfo, URLInfo +from dvc.progress import Tqdm +from dvc.state import StateNoop +from dvc.utils import tmp_fname +from dvc.utils.fs import makedirs, move +from dvc.utils.http import open_url + +logger = logging.getLogger(__name__) + +STATUS_OK = 1 +STATUS_MISSING = 2 +STATUS_NEW = 3 +STATUS_DELETED = 4 + +STATUS_MAP = { + # (local_exists, remote_exists) + (True, True): STATUS_OK, + (False, False): STATUS_MISSING, + (True, False): STATUS_NEW, + (False, True): STATUS_DELETED, +} + + +class RemoteCmdError(DvcException): + def __init__(self, remote, cmd, ret, err): + super().__init__( + "{remote} command '{cmd}' finished with non-zero return code" + " {ret}': {err}".format(remote=remote, cmd=cmd, ret=ret, err=err) + ) + + +class RemoteActionNotImplemented(DvcException): + def __init__(self, action, scheme): + m = f"{action} is not supported for {scheme} remotes" + super().__init__(m) + + +class RemoteMissingDepsError(DvcException): + pass + + +class DirCacheError(DvcException): + def __init__(self, hash_): + super().__init__( + f"Failed to load dir cache for hash value: '{hash_}'." + ) + + +def index_locked(f): + @wraps(f) + def wrapper(obj, named_cache, remote, *args, **kwargs): + if hasattr(remote, "index"): + with remote.index: + return f(obj, named_cache, remote, *args, **kwargs) + return f(obj, named_cache, remote, *args, **kwargs) + + return wrapper + + +class BaseRemoteTree: + scheme = "base" + REQUIRES = {} + PATH_CLS = URLInfo + JOBS = 4 * cpu_count() + + PARAM_RELPATH = "relpath" + CHECKSUM_DIR_SUFFIX = ".dir" + HASH_JOBS = max(1, min(4, cpu_count() // 2)) + DEFAULT_VERIFY = False + LIST_OBJECT_PAGE_SIZE = 1000 + TRAVERSE_WEIGHT_MULTIPLIER = 5 + TRAVERSE_PREFIX_LEN = 3 + TRAVERSE_THRESHOLD_SIZE = 500000 + CAN_TRAVERSE = True + + CACHE_MODE = None + SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} + PARAM_CHECKSUM = None + + state = StateNoop() + + def __init__(self, repo, config): + self.repo = repo + self.config = config + + self._check_requires(config) + + shared = config.get("shared") + self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared] + + self.hash_jobs = ( + config.get("hash_jobs") + or (self.repo and self.repo.config["core"].get("hash_jobs")) + or self.HASH_JOBS + ) + self.verify = config.get("verify", self.DEFAULT_VERIFY) + self.path_info = None + + @classmethod + def get_missing_deps(cls): + import importlib + + missing = [] + for package, module in cls.REQUIRES.items(): + try: + importlib.import_module(module) + except ImportError: + missing.append(package) + + return missing + + def _check_requires(self, config): + missing = self.get_missing_deps() + if not missing: + return + + url = config.get("url", f"{self.scheme}://") + msg = ( + "URL '{}' is supported but requires these missing " + "dependencies: {}. If you have installed dvc using pip, " + "choose one of these options to proceed: \n" + "\n" + " 1) Install specific missing dependencies:\n" + " pip install {}\n" + " 2) Install dvc package that includes those missing " + "dependencies: \n" + " pip install 'dvc[{}]'\n" + " 3) Install dvc package with all possible " + "dependencies included: \n" + " pip install 'dvc[all]'\n" + "\n" + "If you have installed dvc from a binary package and you " + "are still seeing this message, please report it to us " + "using https://github.com/iterative/dvc/issues. Thank you!" + ).format(url, missing, " ".join(missing), self.scheme) + raise RemoteMissingDepsError(msg) + + @classmethod + def supported(cls, config): + if isinstance(config, (str, bytes)): + url = config + else: + url = config["url"] + + # NOTE: silently skipping remote, calling code should handle that + parsed = urlparse(url) + return parsed.scheme == cls.scheme + + @property + def file_mode(self): + return self._file_mode + + @property + def dir_mode(self): + return self._dir_mode + + @property + def cache(self): + return getattr(self.repo.cache, self.scheme) + + def open(self, path_info, mode="r", encoding=None): + if hasattr(self, "_generate_download_url"): + func = self._generate_download_url # noqa,pylint:disable=no-member + get_url = partial(func, path_info) + return open_url(get_url, mode=mode, encoding=encoding) + + raise RemoteActionNotImplemented("open", self.scheme) + + def exists(self, path_info): + raise NotImplementedError + + # pylint: disable=unused-argument + + def isdir(self, path_info): + """Optional: Overwrite only if the remote has a way to distinguish + between a directory and a file. + """ + return False + + def isfile(self, path_info): + """Optional: Overwrite only if the remote has a way to distinguish + between a directory and a file. + """ + return True + + def iscopy(self, path_info): + """Check if this file is an independent copy.""" + return False # We can't be sure by default + + def walk_files(self, path_info, **kwargs): + """Return a generator with `PathInfo`s to all the files. + + Optional kwargs: + prefix (bool): If true `path_info` will be treated as a prefix + rather than directory path. + """ + raise NotImplementedError + + def is_empty(self, path_info): + return False + + def remove(self, path_info): + raise RemoteActionNotImplemented("remove", self.scheme) + + def makedirs(self, path_info): + """Optional: Implement only if the remote needs to create + directories before copying/linking/moving data + """ + + def move(self, from_info, to_info, mode=None): + assert mode is None + self.copy(from_info, to_info) + self.remove(from_info) + + def copy(self, from_info, to_info): + raise RemoteActionNotImplemented("copy", self.scheme) + + def copy_fobj(self, fobj, to_info): + raise RemoteActionNotImplemented("copy_fobj", self.scheme) + + def symlink(self, from_info, to_info): + raise RemoteActionNotImplemented("symlink", self.scheme) + + def hardlink(self, from_info, to_info): + raise RemoteActionNotImplemented("hardlink", self.scheme) + + def reflink(self, from_info, to_info): + raise RemoteActionNotImplemented("reflink", self.scheme) + + @staticmethod + def protect(path_info): + pass + + def is_protected(self, path_info): + return False + + # pylint: enable=unused-argument + + @staticmethod + def unprotect(path_info): + pass + + @classmethod + def is_dir_hash(cls, hash_): + if not hash_: + return False + return hash_.endswith(cls.CHECKSUM_DIR_SUFFIX) + + def get_hash(self, path_info, tree=None, **kwargs): + assert isinstance(path_info, str) or path_info.scheme == self.scheme + + if not tree: + tree = self + + if not tree.exists(path_info): + return None + + if tree == self: + # pylint: disable=assignment-from-none + hash_ = self.state.get(path_info) + else: + hash_ = None + # If we have dir hash in state db, but dir cache file is lost, + # then we need to recollect the dir via .get_dir_hash() call below, + # see https://github.com/iterative/dvc/issues/2219 for context + if ( + hash_ + and self.is_dir_hash(hash_) + and not tree.exists(self.cache.hash_to_path_info(hash_)) + ): + hash_ = None + + if hash_: + return hash_ + + if tree.isdir(path_info): + hash_ = self.get_dir_hash(path_info, tree, **kwargs) + else: + hash_ = tree.get_file_hash(path_info) + + if hash_ and self.exists(path_info): + self.state.save(path_info, hash_) + + return hash_ + + def get_file_hash(self, path_info): + raise NotImplementedError + + def get_dir_hash(self, path_info, tree, **kwargs): + if not self.cache: + raise RemoteCacheRequiredError(path_info) + + dir_info = self._collect_dir(path_info, tree, **kwargs) + return self._save_dir_info(dir_info, path_info) + + def hash_to_path_info(self, hash_): + return self.path_info / hash_[0:2] / hash_[2:] + + def path_to_hash(self, path): + parts = self.PATH_CLS(path).parts[-2:] + + if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): + raise ValueError(f"Bad cache file path '{path}'") + + return "".join(parts) + + def save_info(self, path_info, tree=None, **kwargs): + return { + self.PARAM_CHECKSUM: self.get_hash(path_info, tree=tree, **kwargs) + } + + @staticmethod + def _calculate_hashes(file_infos, tree): + file_infos = list(file_infos) + with Tqdm( + total=len(file_infos), + unit="md5", + desc="Computing file/dir hashes (only done once)", + ) as pbar: + worker = pbar.wrap_fn(tree.get_file_hash) + with ThreadPoolExecutor(max_workers=tree.hash_jobs) as executor: + tasks = executor.map(worker, file_infos) + hashes = dict(zip(file_infos, tasks)) + return hashes + + def _collect_dir(self, path_info, tree, **kwargs): + file_infos = set() + + for fname in tree.walk_files(path_info, **kwargs): + if DvcIgnore.DVCIGNORE_FILE == fname.name: + raise DvcIgnoreInCollectedDirError(fname.parent) + + file_infos.add(fname) + + hashes = {fi: self.state.get(fi) for fi in file_infos} + not_in_state = {fi for fi, hash_ in hashes.items() if hash_ is None} + + new_hashes = self._calculate_hashes(not_in_state, tree) + hashes.update(new_hashes) + + result = [ + { + self.PARAM_CHECKSUM: hashes[fi], + # NOTE: this is lossy transformation: + # "hey\there" -> "hey/there" + # "hey/there" -> "hey/there" + # The latter is fine filename on Windows, which + # will transform to dir/file on back transform. + # + # Yes, this is a BUG, as long as we permit "/" in + # filenames on Windows and "\" on Unix + self.PARAM_RELPATH: fi.relative_to(path_info).as_posix(), + } + for fi in file_infos + ] + + # Sorting the list by path to ensure reproducibility + return sorted(result, key=itemgetter(self.PARAM_RELPATH)) + + def _save_dir_info(self, dir_info, path_info): + hash_, tmp_info = self._get_dir_info_hash(dir_info) + new_info = self.cache.hash_to_path_info(hash_) + if self.cache.changed_cache_file(hash_): + self.cache.tree.makedirs(new_info.parent) + self.cache.tree.move( + tmp_info, new_info, mode=self.cache.CACHE_MODE + ) + + if self.exists(path_info): + self.state.save(path_info, hash_) + self.state.save(new_info, hash_) + + return hash_ + + def _get_dir_info_hash(self, dir_info): + tmp = tempfile.NamedTemporaryFile(delete=False).name + with open(tmp, "w+") as fobj: + json.dump(dir_info, fobj, sort_keys=True) + + tree = self.cache.tree + from_info = PathInfo(tmp) + to_info = tree.path_info / tmp_fname("") + tree.upload(from_info, to_info, no_progress_bar=True) + + hash_ = tree.get_file_hash(to_info) + self.CHECKSUM_DIR_SUFFIX + return hash_, to_info + + def upload(self, from_info, to_info, name=None, no_progress_bar=False): + if not hasattr(self, "_upload"): + raise RemoteActionNotImplemented("upload", self.scheme) + + if to_info.scheme != self.scheme: + raise NotImplementedError + + if from_info.scheme != "local": + raise NotImplementedError + + logger.debug("Uploading '%s' to '%s'", from_info, to_info) + + name = name or from_info.name + + self._upload( # noqa, pylint: disable=no-member + from_info.fspath, + to_info, + name=name, + no_progress_bar=no_progress_bar, + ) + + def download( + self, + from_info, + to_info, + name=None, + no_progress_bar=False, + file_mode=None, + dir_mode=None, + ): + if not hasattr(self, "_download"): + raise RemoteActionNotImplemented("download", self.scheme) + + if from_info.scheme != self.scheme: + raise NotImplementedError + + if to_info.scheme == self.scheme != "local": + self.copy(from_info, to_info) + return 0 + + if to_info.scheme != "local": + raise NotImplementedError + + if self.isdir(from_info): + return self._download_dir( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) + return self._download_file( + from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ) + + def _download_dir( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): + from_infos = list(self.walk_files(from_info)) + to_infos = ( + to_info / info.relative_to(from_info) for info in from_infos + ) + + with Tqdm( + total=len(from_infos), + desc="Downloading directory", + unit="Files", + disable=no_progress_bar, + ) as pbar: + download_files = pbar.wrap_fn( + partial( + self._download_file, + name=name, + no_progress_bar=True, + file_mode=file_mode, + dir_mode=dir_mode, + ) + ) + with ThreadPoolExecutor(max_workers=self.JOBS) as executor: + futures = [ + executor.submit(download_files, from_info, to_info) + for from_info, to_info in zip(from_infos, to_infos) + ] + + # NOTE: unlike pulling/fetching cache, where we need to + # download everything we can, not raising an error here might + # turn very ugly, as the user might think that he has + # downloaded a complete directory, while having a partial one, + # which might cause unexpected results in his pipeline. + for future in as_completed(futures): + # NOTE: executor won't let us raise until all futures that + # it has are finished, so we need to cancel them ourselves + # before re-raising. + exc = future.exception() + if exc: + for entry in futures: + entry.cancel() + raise exc + + def _download_file( + self, from_info, to_info, name, no_progress_bar, file_mode, dir_mode + ): + makedirs(to_info.parent, exist_ok=True, mode=dir_mode) + + logger.debug("Downloading '%s' to '%s'", from_info, to_info) + name = name or to_info.name + + tmp_file = tmp_fname(to_info) + + self._download( # noqa, pylint: disable=no-member + from_info, tmp_file, name=name, no_progress_bar=no_progress_bar + ) + + move(tmp_file, to_info, mode=file_mode) + + def list_paths(self, prefix=None, progress_callback=None): + if prefix: + if len(prefix) > 2: + path_info = self.path_info / prefix[:2] / prefix[2:] + else: + path_info = self.path_info / prefix[:2] + prefix = True + else: + path_info = self.path_info + prefix = False + if progress_callback: + for file_info in self.walk_files(path_info, prefix=prefix): + progress_callback() + yield file_info.path + else: + yield from self.walk_files(path_info, prefix=prefix) + + def list_hashes(self, prefix=None, progress_callback=None): + """Iterate over hashes in this tree. + + If `prefix` is specified, only hashes which begin with `prefix` + will be returned. + """ + for path in self.list_paths(prefix, progress_callback): + try: + yield self.path_to_hash(path) + except ValueError: + logger.debug( + "'%s' doesn't look like a cache file, skipping", path + ) + + def all(self, jobs=None, name=None): + """Iterate over all hashes in this tree. + + Hashes will be fetched in parallel threads according to prefix + (except for small remotes) and a progress bar will be displayed. + """ + logger.debug( + "Fetching all hashes from '{}'".format( + name if name else "remote cache" + ) + ) + + if not self.CAN_TRAVERSE: + return self.list_hashes() + + remote_size, remote_hashes = self.estimate_remote_size(name=name) + return self.list_hashes_traverse( + remote_size, remote_hashes, jobs, name + ) + + def _hashes_with_limit(self, limit, prefix=None, progress_callback=None): + count = 0 + for hash_ in self.list_hashes(prefix, progress_callback): + yield hash_ + count += 1 + if count > limit: + logger.debug( + "`list_hashes()` returned max '{}' hashes, " + "skipping remaining results".format(limit) + ) + return + + def _max_estimation_size(self, hashes): + # Max remote size allowed for us to use traverse method + return max( + self.TRAVERSE_THRESHOLD_SIZE, + len(hashes) + / self.TRAVERSE_WEIGHT_MULTIPLIER + * self.LIST_OBJECT_PAGE_SIZE, + ) + + def estimate_remote_size(self, hashes=None, name=None): + """Estimate tree size based on number of entries beginning with + "00..." prefix. + """ + prefix = "0" * self.TRAVERSE_PREFIX_LEN + total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) + if hashes: + max_hashes = self._max_estimation_size(hashes) + else: + max_hashes = None + + with Tqdm( + desc="Estimating size of " + + (f"cache in '{name}'" if name else "remote cache"), + unit="file", + ) as pbar: + + def update(n=1): + pbar.update(n * total_prefixes) + + if max_hashes: + hashes = self._hashes_with_limit( + max_hashes / total_prefixes, prefix, update + ) + else: + hashes = self.list_hashes(prefix, update) + + remote_hashes = set(hashes) + if remote_hashes: + remote_size = total_prefixes * len(remote_hashes) + else: + remote_size = total_prefixes + logger.debug(f"Estimated remote size: {remote_size} files") + return remote_size, remote_hashes + + def list_hashes_traverse( + self, remote_size, remote_hashes, jobs=None, name=None + ): + """Iterate over all hashes found in this tree. + Hashes are fetched in parallel according to prefix, except in + cases where the remote size is very small. + + All hashes from the remote (including any from the size + estimation step passed via the `remote_hashes` argument) will be + returned. + + NOTE: For large remotes the list of hashes will be very + big(e.g. 100M entries, md5 for each is 32 bytes, so ~3200Mb list) + and we don't really need all of it at the same time, so it makes + sense to use a generator to gradually iterate over it, without + keeping all of it in memory. + """ + num_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE + if num_pages < 256 / self.JOBS: + # Fetching prefixes in parallel requires at least 255 more + # requests, for small enough remotes it will be faster to fetch + # entire cache without splitting it into prefixes. + # + # NOTE: this ends up re-fetching hashes that were already + # fetched during remote size estimation + traverse_prefixes = [None] + initial = 0 + else: + yield from remote_hashes + initial = len(remote_hashes) + traverse_prefixes = [f"{i:02x}" for i in range(1, 256)] + if self.TRAVERSE_PREFIX_LEN > 2: + traverse_prefixes += [ + "{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN) + for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2)) + ] + with Tqdm( + desc="Querying " + + (f"cache in '{name}'" if name else "remote cache"), + total=remote_size, + initial=initial, + unit="file", + ) as pbar: + + def list_with_update(prefix): + return list( + self.list_hashes( + prefix=prefix, progress_callback=pbar.update + ) + ) + + with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: + in_remote = executor.map(list_with_update, traverse_prefixes,) + yield from itertools.chain.from_iterable(in_remote) + + def list_hashes_exists(self, hashes, jobs=None, name=None): + """Return list of the specified hashes which exist in this tree. + Hashes will be queried individually. + """ + logger.debug( + "Querying {} hashes via object_exists".format(len(hashes)) + ) + with Tqdm( + desc="Querying " + + ("cache in " + name if name else "remote cache"), + total=len(hashes), + unit="file", + ) as pbar: + + def exists_with_progress(path_info): + ret = self.exists(path_info) + pbar.update_msg(str(path_info)) + return ret + + with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: + path_infos = map(self.hash_to_path_info, hashes) + in_remote = executor.map(exists_with_progress, path_infos) + ret = list(itertools.compress(hashes, in_remote)) + return ret + + def _remove_unpacked_dir(self, hash_): + pass diff --git a/dvc/remote/gdrive.py b/dvc/tree/gdrive.py similarity index 99% rename from dvc/remote/gdrive.py rename to dvc/tree/gdrive.py index b332b2d500..7c807269e3 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/tree/gdrive.py @@ -14,11 +14,12 @@ from dvc.exceptions import DvcException, FileMissingError from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes from dvc.utils import format_link, tmp_fname from dvc.utils.stream import IterStream +from .base import BaseRemoteTree + logger = logging.getLogger(__name__) FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" diff --git a/dvc/remote/gs.py b/dvc/tree/gs.py similarity index 99% rename from dvc/remote/gs.py rename to dvc/tree/gs.py index fc703231d2..7254ce4f28 100644 --- a/dvc/remote/gs.py +++ b/dvc/tree/gs.py @@ -9,9 +9,10 @@ from dvc.exceptions import DvcException from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes +from .base import BaseRemoteTree + logger = logging.getLogger(__name__) diff --git a/dvc/remote/hdfs.py b/dvc/tree/hdfs.py similarity index 100% rename from dvc/remote/hdfs.py rename to dvc/tree/hdfs.py diff --git a/dvc/remote/http.py b/dvc/tree/http.py similarity index 99% rename from dvc/remote/http.py rename to dvc/tree/http.py index d87c4932ea..8fddc9156e 100644 --- a/dvc/remote/http.py +++ b/dvc/tree/http.py @@ -8,9 +8,10 @@ from dvc.exceptions import DvcException, HTTPError from dvc.path_info import HTTPURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes +from .base import BaseRemoteTree + logger = logging.getLogger(__name__) diff --git a/dvc/remote/https.py b/dvc/tree/https.py similarity index 100% rename from dvc/remote/https.py rename to dvc/tree/https.py diff --git a/dvc/tree/local.py b/dvc/tree/local.py new file mode 100644 index 0000000000..cfa3e3b77c --- /dev/null +++ b/dvc/tree/local.py @@ -0,0 +1,309 @@ +import errno +import logging +import os +import stat + +from funcy import cached_property +from shortuuid import uuid + +from dvc.exceptions import DvcException +from dvc.path_info import PathInfo +from dvc.scheme import Schemes +from dvc.scm.tree import WorkingTree, is_working_tree +from dvc.system import System +from dvc.utils import file_md5, relpath, tmp_fname +from dvc.utils.fs import ( + copy_fobj_to_file, + copyfile, + makedirs, + move, + remove, + walk_files, +) + +from .base import BaseRemoteTree + +logger = logging.getLogger(__name__) + + +class LocalRemoteTree(BaseRemoteTree): + scheme = Schemes.LOCAL + PATH_CLS = PathInfo + PARAM_CHECKSUM = "md5" + PARAM_PATH = "path" + TRAVERSE_PREFIX_LEN = 2 + UNPACKED_DIR_SUFFIX = ".unpacked" + + CACHE_MODE = 0o444 + SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)} + + def __init__(self, repo, config): + super().__init__(repo, config) + url = config.get("url") + self.path_info = self.PATH_CLS(url) if url else None + + @property + def state(self): + from dvc.state import StateNoop + + return self.repo.state if self.repo else StateNoop() + + @cached_property + def work_tree(self): + # When using repo.brancher, repo.tree may change to/from WorkingTree to + # GitTree arbitarily. When repo.tree is GitTree, local cache needs to + # use its own WorkingTree instance. + if self.repo: + return WorkingTree(self.repo.root_dir) + return None + + @staticmethod + def open(path_info, mode="r", encoding=None): + return open(path_info, mode=mode, encoding=encoding) + + def exists(self, path_info): + assert isinstance(path_info, str) or path_info.scheme == "local" + if not self.repo: + return os.path.exists(path_info) + return self.work_tree.exists(path_info) + + def isfile(self, path_info): + if not self.repo: + return os.path.isfile(path_info) + return self.work_tree.isfile(path_info) + + def isdir(self, path_info): + if not self.repo: + return os.path.isdir(path_info) + return self.work_tree.isdir(path_info) + + def iscopy(self, path_info): + return not ( + System.is_symlink(path_info) or System.is_hardlink(path_info) + ) + + def walk_files(self, path_info, **kwargs): + for fname in self.work_tree.walk_files(path_info): + yield PathInfo(fname) + + def is_empty(self, path_info): + path = path_info.fspath + + if self.isfile(path_info) and os.path.getsize(path) == 0: + return True + + if self.isdir(path_info) and len(os.listdir(path)) == 0: + return True + + return False + + def remove(self, path_info): + if isinstance(path_info, PathInfo): + if path_info.scheme != "local": + raise NotImplementedError + path = path_info.fspath + else: + path = path_info + + if self.exists(path): + remove(path) + + def makedirs(self, path_info): + makedirs(path_info, exist_ok=True, mode=self.dir_mode) + + def move(self, from_info, to_info, mode=None): + if from_info.scheme != "local" or to_info.scheme != "local": + raise NotImplementedError + + self.makedirs(to_info.parent) + + if mode is None: + if self.isfile(from_info): + mode = self.file_mode + else: + mode = self.dir_mode + + move(from_info, to_info, mode=mode) + + def copy(self, from_info, to_info): + tmp_info = to_info.parent / tmp_fname(to_info.name) + try: + System.copy(from_info, tmp_info) + os.chmod(tmp_info, self.file_mode) + os.rename(tmp_info, to_info) + except Exception: + self.remove(tmp_info) + raise + + def copy_fobj(self, fobj, to_info): + self.makedirs(to_info.parent) + tmp_info = to_info.parent / tmp_fname(to_info.name) + try: + copy_fobj_to_file(fobj, tmp_info) + os.chmod(tmp_info, self.file_mode) + os.rename(tmp_info, to_info) + except Exception: + self.remove(tmp_info) + raise + + @staticmethod + def symlink(from_info, to_info): + System.symlink(from_info, to_info) + + @staticmethod + def is_symlink(path_info): + return System.is_symlink(path_info) + + def hardlink(self, from_info, to_info): + # If there are a lot of empty files (which happens a lot in datasets), + # and the cache type is `hardlink`, we might reach link limits and + # will get something like: `too many links error` + # + # This is because all those empty files will have the same hash + # (i.e. 68b329da9893e34099c7d8ad5cb9c940), therefore, they will be + # linked to the same file in the cache. + # + # From https://en.wikipedia.org/wiki/Hard_link + # * ext4 limits the number of hard links on a file to 65,000 + # * Windows with NTFS has a limit of 1024 hard links on a file + # + # That's why we simply create an empty file rather than a link. + if self.getsize(from_info) == 0: + self.open(to_info, "w").close() + + logger.debug( + "Created empty file: {src} -> {dest}".format( + src=str(from_info), dest=str(to_info) + ) + ) + return + + System.hardlink(from_info, to_info) + + @staticmethod + def is_hardlink(path_info): + return System.is_hardlink(path_info) + + def reflink(self, from_info, to_info): + tmp_info = to_info.parent / tmp_fname(to_info.name) + System.reflink(from_info, tmp_info) + # NOTE: reflink has its own separate inode, so you can set permissions + # that are different from the source. + os.chmod(tmp_info, self.file_mode) + os.rename(tmp_info, to_info) + + def _unprotect_file(self, path): + if System.is_symlink(path) or System.is_hardlink(path): + logger.debug(f"Unprotecting '{path}'") + tmp = os.path.join(os.path.dirname(path), "." + uuid()) + + # The operations order is important here - if some application + # would access the file during the process of copyfile then it + # would get only the part of file. So, at first, the file should be + # copied with the temporary name, and then original file should be + # replaced by new. + copyfile(path, tmp, name="Unprotecting '{}'".format(relpath(path))) + remove(path) + os.rename(tmp, path) + + else: + logger.debug( + "Skipping copying for '{}', since it is not " + "a symlink or a hardlink.".format(path) + ) + + os.chmod(path, self.file_mode) + + def _unprotect_dir(self, path): + assert is_working_tree(self.repo.tree) + + for fname in self.repo.tree.walk_files(path): + self._unprotect_file(fname) + + def unprotect(self, path_info): + path = path_info.fspath + if not os.path.exists(path): + raise DvcException(f"can't unprotect non-existing data '{path}'") + + if os.path.isdir(path): + self._unprotect_dir(path) + else: + self._unprotect_file(path) + + def protect(self, path_info): + path = os.fspath(path_info) + mode = self.CACHE_MODE + + try: + os.chmod(path, mode) + except OSError as exc: + # There is nothing we need to do in case of a read-only file system + if exc.errno == errno.EROFS: + return + + # In shared cache scenario, we might not own the cache file, so we + # need to check if cache file is already protected. + if exc.errno not in [errno.EPERM, errno.EACCES]: + raise + + actual = stat.S_IMODE(os.stat(path).st_mode) + if actual != mode: + raise + + def is_protected(self, path_info): + try: + mode = os.stat(path_info).st_mode + except FileNotFoundError: + return False + + return stat.S_IMODE(mode) == self.CACHE_MODE + + def get_file_hash(self, path_info): + return file_md5(path_info)[0] + + @staticmethod + def getsize(path_info): + return os.path.getsize(path_info) + + def _upload( + self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs + ): + makedirs(to_info.parent, exist_ok=True) + + tmp_file = tmp_fname(to_info) + copyfile( + from_file, tmp_file, name=name, no_progress_bar=no_progress_bar + ) + + self.protect(tmp_file) + os.rename(tmp_file, to_info) + + @staticmethod + def _download( + from_info, to_file, name=None, no_progress_bar=False, **_kwargs + ): + copyfile( + from_info, to_file, no_progress_bar=no_progress_bar, name=name + ) + + def list_paths(self, prefix=None, progress_callback=None): + assert self.path_info is not None + if prefix: + path_info = self.path_info / prefix[:2] + if not self.exists(path_info): + return + else: + path_info = self.path_info + # NOTE: use utils.fs walk_files since tree.walk_files will not follow + # symlinks + if progress_callback: + for path in walk_files(path_info): + progress_callback() + yield path + else: + yield from walk_files(path_info) + + def _remove_unpacked_dir(self, hash_): + info = self.hash_to_path_info(hash_) + path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) + self.remove(path_info) diff --git a/dvc/remote/oss.py b/dvc/tree/oss.py similarity index 98% rename from dvc/remote/oss.py rename to dvc/tree/oss.py index 5471169afb..2a165e45c2 100644 --- a/dvc/remote/oss.py +++ b/dvc/tree/oss.py @@ -6,9 +6,10 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes +from .base import BaseRemoteTree + logger = logging.getLogger(__name__) diff --git a/dvc/remote/pool.py b/dvc/tree/pool.py similarity index 100% rename from dvc/remote/pool.py rename to dvc/tree/pool.py diff --git a/dvc/remote/s3.py b/dvc/tree/s3.py similarity index 99% rename from dvc/remote/s3.py rename to dvc/tree/s3.py index 2c01fd800e..17d58301f1 100644 --- a/dvc/remote/s3.py +++ b/dvc/tree/s3.py @@ -8,9 +8,10 @@ from dvc.exceptions import DvcException, ETagMismatchError from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes +from .base import BaseRemoteTree + logger = logging.getLogger(__name__) diff --git a/dvc/remote/ssh/__init__.py b/dvc/tree/ssh/__init__.py similarity index 77% rename from dvc/remote/ssh/__init__.py rename to dvc/tree/ssh/__init__.py index 456e85f990..8c66f83daa 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/tree/ssh/__init__.py @@ -1,23 +1,19 @@ -import errno import getpass import io -import itertools import logging import os import posixpath import threading -from concurrent.futures import ThreadPoolExecutor from contextlib import closing, contextmanager from urllib.parse import urlparse from funcy import first, memoize, silent, wrap_with import dvc.prompt as prompt -from dvc.progress import Tqdm -from dvc.remote.base import BaseRemoteTree, Remote -from dvc.remote.pool import get_connection from dvc.scheme import Schemes -from dvc.utils import to_chunks + +from ..base import BaseRemoteTree +from ..pool import get_connection logger = logging.getLogger(__name__) @@ -283,65 +279,3 @@ def list_paths(self, prefix=None, progress_callback=None): yield path else: yield from ssh.walk_files(root) - - -class SSHRemote(Remote): - def batch_exists(self, path_infos, callback): - def _exists(chunk_and_channel): - chunk, channel = chunk_and_channel - ret = [] - for path in chunk: - try: - channel.stat(path) - ret.append(True) - except OSError as exc: - if exc.errno != errno.ENOENT: - raise - ret.append(False) - callback(path) - return ret - - with self.tree.ssh(path_infos[0]) as ssh: - channels = ssh.open_max_sftp_channels() - max_workers = len(channels) - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - paths = [path_info.path for path_info in path_infos] - chunks = to_chunks(paths, num_chunks=max_workers) - chunks_and_channels = zip(chunks, channels) - outcome = executor.map(_exists, chunks_and_channels) - results = list(itertools.chain.from_iterable(outcome)) - - return results - - def hashes_exist(self, hashes, jobs=None, name=None): - """This is older implementation used in remote/base.py - We are reusing it in RemoteSSH, because SSH's batch_exists proved to be - faster than current approach (relying on exists(path_info)) applied in - remote/base. - """ - if not self.tree.CAN_TRAVERSE: - return list(set(hashes) & set(self.tree.all())) - - # possibly prompt for credentials before "Querying" progress output - self.tree.ensure_credentials() - - with Tqdm( - desc="Querying " - + ("cache in " + name if name else "remote cache"), - total=len(hashes), - unit="file", - ) as pbar: - - def exists_with_progress(chunks): - return self.batch_exists(chunks, callback=pbar.update_msg) - - with ThreadPoolExecutor( - max_workers=jobs or self.tree.JOBS - ) as executor: - path_infos = [self.hash_to_path_info(x) for x in hashes] - chunks = to_chunks(path_infos, num_chunks=self.tree.JOBS) - results = executor.map(exists_with_progress, chunks) - in_remote = itertools.chain.from_iterable(results) - ret = list(itertools.compress(hashes, in_remote)) - return ret diff --git a/dvc/remote/ssh/connection.py b/dvc/tree/ssh/connection.py similarity index 99% rename from dvc/remote/ssh/connection.py rename to dvc/tree/ssh/connection.py index d3559314bb..6c8e258561 100644 --- a/dvc/remote/ssh/connection.py +++ b/dvc/tree/ssh/connection.py @@ -9,9 +9,10 @@ from dvc.exceptions import DvcException from dvc.progress import Tqdm -from dvc.remote.base import RemoteCmdError from dvc.utils import tmp_fname +from ..base import RemoteCmdError + try: import paramiko except ImportError: diff --git a/tests/conftest.py b/tests/conftest.py index c9e5c53403..796f465ea1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ def reset_loglevel(request, caplog): @pytest.fixture(scope="session", autouse=True) def _close_pools(): - from dvc.remote.pool import close_pools + from dvc.tree.pool import close_pools yield close_pools() diff --git a/tests/func/remote/test_gdrive.py b/tests/func/remote/test_gdrive.py index 0daecaee21..022391a298 100644 --- a/tests/func/remote/test_gdrive.py +++ b/tests/func/remote/test_gdrive.py @@ -4,8 +4,8 @@ import configobj from dvc.main import main -from dvc.remote.gdrive import GDriveRemoteTree from dvc.repo import Repo +from dvc.tree.gdrive import GDriveRemoteTree def test_relative_user_credentials_file_config_setting(tmp_dir, dvc): diff --git a/tests/func/test_add.py b/tests/func/test_add.py index 8babe59764..6df3311499 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -332,7 +332,7 @@ def test_dir(self): def test_should_update_state_entry_for_file_after_add(mocker, dvc, tmp_dir): - file_md5_counter = mocker.spy(dvc_module.remote.local, "file_md5") + file_md5_counter = mocker.spy(dvc_module.tree.local, "file_md5") tmp_dir.gen("foo", "foo") ret = main(["config", "cache.type", "copy"]) @@ -363,7 +363,7 @@ def test_should_update_state_entry_for_file_after_add(mocker, dvc, tmp_dir): def test_should_update_state_entry_for_directory_after_add( mocker, dvc, tmp_dir ): - file_md5_counter = mocker.spy(dvc_module.remote.local, "file_md5") + file_md5_counter = mocker.spy(dvc_module.tree.local, "file_md5") tmp_dir.gen({"data/data": "foo", "data/data_sub/sub_data": "foo"}) diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index c3cdb04fb0..c1482fa1c7 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -17,12 +17,12 @@ ) from dvc.main import main from dvc.remote.base import CloudCache, Remote -from dvc.remote.local import LocalRemoteTree -from dvc.remote.s3 import S3RemoteTree from dvc.repo import Repo as DvcRepo from dvc.stage import Stage from dvc.stage.exceptions import StageFileDoesNotExistError from dvc.system import System +from dvc.tree.local import LocalRemoteTree +from dvc.tree.s3 import S3RemoteTree from dvc.utils import relpath from dvc.utils.fs import walk_files from dvc.utils.yaml import dump_yaml, load_yaml @@ -760,12 +760,12 @@ def test_checkout_for_external_outputs(tmp_dir, dvc): dvc.cache.s3 = CloudCache(S3RemoteTree(dvc, {"url": S3.get_url()})) remote = Remote(S3RemoteTree(dvc, {"url": S3.get_url()})) - file_path = remote.path_info / "foo" + file_path = remote.tree.path_info / "foo" remote.tree.s3.put_object( - Bucket=remote.path_info.bucket, Key=file_path.path, Body="foo" + Bucket=remote.tree.path_info.bucket, Key=file_path.path, Body="foo" ) - dvc.add(str(remote.path_info / "foo"), external=True) + dvc.add(str(remote.tree.path_info / "foo"), external=True) remote.tree.remove(file_path) stats = dvc.checkout(force=True) @@ -773,7 +773,9 @@ def test_checkout_for_external_outputs(tmp_dir, dvc): assert remote.tree.exists(file_path) remote.tree.s3.put_object( - Bucket=remote.path_info.bucket, Key=file_path.path, Body="foo\nfoo" + Bucket=remote.tree.path_info.bucket, + Key=file_path.path, + Body="foo\nfoo", ) stats = dvc.checkout(force=True) assert stats == {**empty_checkout, "modified": [str(file_path)]} diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index 1e61e6fd39..db0a52d6d6 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -10,8 +10,8 @@ from dvc.exceptions import DownloadError, UploadError from dvc.main import main from dvc.path_info import PathInfo -from dvc.remote.base import BaseRemoteTree, RemoteCacheRequiredError -from dvc.remote.local import LocalRemoteTree +from dvc.tree.base import BaseRemoteTree, RemoteCacheRequiredError +from dvc.tree.local import LocalRemoteTree from dvc.utils.fs import remove from tests.basic_env import TestDvc from tests.remotes import Local @@ -192,9 +192,13 @@ def unreliable_upload(self, from_file, to_info, name=None, **kwargs): assert upload_error_info.value.amount == 3 remote = dvc.cloud.get_remote("upstream") - assert not remote.tree.exists(remote.hash_to_path_info(foo.checksum)) - assert remote.tree.exists(remote.hash_to_path_info(bar.checksum)) - assert not remote.tree.exists(remote.hash_to_path_info(baz.checksum)) + assert not remote.tree.exists( + remote.tree.hash_to_path_info(foo.checksum) + ) + assert remote.tree.exists(remote.tree.hash_to_path_info(bar.checksum)) + assert not remote.tree.exists( + remote.tree.hash_to_path_info(baz.checksum) + ) # Push everything and delete local cache dvc.push() @@ -388,7 +392,7 @@ def test_protect_local_remote(tmp_dir, dvc, local_remote): dvc.push() remote = dvc.cloud.get_remote("upstream") - remote_cache_file = remote.hash_to_path_info(stage.outs[0].checksum) + remote_cache_file = remote.tree.hash_to_path_info(stage.outs[0].checksum) assert os.path.exists(remote_cache_file) assert stat.S_IMODE(os.stat(remote_cache_file).st_mode) == 0o444 diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index 679ccaed9a..017721151c 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -6,7 +6,7 @@ from moto import mock_s3 from dvc.remote.base import CloudCache -from dvc.remote.s3 import S3RemoteTree +from dvc.tree.s3 import S3RemoteTree from tests.remotes import S3 # from https://github.com/spulec/moto/blob/v1.3.5/tests/test_s3/test_s3.py#L40 diff --git a/tests/remotes/gdrive.py b/tests/remotes/gdrive.py index 1246308ddf..05f7834275 100644 --- a/tests/remotes/gdrive.py +++ b/tests/remotes/gdrive.py @@ -6,7 +6,7 @@ from funcy import cached_property from dvc.path_info import CloudURLInfo -from dvc.remote.gdrive import GDriveRemoteTree +from dvc.tree.gdrive import GDriveRemoteTree from .base import Base diff --git a/tests/remotes/ssh.py b/tests/remotes/ssh.py index 098fcf0efe..0880c3647d 100644 --- a/tests/remotes/ssh.py +++ b/tests/remotes/ssh.py @@ -47,7 +47,7 @@ def config(self): @contextmanager def _ssh(self): - from dvc.remote.ssh.connection import SSHConnection + from dvc.tree.ssh.connection import SSHConnection conn = SSHConnection( host=self.host, @@ -117,7 +117,7 @@ def ssh_server(): @pytest.fixture def ssh_connection(ssh_server): - from dvc.remote.ssh.connection import SSHConnection + from dvc.tree.ssh.connection import SSHConnection yield SSHConnection( host=ssh_server.host, @@ -129,7 +129,7 @@ def ssh_connection(ssh_server): @pytest.fixture def ssh(ssh_server, monkeypatch): - from dvc.remote.ssh import SSHRemoteTree + from dvc.tree.ssh import SSHRemoteTree # NOTE: see http://github.com/iterative/dvc/pull/3501 monkeypatch.setattr(SSHRemoteTree, "CAN_TRAVERSE", False) diff --git a/tests/unit/remote/ssh/test_pool.py b/tests/unit/remote/ssh/test_pool.py index f3094eeb82..42abed76d9 100644 --- a/tests/unit/remote/ssh/test_pool.py +++ b/tests/unit/remote/ssh/test_pool.py @@ -1,7 +1,7 @@ import pytest -from dvc.remote.pool import get_connection -from dvc.remote.ssh.connection import SSHConnection +from dvc.tree.pool import get_connection +from dvc.tree.ssh.connection import SSHConnection from tests.remotes.ssh import TEST_SSH_KEY_PATH, TEST_SSH_USER diff --git a/tests/unit/remote/ssh/test_ssh.py b/tests/unit/remote/ssh/test_ssh.py index 958c527300..e679ec323b 100644 --- a/tests/unit/remote/ssh/test_ssh.py +++ b/tests/unit/remote/ssh/test_ssh.py @@ -5,8 +5,8 @@ import pytest from mock import mock_open, patch -from dvc.remote.ssh import SSHRemoteTree from dvc.system import System +from dvc.tree.ssh import SSHRemoteTree def test_url(dvc): diff --git a/tests/unit/remote/test_azure.py b/tests/unit/remote/test_azure.py index 200639e117..0687637199 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/remote/test_azure.py @@ -1,5 +1,5 @@ from dvc.path_info import PathInfo -from dvc.remote.azure import AzureRemoteTree +from dvc.tree.azure import AzureRemoteTree container_name = "container-name" connection_string = ( diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 016d71271a..ac146d2fe4 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -4,9 +4,9 @@ import pytest from dvc.path_info import PathInfo -from dvc.remote.base import ( +from dvc.remote.base import Remote +from dvc.tree.base import ( BaseRemoteTree, - Remote, RemoteCmdError, RemoteMissingDepsError, ) diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index 1fcae5dd7d..5fba03bafa 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -2,7 +2,7 @@ import pytest -from dvc.remote.gdrive import GDriveAuthError, GDriveRemoteTree +from dvc.tree.gdrive import GDriveAuthError, GDriveRemoteTree USER_CREDS_TOKEN_REFRESH_ERROR = '{"access_token": "", "client_id": "", "client_secret": "", "refresh_token": "", "token_expiry": "", "token_uri": "https://oauth2.googleapis.com/token", "user_agent": null, "revoke_uri": "https://oauth2.googleapis.com/revoke", "id_token": null, "id_token_jwt": null, "token_response": {"access_token": "", "expires_in": 3600, "scope": "https://www.googleapis.com/auth/drive.appdata https://www.googleapis.com/auth/drive", "token_type": "Bearer"}, "scopes": ["https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/drive.appdata"], "token_info_uri": "https://oauth2.googleapis.com/tokeninfo", "invalid": true, "_class": "OAuth2Credentials", "_module": "oauth2client.client"}' # noqa: E501 diff --git a/tests/unit/remote/test_gs.py b/tests/unit/remote/test_gs.py index ec86cfb5e8..c519fbbd68 100644 --- a/tests/unit/remote/test_gs.py +++ b/tests/unit/remote/test_gs.py @@ -2,7 +2,7 @@ import pytest import requests -from dvc.remote.gs import GSRemoteTree, dynamic_chunk_size +from dvc.tree.gs import GSRemoteTree, dynamic_chunk_size BUCKET = "bucket" PREFIX = "prefix" diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index 31e850c91c..13677d64f2 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -1,7 +1,7 @@ import pytest from dvc.exceptions import HTTPError -from dvc.remote.http import HTTPRemoteTree +from dvc.tree.http import HTTPRemoteTree def test_download_fails_on_error_code(dvc, http): diff --git a/tests/unit/remote/test_local.py b/tests/unit/remote/test_local.py index 24aafc1be5..ebe49f25bd 100644 --- a/tests/unit/remote/test_local.py +++ b/tests/unit/remote/test_local.py @@ -6,7 +6,8 @@ from dvc.cache import NamedCache from dvc.path_info import PathInfo from dvc.remote.index import RemoteIndexNoop -from dvc.remote.local import LocalCache, LocalRemoteTree +from dvc.remote.local import LocalCache +from dvc.tree.local import LocalRemoteTree def test_status_download_optimization(mocker, dvc): diff --git a/tests/unit/remote/test_oss.py b/tests/unit/remote/test_oss.py index c88804876d..b8b0f385c0 100644 --- a/tests/unit/remote/test_oss.py +++ b/tests/unit/remote/test_oss.py @@ -1,4 +1,4 @@ -from dvc.remote.oss import OSSRemoteTree +from dvc.tree.oss import OSSRemoteTree bucket_name = "bucket-name" endpoint = "endpoint" diff --git a/tests/unit/remote/test_remote.py b/tests/unit/remote/test_remote.py index 7b2677759e..7f5794d532 100644 --- a/tests/unit/remote/test_remote.py +++ b/tests/unit/remote/test_remote.py @@ -1,8 +1,8 @@ import pytest -from dvc.remote import get_cloud_tree -from dvc.remote.gs import GSRemoteTree -from dvc.remote.s3 import S3RemoteTree +from dvc.tree import get_cloud_tree +from dvc.tree.gs import GSRemoteTree +from dvc.tree.s3 import S3RemoteTree def test_remote_with_hash_jobs(dvc): diff --git a/tests/unit/remote/test_remote_tree.py b/tests/unit/remote/test_remote_tree.py index 0c01ab89b6..c6eaecd3c4 100644 --- a/tests/unit/remote/test_remote_tree.py +++ b/tests/unit/remote/test_remote_tree.py @@ -4,7 +4,7 @@ from dvc.path_info import PathInfo from dvc.remote import get_remote -from dvc.remote.s3 import S3RemoteTree +from dvc.tree.s3 import S3RemoteTree from dvc.utils.fs import walk_files remotes = [pytest.lazy_fixture(fix) for fix in ["gs", "s3"]] @@ -46,7 +46,7 @@ def test_isdir(remote): ] for expected, path in test_cases: - assert remote.tree.isdir(remote.path_info / path) == expected + assert remote.tree.isdir(remote.tree.path_info / path) == expected @pytest.mark.parametrize("remote", remotes, indirect=True) @@ -66,22 +66,24 @@ def test_exists(remote): ] for expected, path in test_cases: - assert remote.tree.exists(remote.path_info / path) == expected + assert remote.tree.exists(remote.tree.path_info / path) == expected @pytest.mark.parametrize("remote", remotes, indirect=True) def test_walk_files(remote): files = [ - remote.path_info / "data/alice", - remote.path_info / "data/alpha", - remote.path_info / "data/subdir-file.txt", - remote.path_info / "data/subdir/1", - remote.path_info / "data/subdir/2", - remote.path_info / "data/subdir/3", - remote.path_info / "data/subdir/empty_file", + remote.tree.path_info / "data/alice", + remote.tree.path_info / "data/alpha", + remote.tree.path_info / "data/subdir-file.txt", + remote.tree.path_info / "data/subdir/1", + remote.tree.path_info / "data/subdir/2", + remote.tree.path_info / "data/subdir/3", + remote.tree.path_info / "data/subdir/empty_file", ] - assert list(remote.tree.walk_files(remote.path_info / "data")) == files + assert ( + list(remote.tree.walk_files(remote.tree.path_info / "data")) == files + ) @pytest.mark.parametrize("remote", [pytest.lazy_fixture("s3")], indirect=True) @@ -91,7 +93,7 @@ def test_copy_preserve_etag_across_buckets(remote, dvc): another = S3RemoteTree(dvc, {"url": "s3://another", "region": "us-east-1"}) - from_info = remote.path_info / "foo" + from_info = remote.tree.path_info / "foo" to_info = another.path_info / "foo" remote.tree.copy(from_info, to_info) @@ -105,7 +107,7 @@ def test_copy_preserve_etag_across_buckets(remote, dvc): @pytest.mark.parametrize("remote", remotes, indirect=True) def test_makedirs(remote): tree = remote.tree - empty_dir = remote.path_info / "empty_dir" / "" + empty_dir = remote.tree.path_info / "empty_dir" / "" tree.remove(empty_dir) assert not tree.exists(empty_dir) tree.makedirs(empty_dir) @@ -132,14 +134,14 @@ def test_isfile(remote): ] for expected, path in test_cases: - assert remote.tree.isfile(remote.path_info / path) == expected + assert remote.tree.isfile(remote.tree.path_info / path) == expected @pytest.mark.parametrize("remote", remotes, indirect=True) def test_download_dir(remote, tmpdir): path = str(tmpdir / "data") to_info = PathInfo(path) - remote.tree.download(remote.path_info / "data", to_info) + remote.tree.download(remote.tree.path_info / "data", to_info) assert os.path.isdir(path) data_dir = tmpdir / "data" assert len(list(walk_files(path))) == 7 diff --git a/tests/unit/remote/test_s3.py b/tests/unit/remote/test_s3.py index 4623bd4144..570371c0c7 100644 --- a/tests/unit/remote/test_s3.py +++ b/tests/unit/remote/test_s3.py @@ -1,7 +1,7 @@ import pytest from dvc.config import ConfigError -from dvc.remote.s3 import S3RemoteTree +from dvc.tree.s3 import S3RemoteTree bucket_name = "bucket-name" prefix = "some/prefix"