From 97aedaa179dd396a625b7c6ae6ab8e6c03ab310a Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 18 Mar 2020 17:37:36 +0900 Subject: [PATCH 01/11] remote: Optimize traverse/no_traverse behavior - estimate remote file count by fetching a single parent cache dir, then determine whether or not to use no_traverse method for checking remainder of cache entries in `cache_exists` - thread requests when traversing full remote file lists (by fetching one parent cache dir per thread) --- dvc/remote/azure.py | 9 +++- dvc/remote/base.py | 93 ++++++++++++++++++++++++++++++++++++-- dvc/remote/gdrive.py | 12 ++++- dvc/remote/gs.py | 11 +++-- dvc/remote/hdfs.py | 25 +++++++--- dvc/remote/local.py | 2 +- dvc/remote/oss.py | 9 +++- dvc/remote/s3.py | 14 ++++-- dvc/remote/ssh/__init__.py | 2 +- 9 files changed, 150 insertions(+), 27 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index cfc8032d96..280d85b56a 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -26,6 +26,7 @@ class RemoteAZURE(RemoteBASE): r"(?P.+)?)?)$" ) REQUIRES = {"azure-storage-blob": "azure.storage.blob"} + DEFAULT_NO_TRAVERSE = False PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 @@ -103,8 +104,12 @@ def _list_paths(self, bucket, prefix): next_marker = blobs.next_marker - def list_cache_paths(self): - return self._list_paths(self.path_info.bucket, self.path_info.path) + def list_cache_paths(self, prefix=None): + if prefix: + prefix = "/".join([self.path_info.path, prefix[:2], prefix[2:]]) + else: + prefix = self.path_info.path + return self._list_paths(self.path_info.bucket, prefix) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 67941b37a1..5827cc8c52 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -82,6 +82,8 @@ class RemoteBASE(object): DEFAULT_CACHE_TYPES = ["copy"] DEFAULT_NO_TRAVERSE = True DEFAULT_VERIFY = False + LIST_OBJECT_PAGE_SIZE = 1000 + TRAVERSE_WEIGHT_MULTIPLIER = 20 CACHE_MODE = None SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} @@ -686,7 +688,7 @@ def path_to_checksum(self, path): def checksum_to_path_info(self, checksum): return self.path_info / checksum[0:2] / checksum[2:] - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): raise NotImplementedError def all(self): @@ -804,7 +806,9 @@ def cache_exists(self, checksums, jobs=None, name=None): - Traverse: Get a list of all the files in the remote (traversing the cache directory) and compare it with - the given checksums. + the given checksums. Cache parent directories will + be retrieved in parallel threads, and a progress bar + will be displayed. - No traverse: For each given checksum, run the `exists` method and filter the checksums that aren't on the remote. @@ -820,9 +824,90 @@ def cache_exists(self, checksums, jobs=None, name=None): Returns: A list with checksums that were found in the remote """ - if not self.no_traverse: - return list(set(checksums) & set(self.all())) + if self.no_traverse: + return self._cache_exists_no_traverse(checksums, jobs, name) + + # Fetch one parent cache dir for estimating size of entire remote cache + checksums = frozenset(checksums) + remote_checksums = set( + map(self.path_to_checksum, self.list_cache_paths(prefix="00")) + ) + if not remote_checksums: + remote_size = 256 + else: + remote_size = 256 * len(remote_checksums) + logger.debug("Estimated remote size: {} files".format(remote_size)) + + traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE + # For sufficiently large remotes, traverse must be weighted to account + # for performance overhead from large lists/sets. + # From testing with S3, for remotes with 1M+ files, no_traverse is + # faster until len(checksums) is at least 10k~100k + if remote_size > 500000: + traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER + else: + traverse_weight = traverse_pages + if len(checksums) < traverse_weight: + logger.debug( + "Large remote, using no_traverse for remaining checksums" + ) + return list( + checksums & remote_checksums + ) + self._cache_exists_no_traverse( + checksums - remote_checksums, jobs, name + ) + + if traverse_pages < 256 / self.JOBS: + # Threaded traverse will require making at least 255 more requests + # to the remote, so for small enough remotes, fetching the entire + # list at once will require fewer requests (but also take into + # account that this must be done sequentially rather than in + # parallel) + logger.debug( + "Querying {} checksums via default traverse".format( + len(checksums) + ) + ) + return list(checksums & set(self.all())) + + logger.debug( + "Querying {} checksums via threaded traverse".format( + len(checksums) + ) + ) + + with Tqdm( + desc="Querying " + + ("cache in " + name if name else "remote cache"), + total=256, + unit="dir", + ) as pbar: + + def list_with_update(prefix): + ret = map( + self.path_to_checksum, + list(self.list_cache_paths(prefix=prefix)), + ) + pbar.update_desc( + "Querying cache in {}/{}".format(self.path_info, prefix) + ) + return ret + + with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: + in_remote = executor.map( + list_with_update, + ["{:02x}".format(i) for i in range(1, 256)], + ) + remote_checksums.update( + itertools.chain.from_iterable(in_remote) + ) + return list(checksums & remote_checksums) + + def _cache_exists_no_traverse(self, checksums, jobs=None, name=None): + logger.debug( + "Querying {} checksums via no_traverse".format(len(checksums)) + ) with Tqdm( desc="Querying " + ("cache in " + name if name else "remote cache"), diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 56499e8537..175d8afdec 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -75,6 +75,8 @@ class RemoteGDrive(RemoteBASE): REQUIRES = {"pydrive2": "pydrive2"} DEFAULT_NO_TRAVERSE = False DEFAULT_VERIFY = True + # Always prefer traverse for GDrive since API usage quotas are a concern. + TRAVERSE_WEIGHT_MULTIPLIER = 1 GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA" DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json" @@ -414,12 +416,18 @@ def _download(self, from_info, to_file, name, no_progress_bar): file_id = self._get_remote_id(from_info) self.gdrive_download_file(file_id, to_file, name, no_progress_bar) - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): if not self.cache["ids"]: return + if prefix: + dir_ids = self.cache["dirs"].get(prefix[:2]) + if not dir_ids: + return + else: + dir_ids = self.cache["ids"] parents_query = " or ".join( - "'{}' in parents".format(dir_id) for dir_id in self.cache["ids"] + "'{}' in parents".format(dir_id) for dir_id in dir_ids ) query = "({}) and trashed=false".format(parents_query) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 9230699535..7572596191 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -69,6 +69,7 @@ class RemoteGS(RemoteBASE): scheme = Schemes.GS path_cls = CloudURLInfo REQUIRES = {"google-cloud-storage": "google.cloud.storage"} + DEFAULT_NO_TRAVERSE = False PARAM_CHECKSUM = "md5" def __init__(self, repo, config): @@ -126,14 +127,18 @@ def remove(self, path_info): blob.delete() - def _list_paths(self, path_info, max_items=None): + def _list_paths(self, path_info, max_items=None, prefix=None): + if prefix: + prefix = "/".join([path_info.path, prefix[:2], prefix[2:]]) + else: + prefix = path_info.path for blob in self.gs.bucket(path_info.bucket).list_blobs( prefix=path_info.path, max_results=max_items ): yield blob.name - def list_cache_paths(self): - return self._list_paths(self.path_info) + def list_cache_paths(self, prefix=None): + return self._list_paths(self.path_info, prefix=prefix) def walk_files(self, path_info): for fname in self._list_paths(path_info / ""): diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 6c6cd62bab..2b6be5229d 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -21,6 +21,7 @@ class RemoteHDFS(RemoteBASE): REGEX = r"^hdfs://((?P.*)@)?.*$" PARAM_CHECKSUM = "checksum" REQUIRES = {"pyarrow": "pyarrow"} + DEFAULT_NO_TRAVERSE = False def __init__(self, repo, config): super().__init__(repo, config) @@ -152,16 +153,26 @@ def open(self, path_info, mode="r", encoding=None): raise FileNotFoundError(*e.args) raise - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): if not self.exists(self.path_info): return - dirs = deque([self.path_info.path]) + if prefix: + root = "/".join([self.path_info.path, prefix[:2]]) + else: + root = self.path_info.path + dirs = deque([root]) with self.hdfs(self.path_info) as hdfs: while dirs: - for entry in hdfs.ls(dirs.pop(), detail=True): - if entry["kind"] == "directory": - dirs.append(urlparse(entry["name"]).path) - elif entry["kind"] == "file": - yield urlparse(entry["name"]).path + try: + for entry in hdfs.ls(dirs.pop(), detail=True): + if entry["kind"] == "directory": + dirs.append(urlparse(entry["name"]).path) + elif entry["kind"] == "file": + yield urlparse(entry["name"]).path + except IOError as e: + # When searching for a specific prefix pyarrow raises an + # exception if the specified cache dir does not exist + if not prefix: + raise e diff --git a/dvc/remote/local.py b/dvc/remote/local.py index d1470620ec..3f0cbc35ff 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -56,7 +56,7 @@ def cache_dir(self, value): def supported(cls, config): return True - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): assert self.path_info is not None return walk_files(self.path_info) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 989eb829fa..0376809401 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -35,6 +35,7 @@ class RemoteOSS(RemoteBASE): scheme = Schemes.OSS path_cls = CloudURLInfo REQUIRES = {"oss2": "oss2"} + DEFAULT_NO_TRAVERSE = False PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 @@ -95,8 +96,12 @@ def _list_paths(self, prefix): for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): yield blob.key - def list_cache_paths(self): - return self._list_paths(self.path_info.path) + def list_cache_paths(self, prefix=None): + if prefix: + prefix = "/".join([self.path_info.path, prefix[:2], prefix[2:]]) + else: + prefix = self.path_info.path + return self._list_paths(prefix) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 61636a0154..99c4b47994 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -21,6 +21,7 @@ class RemoteS3(RemoteBASE): scheme = Schemes.S3 path_cls = CloudURLInfo REQUIRES = {"boto3": "boto3"} + DEFAULT_NO_TRAVERSE = False PARAM_CHECKSUM = "etag" def __init__(self, repo, config): @@ -190,24 +191,27 @@ def remove(self, path_info): logger.debug("Removing {}".format(path_info)) self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) - def _list_objects(self, path_info, max_items=None): + def _list_objects(self, path_info, max_items=None, prefix=None): """ Read config for list object api, paginate through list objects.""" kwargs = { "Bucket": path_info.bucket, "Prefix": path_info.path, "PaginationConfig": {"MaxItems": max_items}, } + if prefix: + kwargs["Prefix"] = "/".join([path_info.path, prefix[:2]]) paginator = self.s3.get_paginator(self.list_objects_api) for page in paginator.paginate(**kwargs): yield from page.get("Contents", ()) - def _list_paths(self, path_info, max_items=None): + def _list_paths(self, path_info, max_items=None, prefix=None): return ( - item["Key"] for item in self._list_objects(path_info, max_items) + item["Key"] + for item in self._list_objects(path_info, max_items, prefix) ) - def list_cache_paths(self): - return self._list_paths(self.path_info) + def list_cache_paths(self, prefix=None): + return self._list_paths(self.path_info, prefix=prefix) def isfile(self, path_info): from botocore.exceptions import ClientError diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 09c23bc6be..35dff92d52 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -249,7 +249,7 @@ def open(self, path_info, mode="r", encoding=None): else: yield io.TextIOWrapper(fd, encoding=encoding) - def list_cache_paths(self): + def list_cache_paths(self, prefix=None): with self.ssh(self.path_info) as ssh: # If we simply return an iterator then with above closes instantly yield from ssh.walk_files(self.path_info.path) From 82f1b5281eb715046484e58f28d1c843a9ebe969 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 23 Mar 2020 13:35:37 +0900 Subject: [PATCH 02/11] Use arbitrary prefix length for traverse remote cache queries - default to 3 - for remotes that only support per-directory query use 2 (gdrive, hdfs) --- dvc/remote/base.py | 29 +++++++++++++++++++++-------- dvc/remote/gdrive.py | 1 + dvc/remote/hdfs.py | 1 + 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 5827cc8c52..4330981e67 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -84,6 +84,7 @@ class RemoteBASE(object): DEFAULT_VERIFY = False LIST_OBJECT_PAGE_SIZE = 1000 TRAVERSE_WEIGHT_MULTIPLIER = 20 + TRAVERSE_PREFIX_LEN = 3 CACHE_MODE = None SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} @@ -830,13 +831,15 @@ def cache_exists(self, checksums, jobs=None, name=None): # Fetch one parent cache dir for estimating size of entire remote cache checksums = frozenset(checksums) + prefix = "0" * self.TRAVERSE_PREFIX_LEN + traverse_dirs = pow(16, self.TRAVERSE_PREFIX_LEN) remote_checksums = set( - map(self.path_to_checksum, self.list_cache_paths(prefix="00")) + map(self.path_to_checksum, self.list_cache_paths(prefix=prefix)) ) if not remote_checksums: - remote_size = 256 + remote_size = traverse_dirs else: - remote_size = 256 * len(remote_checksums) + remote_size = traverse_dirs * len(remote_checksums) logger.debug("Estimated remote size: {} files".format(remote_size)) traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE @@ -871,16 +874,29 @@ def cache_exists(self, checksums, jobs=None, name=None): ) return list(checksums & set(self.all())) + return self._cache_exists_traverse( + checksums, remote_checksums, jobs, name + ) + + def _cache_exists_traverse( + self, checksums, remote_checksums, jobs=None, name=None + ): logger.debug( "Querying {} checksums via threaded traverse".format( len(checksums) ) ) + traverse_prefixes = ["{:02x}".format(i) for i in range(1, 256)] + if self.TRAVERSE_PREFIX_LEN > 2: + traverse_prefixes += [ + "{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN) + for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2)) + ] with Tqdm( desc="Querying " + ("cache in " + name if name else "remote cache"), - total=256, + total=len(traverse_prefixes), unit="dir", ) as pbar: @@ -895,10 +911,7 @@ def list_with_update(prefix): return ret with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: - in_remote = executor.map( - list_with_update, - ["{:02x}".format(i) for i in range(1, 256)], - ) + in_remote = executor.map(list_with_update, traverse_prefixes,) remote_checksums.update( itertools.chain.from_iterable(in_remote) ) diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 175d8afdec..f1d43f41e6 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -77,6 +77,7 @@ class RemoteGDrive(RemoteBASE): DEFAULT_VERIFY = True # Always prefer traverse for GDrive since API usage quotas are a concern. TRAVERSE_WEIGHT_MULTIPLIER = 1 + TRAVERSE_PREFIX_LEN = 2 GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA" DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json" diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 2b6be5229d..0ef172de17 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -22,6 +22,7 @@ class RemoteHDFS(RemoteBASE): PARAM_CHECKSUM = "checksum" REQUIRES = {"pyarrow": "pyarrow"} DEFAULT_NO_TRAVERSE = False + TRAVERSE_PREFIX_LEN = 2 def __init__(self, repo, config): super().__init__(repo, config) From a5a59469502f2a758918d71a341c07b183f9f58e Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 23 Mar 2020 13:37:07 +0900 Subject: [PATCH 03/11] Add unit tests for new traverse/no_traverse behavior --- tests/unit/remote/test_base.py | 72 ++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 2199886350..d82bfa66cc 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -35,3 +35,75 @@ def test(self): ): with self.assertRaises(RemoteCmdError): self.REMOTE_CLS(repo, config).remove("file") + + +@mock.patch.object(RemoteBASE, "_cache_exists_traverse") +@mock.patch.object(RemoteBASE, "_cache_exists_no_traverse") +@mock.patch.object( + RemoteBASE, "path_to_checksum", side_effect=lambda x: x, +) +def test_cache_exists(path_to_checksum, no_traverse, traverse): + remote = RemoteBASE(None, {}) + + # no_traverse option set + remote.no_traverse = True + with mock.patch.object( + remote, "list_cache_paths", return_value=list(range(256)) + ): + checksums = list(range(1000)) + remote.cache_exists(checksums) + no_traverse.assert_called_with(checksums, None, None) + traverse.assert_not_called() + + remote.no_traverse = False + + # large remote, small local + no_traverse.reset_mock() + traverse.reset_mock() + with mock.patch.object( + remote, "list_cache_paths", return_value=list(range(256)) + ): + checksums = list(range(1000)) + remote.cache_exists(checksums) + no_traverse.assert_called_with(frozenset(range(256, 1000)), None, None) + traverse.assert_not_called() + + # large remote, large local + no_traverse.reset_mock() + traverse.reset_mock() + remote.JOBS = 16 + with mock.patch.object( + remote, "list_cache_paths", return_value=list(range(256)) + ): + checksums = list(range(1000000)) + remote.cache_exists(checksums) + no_traverse.assert_not_called() + traverse.assert_called_with( + frozenset(checksums), set(range(256)), None, None + ) + + # default traverse + no_traverse.reset_mock() + traverse.reset_mock() + remote.TRAVERSE_WEIGHT_MULTIPLIER = 1 + with mock.patch.object(remote, "list_cache_paths", return_value=[0]): + checksums = set(range(1000000)) + remote.cache_exists(checksums) + traverse.assert_not_called() + no_traverse.assert_not_called() + + +@mock.patch.object( + RemoteBASE, "list_cache_paths", return_value=[], +) +@mock.patch.object( + RemoteBASE, "path_to_checksum", side_effect=lambda x: x, +) +def test_cache_exists_traverse(path_to_checksum, list_cache_paths): + remote = RemoteBASE(None, {}) + remote.path_info = "" + remote._cache_exists_traverse(set([0]), set()) + for i in range(1, 16): + list_cache_paths.assert_any_call(prefix="{:03x}".format(i)) + for i in range(1, 256): + list_cache_paths.assert_any_call(prefix="{:02x}".format(i)) From 5a5bb1893e011bd0d2c73defda2871f47212c4d5 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 23 Mar 2020 14:32:10 +0900 Subject: [PATCH 04/11] Fix deepsource warning --- tests/unit/remote/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index d82bfa66cc..b19dce0de5 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -102,7 +102,7 @@ def test_cache_exists(path_to_checksum, no_traverse, traverse): def test_cache_exists_traverse(path_to_checksum, list_cache_paths): remote = RemoteBASE(None, {}) remote.path_info = "" - remote._cache_exists_traverse(set([0]), set()) + remote._cache_exists_traverse({0}, set()) for i in range(1, 16): list_cache_paths.assert_any_call(prefix="{:03x}".format(i)) for i in range(1, 256): From f17a0ade0fa4ec10cc6b0e58459b0ac1955791ef Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Tue, 24 Mar 2020 12:15:42 +0900 Subject: [PATCH 05/11] Fix review issues --- dvc/remote/azure.py | 5 ++++- dvc/remote/base.py | 3 ++- dvc/remote/gs.py | 3 ++- dvc/remote/hdfs.py | 2 +- dvc/remote/local.py | 1 + dvc/remote/oss.py | 5 ++++- dvc/remote/s3.py | 3 ++- dvc/remote/ssh/__init__.py | 1 + 8 files changed, 17 insertions(+), 6 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 280d85b56a..b2c97dd13b 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -1,5 +1,6 @@ import logging import os +import posixpath import re from datetime import datetime, timedelta from urllib.parse import urlparse @@ -106,7 +107,9 @@ def _list_paths(self, bucket, prefix): def list_cache_paths(self, prefix=None): if prefix: - prefix = "/".join([self.path_info.path, prefix[:2], prefix[2:]]) + prefix = posixpath.join( + [self.path_info.path, prefix[:2], prefix[2:]] + ) else: prefix = self.path_info.path return self._list_paths(self.path_info.bucket, prefix) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 4330981e67..df4a881f66 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -85,6 +85,7 @@ class RemoteBASE(object): LIST_OBJECT_PAGE_SIZE = 1000 TRAVERSE_WEIGHT_MULTIPLIER = 20 TRAVERSE_PREFIX_LEN = 3 + TRAVERSE_THRESHOLD_SIZE = 500000 CACHE_MODE = None SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} @@ -847,7 +848,7 @@ def cache_exists(self, checksums, jobs=None, name=None): # for performance overhead from large lists/sets. # From testing with S3, for remotes with 1M+ files, no_traverse is # faster until len(checksums) is at least 10k~100k - if remote_size > 500000: + if remote_size > self.TRAVERSE_THRESHOLD_SIZE: traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER else: traverse_weight = traverse_pages diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 7572596191..1e2053baeb 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -3,6 +3,7 @@ from functools import wraps import io import os.path +import posixpath import threading from funcy import cached_property, wrap_prop @@ -129,7 +130,7 @@ def remove(self, path_info): def _list_paths(self, path_info, max_items=None, prefix=None): if prefix: - prefix = "/".join([path_info.path, prefix[:2], prefix[2:]]) + prefix = posixpath.join([path_info.path, prefix[:2], prefix[2:]]) else: prefix = path_info.path for blob in self.gs.bucket(path_info.bucket).list_blobs( diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 0ef172de17..489f3795a8 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -159,7 +159,7 @@ def list_cache_paths(self, prefix=None): return if prefix: - root = "/".join([self.path_info.path, prefix[:2]]) + root = posixpath.join([self.path_info.path, prefix[:2]]) else: root = self.path_info.path dirs = deque([root]) diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 3f0cbc35ff..25d409f67d 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -57,6 +57,7 @@ def supported(cls, config): return True def list_cache_paths(self, prefix=None): + assert prefix is None assert self.path_info is not None return walk_files(self.path_info) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 0376809401..48a2155aed 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -1,5 +1,6 @@ import logging import os +import posixpath import threading from funcy import cached_property, wrap_prop @@ -98,7 +99,9 @@ def _list_paths(self, prefix): def list_cache_paths(self, prefix=None): if prefix: - prefix = "/".join([self.path_info.path, prefix[:2], prefix[2:]]) + prefix = posixpath.join( + [self.path_info.path, prefix[:2], prefix[2:]] + ) else: prefix = self.path_info.path return self._list_paths(prefix) diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 99c4b47994..b4463dbabb 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -2,6 +2,7 @@ import logging import os +import posixpath import threading from funcy import cached_property, wrap_prop @@ -199,7 +200,7 @@ def _list_objects(self, path_info, max_items=None, prefix=None): "PaginationConfig": {"MaxItems": max_items}, } if prefix: - kwargs["Prefix"] = "/".join([path_info.path, prefix[:2]]) + kwargs["Prefix"] = posixpath.join([path_info.path, prefix[:2]]) paginator = self.s3.get_paginator(self.list_objects_api) for page in paginator.paginate(**kwargs): yield from page.get("Contents", ()) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 35dff92d52..f34e7adf84 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -250,6 +250,7 @@ def open(self, path_info, mode="r", encoding=None): yield io.TextIOWrapper(fd, encoding=encoding) def list_cache_paths(self, prefix=None): + assert prefix is None with self.ssh(self.path_info) as ssh: # If we simply return an iterator then with above closes instantly yield from ssh.walk_files(self.path_info.path) From 3bbe45b5ec5e491d766589a83ca5d4e5c4a7519d Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Tue, 24 Mar 2020 13:46:12 +0900 Subject: [PATCH 06/11] fix CI --- dvc/remote/azure.py | 2 +- dvc/remote/gs.py | 2 +- dvc/remote/hdfs.py | 2 +- dvc/remote/oss.py | 2 +- dvc/remote/s3.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index b2c97dd13b..f928da1d8e 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -108,7 +108,7 @@ def _list_paths(self, bucket, prefix): def list_cache_paths(self, prefix=None): if prefix: prefix = posixpath.join( - [self.path_info.path, prefix[:2], prefix[2:]] + self.path_info.path, prefix[:2], prefix[2:] ) else: prefix = self.path_info.path diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 1e2053baeb..31ce7e1fd2 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -130,7 +130,7 @@ def remove(self, path_info): def _list_paths(self, path_info, max_items=None, prefix=None): if prefix: - prefix = posixpath.join([path_info.path, prefix[:2], prefix[2:]]) + prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:]) else: prefix = path_info.path for blob in self.gs.bucket(path_info.bucket).list_blobs( diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 489f3795a8..81cb4376f1 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -159,7 +159,7 @@ def list_cache_paths(self, prefix=None): return if prefix: - root = posixpath.join([self.path_info.path, prefix[:2]]) + root = posixpath.join(self.path_info.path, prefix[:2]) else: root = self.path_info.path dirs = deque([root]) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 48a2155aed..a2eff6d715 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -100,7 +100,7 @@ def _list_paths(self, prefix): def list_cache_paths(self, prefix=None): if prefix: prefix = posixpath.join( - [self.path_info.path, prefix[:2], prefix[2:]] + self.path_info.path, prefix[:2], prefix[2:] ) else: prefix = self.path_info.path diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index b4463dbabb..c265ae3661 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -200,7 +200,7 @@ def _list_objects(self, path_info, max_items=None, prefix=None): "PaginationConfig": {"MaxItems": max_items}, } if prefix: - kwargs["Prefix"] = posixpath.join([path_info.path, prefix[:2]]) + kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2]) paginator = self.s3.get_paginator(self.list_objects_api) for page in paginator.paginate(**kwargs): yield from page.get("Contents", ()) From 65d7e2275e5e3eecf0e44f508071466ed545977a Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Tue, 24 Mar 2020 14:17:28 +0900 Subject: [PATCH 07/11] Fix review issues --- dvc/remote/base.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index df4a881f66..ce477564f2 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -823,24 +823,36 @@ def cache_exists(self, checksums, jobs=None, name=None): check if particular file exists much quicker, use their own implementation of cache_exists (see ssh, local). + Which method to use will be automatically determined after estimating + the size of the remote cache, and comparing the estimated size with + len(checksums). To estimate the size of the remote cache, we fetch + a small subset of cache entries (i.e. entries starting with "00..."). + Based on the number of entries in that subset, the size of the full + cache can be estimated, since the cache is evenly distributed according + to checksum. + Returns: A list with checksums that were found in the remote """ + # Remotes which do not use traverse prefix should override + # cache_exists() (see ssh, local) + assert self.TRAVERSE_PREFIX_LEN >= 2 if self.no_traverse: return self._cache_exists_no_traverse(checksums, jobs, name) - # Fetch one parent cache dir for estimating size of entire remote cache + # Fetch cache entries beginning with "00..." prefix for estimating the + # size of entire remote cache checksums = frozenset(checksums) prefix = "0" * self.TRAVERSE_PREFIX_LEN - traverse_dirs = pow(16, self.TRAVERSE_PREFIX_LEN) + total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) remote_checksums = set( map(self.path_to_checksum, self.list_cache_paths(prefix=prefix)) ) - if not remote_checksums: - remote_size = traverse_dirs + if remote_checksums: + remote_size = total_prefixes * len(remote_checksums) else: - remote_size = traverse_dirs * len(remote_checksums) + remote_size = total_prefixes logger.debug("Estimated remote size: {} files".format(remote_size)) traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE @@ -854,7 +866,10 @@ def cache_exists(self, checksums, jobs=None, name=None): traverse_weight = traverse_pages if len(checksums) < traverse_weight: logger.debug( - "Large remote, using no_traverse for remaining checksums" + "Large remote ({} checksums < {} traverse weight), " + "using no_traverse for remaining checksums".format( + len(checksums), traverse_weight + ) ) return list( checksums & remote_checksums From 88ea4dd54ba50f23f2443bfcbbfa61b19ebf463c Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Tue, 24 Mar 2020 22:29:47 +0900 Subject: [PATCH 08/11] obsolete no_traverse config option - remove RemoteBASE.DEFAULT_NO_TRAVERSE and replace it with RemoteBASE.CAN_TRAVERSE (True for all remotes except http) --- dvc/config.py | 2 +- dvc/remote/azure.py | 1 - dvc/remote/base.py | 5 ++--- dvc/remote/gdrive.py | 1 - dvc/remote/gs.py | 1 - dvc/remote/hdfs.py | 1 - dvc/remote/http.py | 8 +------- dvc/remote/oss.py | 1 - dvc/remote/s3.py | 1 - dvc/remote/ssh/__init__.py | 3 --- tests/unit/remote/test_base.py | 6 +++--- tests/unit/remote/test_gdrive.py | 1 - tests/unit/remote/test_http.py | 12 ------------ tests/unit/test_config.py | 16 ---------------- 14 files changed, 7 insertions(+), 52 deletions(-) delete mode 100644 tests/unit/test_config.py diff --git a/dvc/config.py b/dvc/config.py index 34addc400b..72a4513110 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -92,7 +92,7 @@ class RelPath(str): REMOTE_COMMON = { "url": str, "checksum_jobs": All(Coerce(int), Range(1)), - "no_traverse": Bool, + Optional("no_traverse", default=False): Bool, # obsoleted "verify": Bool, } LOCAL_COMMON = { diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index f928da1d8e..23770ad18e 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -27,7 +27,6 @@ class RemoteAZURE(RemoteBASE): r"(?P.+)?)?)$" ) REQUIRES = {"azure-storage-blob": "azure.storage.blob"} - DEFAULT_NO_TRAVERSE = False PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 diff --git a/dvc/remote/base.py b/dvc/remote/base.py index ce477564f2..f62240e9fd 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -80,12 +80,12 @@ class RemoteBASE(object): CHECKSUM_DIR_SUFFIX = ".dir" CHECKSUM_JOBS = max(1, min(4, cpu_count() // 2)) DEFAULT_CACHE_TYPES = ["copy"] - DEFAULT_NO_TRAVERSE = True DEFAULT_VERIFY = False LIST_OBJECT_PAGE_SIZE = 1000 TRAVERSE_WEIGHT_MULTIPLIER = 20 TRAVERSE_PREFIX_LEN = 3 TRAVERSE_THRESHOLD_SIZE = 500000 + CAN_TRAVERSE = True CACHE_MODE = None SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} @@ -105,7 +105,6 @@ def __init__(self, repo, config): or (self.repo and self.repo.config["core"].get("checksum_jobs")) or self.CHECKSUM_JOBS ) - self.no_traverse = config.get("no_traverse", self.DEFAULT_NO_TRAVERSE) self.verify = config.get("verify", self.DEFAULT_VERIFY) self._dir_info = {} @@ -838,7 +837,7 @@ def cache_exists(self, checksums, jobs=None, name=None): # cache_exists() (see ssh, local) assert self.TRAVERSE_PREFIX_LEN >= 2 - if self.no_traverse: + if len(checksums) == 1 or not self.CAN_TRAVERSE: return self._cache_exists_no_traverse(checksums, jobs, name) # Fetch cache entries beginning with "00..." prefix for estimating the diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index f1d43f41e6..484bd1a827 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -73,7 +73,6 @@ class RemoteGDrive(RemoteBASE): scheme = Schemes.GDRIVE path_cls = GDriveURLInfo REQUIRES = {"pydrive2": "pydrive2"} - DEFAULT_NO_TRAVERSE = False DEFAULT_VERIFY = True # Always prefer traverse for GDrive since API usage quotas are a concern. TRAVERSE_WEIGHT_MULTIPLIER = 1 diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 31ce7e1fd2..eab424458f 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -70,7 +70,6 @@ class RemoteGS(RemoteBASE): scheme = Schemes.GS path_cls = CloudURLInfo REQUIRES = {"google-cloud-storage": "google.cloud.storage"} - DEFAULT_NO_TRAVERSE = False PARAM_CHECKSUM = "md5" def __init__(self, repo, config): diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 81cb4376f1..f7e3439740 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -21,7 +21,6 @@ class RemoteHDFS(RemoteBASE): REGEX = r"^hdfs://((?P.*)@)?.*$" PARAM_CHECKSUM = "checksum" REQUIRES = {"pyarrow": "pyarrow"} - DEFAULT_NO_TRAVERSE = False TRAVERSE_PREFIX_LEN = 2 def __init__(self, repo, config): diff --git a/dvc/remote/http.py b/dvc/remote/http.py index fa4966685e..6ea016ca6d 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -6,7 +6,6 @@ from dvc.path_info import HTTPURLInfo import dvc.prompt as prompt -from dvc.config import ConfigError from dvc.exceptions import DvcException, HTTPError from dvc.progress import Tqdm from dvc.remote.base import RemoteBASE @@ -32,6 +31,7 @@ class RemoteHTTP(RemoteBASE): REQUEST_TIMEOUT = 10 CHUNK_SIZE = 2 ** 16 PARAM_CHECKSUM = "etag" + CAN_TRAVERSE = False def __init__(self, repo, config): super().__init__(repo, config) @@ -45,12 +45,6 @@ def __init__(self, repo, config): else: self.path_info = None - if not self.no_traverse: - raise ConfigError( - "HTTP doesn't support traversing the remote to list existing " - "files. Use: `dvc remote modify no_traverse true`" - ) - self.auth = config.get("auth", None) self.custom_auth_header = config.get("custom_auth_header", None) self.password = config.get("password", None) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index a2eff6d715..9d928f7cb0 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -36,7 +36,6 @@ class RemoteOSS(RemoteBASE): scheme = Schemes.OSS path_cls = CloudURLInfo REQUIRES = {"oss2": "oss2"} - DEFAULT_NO_TRAVERSE = False PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index c265ae3661..7bc9b0fcc1 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -22,7 +22,6 @@ class RemoteS3(RemoteBASE): scheme = Schemes.S3 path_cls = CloudURLInfo REQUIRES = {"boto3": "boto3"} - DEFAULT_NO_TRAVERSE = False PARAM_CHECKSUM = "etag" def __init__(self, repo, config): diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index f34e7adf84..a981864680 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -298,9 +298,6 @@ def cache_exists(self, checksums, jobs=None, name=None): faster than current approach (relying on exists(path_info)) applied in remote/base. """ - if not self.no_traverse: - return list(set(checksums) & set(self.all())) - # possibly prompt for credentials before "Querying" progress output self.ensure_credentials() diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index b19dce0de5..a9a6595097 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -45,8 +45,8 @@ def test(self): def test_cache_exists(path_to_checksum, no_traverse, traverse): remote = RemoteBASE(None, {}) - # no_traverse option set - remote.no_traverse = True + # remote does not support traverse + remote.CAN_TRAVERSE = False with mock.patch.object( remote, "list_cache_paths", return_value=list(range(256)) ): @@ -55,7 +55,7 @@ def test_cache_exists(path_to_checksum, no_traverse, traverse): no_traverse.assert_called_with(checksums, None, None) traverse.assert_not_called() - remote.no_traverse = False + remote.CAN_TRAVERSE = True # large remote, small local no_traverse.reset_mock() diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index e165903ccb..72b790ea92 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -28,7 +28,6 @@ class TestRemoteGDrive(object): def test_init(self): remote = RemoteGDrive(Repo(), self.CONFIG) assert str(remote.path_info) == self.CONFIG["url"] - assert not remote.no_traverse def test_drive(self): remote = RemoteGDrive(Repo(), self.CONFIG) diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index 65ad2d2e85..f837e06178 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -1,23 +1,11 @@ import pytest -from dvc.config import ConfigError from dvc.exceptions import HTTPError from dvc.path_info import URLInfo from dvc.remote.http import RemoteHTTP from tests.utils.httpd import StaticFileServer -def test_no_traverse_compatibility(dvc): - config = { - "url": "http://example.com/", - "path_info": "file.html", - "no_traverse": False, - } - - with pytest.raises(ConfigError): - RemoteHTTP(dvc, config) - - def test_download_fails_on_error_code(dvc): with StaticFileServer() as httpd: url = "http://localhost:{}/".format(httpd.server_port) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py deleted file mode 100644 index d71d3c4077..0000000000 --- a/tests/unit/test_config.py +++ /dev/null @@ -1,16 +0,0 @@ -from dvc.config import COMPILED_SCHEMA - - -def test_remote_config_no_traverse(): - d = COMPILED_SCHEMA({"remote": {"myremote": {"url": "url"}}}) - assert "no_traverse" not in d["remote"]["myremote"] - - d = COMPILED_SCHEMA( - {"remote": {"myremote": {"url": "url", "no_traverse": "fAlSe"}}} - ) - assert not d["remote"]["myremote"]["no_traverse"] - - d = COMPILED_SCHEMA( - {"remote": {"myremote": {"url": "url", "no_traverse": "tRuE"}}} - ) - assert d["remote"]["myremote"]["no_traverse"] From 60938689e1cd17696a1938ff9a1681077f57f2dc Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Tue, 24 Mar 2020 22:55:02 +0900 Subject: [PATCH 09/11] review issues - use 'object exists' instead of 'no_traverse' --- dvc/remote/base.py | 26 +++++++++++++------------- tests/unit/remote/test_base.py | 23 +++++++++++++---------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index f62240e9fd..2c41e32279 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -805,13 +805,13 @@ def cache_exists(self, checksums, jobs=None, name=None): There are two ways of performing this check: - - Traverse: Get a list of all the files in the remote + - Traverse method: Get a list of all the files in the remote (traversing the cache directory) and compare it with - the given checksums. Cache parent directories will - be retrieved in parallel threads, and a progress bar - will be displayed. + the given checksums. Cache entries will be retrieved in parallel + threads according to prefix (i.e. entries starting with, "00...", + "01...", and so on) and a progress bar will be displayed. - - No traverse: For each given checksum, run the `exists` + - Exists method: For each given checksum, run the `exists` method and filter the checksums that aren't on the remote. This is done in parallel threads. It also shows a progress bar when performing the check. @@ -838,7 +838,7 @@ def cache_exists(self, checksums, jobs=None, name=None): assert self.TRAVERSE_PREFIX_LEN >= 2 if len(checksums) == 1 or not self.CAN_TRAVERSE: - return self._cache_exists_no_traverse(checksums, jobs, name) + return self._cache_object_exists(checksums, jobs, name) # Fetch cache entries beginning with "00..." prefix for estimating the # size of entire remote cache @@ -857,7 +857,7 @@ def cache_exists(self, checksums, jobs=None, name=None): traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE # For sufficiently large remotes, traverse must be weighted to account # for performance overhead from large lists/sets. - # From testing with S3, for remotes with 1M+ files, no_traverse is + # From testing with S3, for remotes with 1M+ files, object_exists is # faster until len(checksums) is at least 10k~100k if remote_size > self.TRAVERSE_THRESHOLD_SIZE: traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER @@ -865,14 +865,14 @@ def cache_exists(self, checksums, jobs=None, name=None): traverse_weight = traverse_pages if len(checksums) < traverse_weight: logger.debug( - "Large remote ({} checksums < {} traverse weight), " - "using no_traverse for remaining checksums".format( + "Large remote ('{}' checksums < '{}' traverse weight), " + "using object_exists for remaining checksums".format( len(checksums), traverse_weight ) ) return list( checksums & remote_checksums - ) + self._cache_exists_no_traverse( + ) + self._cache_object_exists( checksums - remote_checksums, jobs, name ) @@ -921,7 +921,7 @@ def list_with_update(prefix): list(self.list_cache_paths(prefix=prefix)), ) pbar.update_desc( - "Querying cache in {}/{}".format(self.path_info, prefix) + "Querying cache in '{}'".format(self.path_info / prefix) ) return ret @@ -932,9 +932,9 @@ def list_with_update(prefix): ) return list(checksums & remote_checksums) - def _cache_exists_no_traverse(self, checksums, jobs=None, name=None): + def _cache_object_exists(self, checksums, jobs=None, name=None): logger.debug( - "Querying {} checksums via no_traverse".format(len(checksums)) + "Querying {} checksums via object_exists".format(len(checksums)) ) with Tqdm( desc="Querying " diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index a9a6595097..54ec121a9a 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -2,6 +2,7 @@ import mock +from dvc.path_info import PathInfo from dvc.remote.base import RemoteBASE from dvc.remote.base import RemoteCmdError from dvc.remote.base import RemoteMissingDepsError @@ -38,11 +39,11 @@ def test(self): @mock.patch.object(RemoteBASE, "_cache_exists_traverse") -@mock.patch.object(RemoteBASE, "_cache_exists_no_traverse") +@mock.patch.object(RemoteBASE, "_cache_object_exists") @mock.patch.object( RemoteBASE, "path_to_checksum", side_effect=lambda x: x, ) -def test_cache_exists(path_to_checksum, no_traverse, traverse): +def test_cache_exists(path_to_checksum, object_exists, traverse): remote = RemoteBASE(None, {}) # remote does not support traverse @@ -52,24 +53,26 @@ def test_cache_exists(path_to_checksum, no_traverse, traverse): ): checksums = list(range(1000)) remote.cache_exists(checksums) - no_traverse.assert_called_with(checksums, None, None) + object_exists.assert_called_with(checksums, None, None) traverse.assert_not_called() remote.CAN_TRAVERSE = True # large remote, small local - no_traverse.reset_mock() + object_exists.reset_mock() traverse.reset_mock() with mock.patch.object( remote, "list_cache_paths", return_value=list(range(256)) ): checksums = list(range(1000)) remote.cache_exists(checksums) - no_traverse.assert_called_with(frozenset(range(256, 1000)), None, None) + object_exists.assert_called_with( + frozenset(range(256, 1000)), None, None + ) traverse.assert_not_called() # large remote, large local - no_traverse.reset_mock() + object_exists.reset_mock() traverse.reset_mock() remote.JOBS = 16 with mock.patch.object( @@ -77,20 +80,20 @@ def test_cache_exists(path_to_checksum, no_traverse, traverse): ): checksums = list(range(1000000)) remote.cache_exists(checksums) - no_traverse.assert_not_called() + object_exists.assert_not_called() traverse.assert_called_with( frozenset(checksums), set(range(256)), None, None ) # default traverse - no_traverse.reset_mock() + object_exists.reset_mock() traverse.reset_mock() remote.TRAVERSE_WEIGHT_MULTIPLIER = 1 with mock.patch.object(remote, "list_cache_paths", return_value=[0]): checksums = set(range(1000000)) remote.cache_exists(checksums) traverse.assert_not_called() - no_traverse.assert_not_called() + object_exists.assert_not_called() @mock.patch.object( @@ -101,7 +104,7 @@ def test_cache_exists(path_to_checksum, no_traverse, traverse): ) def test_cache_exists_traverse(path_to_checksum, list_cache_paths): remote = RemoteBASE(None, {}) - remote.path_info = "" + remote.path_info = PathInfo("foo") remote._cache_exists_traverse({0}, set()) for i in range(1, 16): list_cache_paths.assert_any_call(prefix="{:03x}".format(i)) From 1bb8ef91b33f7b33a3efc9aa7df1e64b016f556e Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 25 Mar 2020 00:12:01 +0900 Subject: [PATCH 10/11] fix obsolete no_traverse test issues --- dvc/config.py | 2 +- tests/func/test_data_cloud.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/dvc/config.py b/dvc/config.py index 72a4513110..50cd8feb3d 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -92,7 +92,7 @@ class RelPath(str): REMOTE_COMMON = { "url": str, "checksum_jobs": All(Coerce(int), Range(1)), - Optional("no_traverse", default=False): Bool, # obsoleted + Optional("no_traverse"): Bool, # obsoleted "verify": Bool, } LOCAL_COMMON = { diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 590094d8a8..fd3bdc2ed9 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -268,7 +268,6 @@ def _setup_cloud(self): config["remote"][TEST_REMOTE] = { "url": repo, "keyfile": keyfile, - "no_traverse": False, } self.dvc.config = config self.cloud = DataCloud(self.dvc) From 75c124fab85b2aa1c1d038e2375280fa841dfd9a Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 25 Mar 2020 10:53:54 +0900 Subject: [PATCH 11/11] fix SSHMocked func tests --- dvc/remote/ssh/__init__.py | 3 +++ tests/func/test_data_cloud.py | 1 + 2 files changed, 4 insertions(+) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index a981864680..1f191fff65 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -298,6 +298,9 @@ def cache_exists(self, checksums, jobs=None, name=None): faster than current approach (relying on exists(path_info)) applied in remote/base. """ + if not self.CAN_TRAVERSE: + return list(set(checksums) & set(self.all())) + # possibly prompt for credentials before "Querying" progress output self.ensure_credentials() diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index fd3bdc2ed9..976de78417 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -264,6 +264,7 @@ def _setup_cloud(self): repo = self.get_url() keyfile = self._get_keyfile() + self._get_cloud_class().CAN_TRAVERSE = False config = copy.deepcopy(TEST_CONFIG) config["remote"][TEST_REMOTE] = { "url": repo,