diff --git a/dvc/config.py b/dvc/config.py index 34addc400b..50cd8feb3d 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -92,7 +92,7 @@ class RelPath(str): REMOTE_COMMON = { "url": str, "checksum_jobs": All(Coerce(int), Range(1)), - "no_traverse": Bool, + Optional("no_traverse"): Bool, # obsoleted "verify": Bool, } LOCAL_COMMON = { diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index cfc8032d96..23770ad18e 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -1,5 +1,6 @@ import logging import os +import posixpath import re from datetime import datetime, timedelta from urllib.parse import urlparse @@ -103,8 +104,14 @@ def _list_paths(self, bucket, prefix): next_marker = blobs.next_marker - def list_cache_paths(self): - return self._list_paths(self.path_info.bucket, self.path_info.path) + def list_cache_paths(self, prefix=None): + if prefix: + prefix = posixpath.join( + self.path_info.path, prefix[:2], prefix[2:] + ) + else: + prefix = self.path_info.path + return self._list_paths(self.path_info.bucket, prefix) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 67941b37a1..2c41e32279 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -80,8 +80,12 @@ class RemoteBASE(object): CHECKSUM_DIR_SUFFIX = ".dir" CHECKSUM_JOBS = max(1, min(4, cpu_count() // 2)) DEFAULT_CACHE_TYPES = ["copy"] - DEFAULT_NO_TRAVERSE = True DEFAULT_VERIFY = False + LIST_OBJECT_PAGE_SIZE = 1000 + TRAVERSE_WEIGHT_MULTIPLIER = 20 + TRAVERSE_PREFIX_LEN = 3 + TRAVERSE_THRESHOLD_SIZE = 500000 + CAN_TRAVERSE = True CACHE_MODE = None SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} @@ -101,7 +105,6 @@ def __init__(self, repo, config): or (self.repo and self.repo.config["core"].get("checksum_jobs")) or self.CHECKSUM_JOBS ) - self.no_traverse = config.get("no_traverse", self.DEFAULT_NO_TRAVERSE) self.verify = config.get("verify", self.DEFAULT_VERIFY) self._dir_info = {} @@ -686,7 +689,7 @@ def path_to_checksum(self, path): def checksum_to_path_info(self, checksum): return self.path_info / checksum[0:2] / checksum[2:] - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): raise NotImplementedError def all(self): @@ -802,11 +805,13 @@ def cache_exists(self, checksums, jobs=None, name=None): There are two ways of performing this check: - - Traverse: Get a list of all the files in the remote + - Traverse method: Get a list of all the files in the remote (traversing the cache directory) and compare it with - the given checksums. + the given checksums. Cache entries will be retrieved in parallel + threads according to prefix (i.e. entries starting with, "00...", + "01...", and so on) and a progress bar will be displayed. - - No traverse: For each given checksum, run the `exists` + - Exists method: For each given checksum, run the `exists` method and filter the checksums that aren't on the remote. This is done in parallel threads. It also shows a progress bar when performing the check. @@ -817,12 +822,120 @@ def cache_exists(self, checksums, jobs=None, name=None): check if particular file exists much quicker, use their own implementation of cache_exists (see ssh, local). + Which method to use will be automatically determined after estimating + the size of the remote cache, and comparing the estimated size with + len(checksums). To estimate the size of the remote cache, we fetch + a small subset of cache entries (i.e. entries starting with "00..."). + Based on the number of entries in that subset, the size of the full + cache can be estimated, since the cache is evenly distributed according + to checksum. + Returns: A list with checksums that were found in the remote """ - if not self.no_traverse: - return list(set(checksums) & set(self.all())) + # Remotes which do not use traverse prefix should override + # cache_exists() (see ssh, local) + assert self.TRAVERSE_PREFIX_LEN >= 2 + + if len(checksums) == 1 or not self.CAN_TRAVERSE: + return self._cache_object_exists(checksums, jobs, name) + + # Fetch cache entries beginning with "00..." prefix for estimating the + # size of entire remote cache + checksums = frozenset(checksums) + prefix = "0" * self.TRAVERSE_PREFIX_LEN + total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) + remote_checksums = set( + map(self.path_to_checksum, self.list_cache_paths(prefix=prefix)) + ) + if remote_checksums: + remote_size = total_prefixes * len(remote_checksums) + else: + remote_size = total_prefixes + logger.debug("Estimated remote size: {} files".format(remote_size)) + + traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE + # For sufficiently large remotes, traverse must be weighted to account + # for performance overhead from large lists/sets. + # From testing with S3, for remotes with 1M+ files, object_exists is + # faster until len(checksums) is at least 10k~100k + if remote_size > self.TRAVERSE_THRESHOLD_SIZE: + traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER + else: + traverse_weight = traverse_pages + if len(checksums) < traverse_weight: + logger.debug( + "Large remote ('{}' checksums < '{}' traverse weight), " + "using object_exists for remaining checksums".format( + len(checksums), traverse_weight + ) + ) + return list( + checksums & remote_checksums + ) + self._cache_object_exists( + checksums - remote_checksums, jobs, name + ) + if traverse_pages < 256 / self.JOBS: + # Threaded traverse will require making at least 255 more requests + # to the remote, so for small enough remotes, fetching the entire + # list at once will require fewer requests (but also take into + # account that this must be done sequentially rather than in + # parallel) + logger.debug( + "Querying {} checksums via default traverse".format( + len(checksums) + ) + ) + return list(checksums & set(self.all())) + + return self._cache_exists_traverse( + checksums, remote_checksums, jobs, name + ) + + def _cache_exists_traverse( + self, checksums, remote_checksums, jobs=None, name=None + ): + logger.debug( + "Querying {} checksums via threaded traverse".format( + len(checksums) + ) + ) + + traverse_prefixes = ["{:02x}".format(i) for i in range(1, 256)] + if self.TRAVERSE_PREFIX_LEN > 2: + traverse_prefixes += [ + "{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN) + for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2)) + ] + with Tqdm( + desc="Querying " + + ("cache in " + name if name else "remote cache"), + total=len(traverse_prefixes), + unit="dir", + ) as pbar: + + def list_with_update(prefix): + ret = map( + self.path_to_checksum, + list(self.list_cache_paths(prefix=prefix)), + ) + pbar.update_desc( + "Querying cache in '{}'".format(self.path_info / prefix) + ) + return ret + + with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: + in_remote = executor.map(list_with_update, traverse_prefixes,) + remote_checksums.update( + itertools.chain.from_iterable(in_remote) + ) + return list(checksums & remote_checksums) + + def _cache_object_exists(self, checksums, jobs=None, name=None): + logger.debug( + "Querying {} checksums via object_exists".format(len(checksums)) + ) with Tqdm( desc="Querying " + ("cache in " + name if name else "remote cache"), diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 56499e8537..484bd1a827 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -73,8 +73,10 @@ class RemoteGDrive(RemoteBASE): scheme = Schemes.GDRIVE path_cls = GDriveURLInfo REQUIRES = {"pydrive2": "pydrive2"} - DEFAULT_NO_TRAVERSE = False DEFAULT_VERIFY = True + # Always prefer traverse for GDrive since API usage quotas are a concern. + TRAVERSE_WEIGHT_MULTIPLIER = 1 + TRAVERSE_PREFIX_LEN = 2 GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA" DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json" @@ -414,12 +416,18 @@ def _download(self, from_info, to_file, name, no_progress_bar): file_id = self._get_remote_id(from_info) self.gdrive_download_file(file_id, to_file, name, no_progress_bar) - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): if not self.cache["ids"]: return + if prefix: + dir_ids = self.cache["dirs"].get(prefix[:2]) + if not dir_ids: + return + else: + dir_ids = self.cache["ids"] parents_query = " or ".join( - "'{}' in parents".format(dir_id) for dir_id in self.cache["ids"] + "'{}' in parents".format(dir_id) for dir_id in dir_ids ) query = "({}) and trashed=false".format(parents_query) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 9230699535..eab424458f 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -3,6 +3,7 @@ from functools import wraps import io import os.path +import posixpath import threading from funcy import cached_property, wrap_prop @@ -126,14 +127,18 @@ def remove(self, path_info): blob.delete() - def _list_paths(self, path_info, max_items=None): + def _list_paths(self, path_info, max_items=None, prefix=None): + if prefix: + prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:]) + else: + prefix = path_info.path for blob in self.gs.bucket(path_info.bucket).list_blobs( prefix=path_info.path, max_results=max_items ): yield blob.name - def list_cache_paths(self): - return self._list_paths(self.path_info) + def list_cache_paths(self, prefix=None): + return self._list_paths(self.path_info, prefix=prefix) def walk_files(self, path_info): for fname in self._list_paths(path_info / ""): diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 6c6cd62bab..f7e3439740 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -21,6 +21,7 @@ class RemoteHDFS(RemoteBASE): REGEX = r"^hdfs://((?P.*)@)?.*$" PARAM_CHECKSUM = "checksum" REQUIRES = {"pyarrow": "pyarrow"} + TRAVERSE_PREFIX_LEN = 2 def __init__(self, repo, config): super().__init__(repo, config) @@ -152,16 +153,26 @@ def open(self, path_info, mode="r", encoding=None): raise FileNotFoundError(*e.args) raise - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): if not self.exists(self.path_info): return - dirs = deque([self.path_info.path]) + if prefix: + root = posixpath.join(self.path_info.path, prefix[:2]) + else: + root = self.path_info.path + dirs = deque([root]) with self.hdfs(self.path_info) as hdfs: while dirs: - for entry in hdfs.ls(dirs.pop(), detail=True): - if entry["kind"] == "directory": - dirs.append(urlparse(entry["name"]).path) - elif entry["kind"] == "file": - yield urlparse(entry["name"]).path + try: + for entry in hdfs.ls(dirs.pop(), detail=True): + if entry["kind"] == "directory": + dirs.append(urlparse(entry["name"]).path) + elif entry["kind"] == "file": + yield urlparse(entry["name"]).path + except IOError as e: + # When searching for a specific prefix pyarrow raises an + # exception if the specified cache dir does not exist + if not prefix: + raise e diff --git a/dvc/remote/http.py b/dvc/remote/http.py index fa4966685e..6ea016ca6d 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -6,7 +6,6 @@ from dvc.path_info import HTTPURLInfo import dvc.prompt as prompt -from dvc.config import ConfigError from dvc.exceptions import DvcException, HTTPError from dvc.progress import Tqdm from dvc.remote.base import RemoteBASE @@ -32,6 +31,7 @@ class RemoteHTTP(RemoteBASE): REQUEST_TIMEOUT = 10 CHUNK_SIZE = 2 ** 16 PARAM_CHECKSUM = "etag" + CAN_TRAVERSE = False def __init__(self, repo, config): super().__init__(repo, config) @@ -45,12 +45,6 @@ def __init__(self, repo, config): else: self.path_info = None - if not self.no_traverse: - raise ConfigError( - "HTTP doesn't support traversing the remote to list existing " - "files. Use: `dvc remote modify no_traverse true`" - ) - self.auth = config.get("auth", None) self.custom_auth_header = config.get("custom_auth_header", None) self.password = config.get("password", None) diff --git a/dvc/remote/local.py b/dvc/remote/local.py index d1470620ec..25d409f67d 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -56,7 +56,8 @@ def cache_dir(self, value): def supported(cls, config): return True - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): + assert prefix is None assert self.path_info is not None return walk_files(self.path_info) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 989eb829fa..9d928f7cb0 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -1,5 +1,6 @@ import logging import os +import posixpath import threading from funcy import cached_property, wrap_prop @@ -95,8 +96,14 @@ def _list_paths(self, prefix): for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): yield blob.key - def list_cache_paths(self): - return self._list_paths(self.path_info.path) + def list_cache_paths(self, prefix=None): + if prefix: + prefix = posixpath.join( + self.path_info.path, prefix[:2], prefix[2:] + ) + else: + prefix = self.path_info.path + return self._list_paths(prefix) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 61636a0154..7bc9b0fcc1 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -2,6 +2,7 @@ import logging import os +import posixpath import threading from funcy import cached_property, wrap_prop @@ -190,24 +191,27 @@ def remove(self, path_info): logger.debug("Removing {}".format(path_info)) self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) - def _list_objects(self, path_info, max_items=None): + def _list_objects(self, path_info, max_items=None, prefix=None): """ Read config for list object api, paginate through list objects.""" kwargs = { "Bucket": path_info.bucket, "Prefix": path_info.path, "PaginationConfig": {"MaxItems": max_items}, } + if prefix: + kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2]) paginator = self.s3.get_paginator(self.list_objects_api) for page in paginator.paginate(**kwargs): yield from page.get("Contents", ()) - def _list_paths(self, path_info, max_items=None): + def _list_paths(self, path_info, max_items=None, prefix=None): return ( - item["Key"] for item in self._list_objects(path_info, max_items) + item["Key"] + for item in self._list_objects(path_info, max_items, prefix) ) - def list_cache_paths(self): - return self._list_paths(self.path_info) + def list_cache_paths(self, prefix=None): + return self._list_paths(self.path_info, prefix=prefix) def isfile(self, path_info): from botocore.exceptions import ClientError diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 09c23bc6be..1f191fff65 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -249,7 +249,8 @@ def open(self, path_info, mode="r", encoding=None): else: yield io.TextIOWrapper(fd, encoding=encoding) - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): + assert prefix is None with self.ssh(self.path_info) as ssh: # If we simply return an iterator then with above closes instantly yield from ssh.walk_files(self.path_info.path) @@ -297,7 +298,7 @@ def cache_exists(self, checksums, jobs=None, name=None): faster than current approach (relying on exists(path_info)) applied in remote/base. """ - if not self.no_traverse: + if not self.CAN_TRAVERSE: return list(set(checksums) & set(self.all())) # possibly prompt for credentials before "Querying" progress output diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 590094d8a8..976de78417 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -264,11 +264,11 @@ def _setup_cloud(self): repo = self.get_url() keyfile = self._get_keyfile() + self._get_cloud_class().CAN_TRAVERSE = False config = copy.deepcopy(TEST_CONFIG) config["remote"][TEST_REMOTE] = { "url": repo, "keyfile": keyfile, - "no_traverse": False, } self.dvc.config = config self.cloud = DataCloud(self.dvc) diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 2199886350..54ec121a9a 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -2,6 +2,7 @@ import mock +from dvc.path_info import PathInfo from dvc.remote.base import RemoteBASE from dvc.remote.base import RemoteCmdError from dvc.remote.base import RemoteMissingDepsError @@ -35,3 +36,77 @@ def test(self): ): with self.assertRaises(RemoteCmdError): self.REMOTE_CLS(repo, config).remove("file") + + +@mock.patch.object(RemoteBASE, "_cache_exists_traverse") +@mock.patch.object(RemoteBASE, "_cache_object_exists") +@mock.patch.object( + RemoteBASE, "path_to_checksum", side_effect=lambda x: x, +) +def test_cache_exists(path_to_checksum, object_exists, traverse): + remote = RemoteBASE(None, {}) + + # remote does not support traverse + remote.CAN_TRAVERSE = False + with mock.patch.object( + remote, "list_cache_paths", return_value=list(range(256)) + ): + checksums = list(range(1000)) + remote.cache_exists(checksums) + object_exists.assert_called_with(checksums, None, None) + traverse.assert_not_called() + + remote.CAN_TRAVERSE = True + + # large remote, small local + object_exists.reset_mock() + traverse.reset_mock() + with mock.patch.object( + remote, "list_cache_paths", return_value=list(range(256)) + ): + checksums = list(range(1000)) + remote.cache_exists(checksums) + object_exists.assert_called_with( + frozenset(range(256, 1000)), None, None + ) + traverse.assert_not_called() + + # large remote, large local + object_exists.reset_mock() + traverse.reset_mock() + remote.JOBS = 16 + with mock.patch.object( + remote, "list_cache_paths", return_value=list(range(256)) + ): + checksums = list(range(1000000)) + remote.cache_exists(checksums) + object_exists.assert_not_called() + traverse.assert_called_with( + frozenset(checksums), set(range(256)), None, None + ) + + # default traverse + object_exists.reset_mock() + traverse.reset_mock() + remote.TRAVERSE_WEIGHT_MULTIPLIER = 1 + with mock.patch.object(remote, "list_cache_paths", return_value=[0]): + checksums = set(range(1000000)) + remote.cache_exists(checksums) + traverse.assert_not_called() + object_exists.assert_not_called() + + +@mock.patch.object( + RemoteBASE, "list_cache_paths", return_value=[], +) +@mock.patch.object( + RemoteBASE, "path_to_checksum", side_effect=lambda x: x, +) +def test_cache_exists_traverse(path_to_checksum, list_cache_paths): + remote = RemoteBASE(None, {}) + remote.path_info = PathInfo("foo") + remote._cache_exists_traverse({0}, set()) + for i in range(1, 16): + list_cache_paths.assert_any_call(prefix="{:03x}".format(i)) + for i in range(1, 256): + list_cache_paths.assert_any_call(prefix="{:02x}".format(i)) diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index e165903ccb..72b790ea92 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -28,7 +28,6 @@ class TestRemoteGDrive(object): def test_init(self): remote = RemoteGDrive(Repo(), self.CONFIG) assert str(remote.path_info) == self.CONFIG["url"] - assert not remote.no_traverse def test_drive(self): remote = RemoteGDrive(Repo(), self.CONFIG) diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index 65ad2d2e85..f837e06178 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -1,23 +1,11 @@ import pytest -from dvc.config import ConfigError from dvc.exceptions import HTTPError from dvc.path_info import URLInfo from dvc.remote.http import RemoteHTTP from tests.utils.httpd import StaticFileServer -def test_no_traverse_compatibility(dvc): - config = { - "url": "http://example.com/", - "path_info": "file.html", - "no_traverse": False, - } - - with pytest.raises(ConfigError): - RemoteHTTP(dvc, config) - - def test_download_fails_on_error_code(dvc): with StaticFileServer() as httpd: url = "http://localhost:{}/".format(httpd.server_port) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py deleted file mode 100644 index d71d3c4077..0000000000 --- a/tests/unit/test_config.py +++ /dev/null @@ -1,16 +0,0 @@ -from dvc.config import COMPILED_SCHEMA - - -def test_remote_config_no_traverse(): - d = COMPILED_SCHEMA({"remote": {"myremote": {"url": "url"}}}) - assert "no_traverse" not in d["remote"]["myremote"] - - d = COMPILED_SCHEMA( - {"remote": {"myremote": {"url": "url", "no_traverse": "fAlSe"}}} - ) - assert not d["remote"]["myremote"]["no_traverse"] - - d = COMPILED_SCHEMA( - {"remote": {"myremote": {"url": "url", "no_traverse": "tRuE"}}} - ) - assert d["remote"]["myremote"]["no_traverse"]