diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 23770ad18e..f126aaeb43 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -29,6 +29,7 @@ class RemoteAZURE(RemoteBASE): REQUIRES = {"azure-storage-blob": "azure.storage.blob"} PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 + LIST_OBJECT_PAGE_SIZE = 5000 def __init__(self, repo, config): super().__init__(repo, config) @@ -88,7 +89,7 @@ def remove(self, path_info): logger.debug("Removing {}".format(path_info)) self.blob_service.delete_blob(path_info.bucket, path_info.path) - def _list_paths(self, bucket, prefix): + def _list_paths(self, bucket, prefix, progress_callback=None): blob_service = self.blob_service next_marker = None while True: @@ -97,6 +98,8 @@ def _list_paths(self, bucket, prefix): ) for blob in blobs: + if progress_callback: + progress_callback() yield blob.name if not blobs.next_marker: @@ -104,14 +107,16 @@ def _list_paths(self, bucket, prefix): next_marker = blobs.next_marker - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: prefix = posixpath.join( self.path_info.path, prefix[:2], prefix[2:] ) else: prefix = self.path_info.path - return self._list_paths(self.path_info.bucket, prefix) + return self._list_paths( + self.path_info.bucket, prefix, progress_callback + ) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 644e64c1e6..fc161d0151 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -694,7 +694,7 @@ def path_to_checksum(self, path): def checksum_to_path_info(self, checksum): return self.path_info / checksum[0:2] / checksum[2:] - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): raise NotImplementedError def all(self): @@ -850,9 +850,20 @@ def cache_exists(self, checksums, jobs=None, name=None): checksums = frozenset(checksums) prefix = "0" * self.TRAVERSE_PREFIX_LEN total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) - remote_checksums = set( - map(self.path_to_checksum, self.list_cache_paths(prefix=prefix)) - ) + with Tqdm( + desc="Estimating size of " + + ("cache in '{}'".format(name) if name else "remote cache"), + unit="file", + ) as pbar: + + def update(n=1): + pbar.update(n * total_prefixes) + + paths = self.list_cache_paths( + prefix=prefix, progress_callback=update + ) + remote_checksums = set(map(self.path_to_checksum, paths)) + if remote_checksums: remote_size = total_prefixes * len(remote_checksums) else: @@ -895,11 +906,11 @@ def cache_exists(self, checksums, jobs=None, name=None): return list(checksums & set(self.all())) return self._cache_exists_traverse( - checksums, remote_checksums, jobs, name + checksums, remote_checksums, remote_size, jobs, name ) def _cache_exists_traverse( - self, checksums, remote_checksums, jobs=None, name=None + self, checksums, remote_checksums, remote_size, jobs=None, name=None ): logger.debug( "Querying {} checksums via threaded traverse".format( @@ -915,20 +926,17 @@ def _cache_exists_traverse( ] with Tqdm( desc="Querying " - + ("cache in " + name if name else "remote cache"), - total=len(traverse_prefixes), - unit="dir", + + ("cache in '{}'".format(name) if name else "remote cache"), + total=remote_size, + initial=len(remote_checksums), + unit="objects", ) as pbar: def list_with_update(prefix): - ret = map( - self.path_to_checksum, - list(self.list_cache_paths(prefix=prefix)), + paths = self.list_cache_paths( + prefix=prefix, progress_callback=pbar.update ) - pbar.update_desc( - "Querying cache in '{}'".format(self.path_info / prefix) - ) - return ret + return map(self.path_to_checksum, list(paths)) with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: in_remote = executor.map(list_with_update, traverse_prefixes,) diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index c3d30e3b41..de4211641e 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -463,7 +463,7 @@ def _download(self, from_info, to_file, name, no_progress_bar): file_id = self._get_remote_id(from_info) self.gdrive_download_file(file_id, to_file, name, no_progress_bar) - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): if not self.cache["ids"]: return @@ -479,6 +479,8 @@ def list_cache_paths(self, prefix=None): query = "({}) and trashed=false".format(parents_query) for item in self.gdrive_list_item(query): + if progress_callback: + progress_callback() parent_id = item["parents"][0]["id"] yield posixpath.join(self.cache["ids"][parent_id], item["title"]) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 1eb7dbf8b9..78517cad96 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -127,7 +127,9 @@ def remove(self, path_info): blob.delete() - def _list_paths(self, path_info, max_items=None, prefix=None): + def _list_paths( + self, path_info, max_items=None, prefix=None, progress_callback=None + ): if prefix: prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:]) else: @@ -135,10 +137,14 @@ def _list_paths(self, path_info, max_items=None, prefix=None): for blob in self.gs.bucket(path_info.bucket).list_blobs( prefix=path_info.path, max_results=max_items ): + if progress_callback: + progress_callback() yield blob.name - def list_cache_paths(self, prefix=None): - return self._list_paths(self.path_info, prefix=prefix) + def list_cache_paths(self, prefix=None, progress_callback=None): + return self._list_paths( + self.path_info, prefix=prefix, progress_callback=progress_callback + ) def walk_files(self, path_info): for fname in self._list_paths(path_info / ""): diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index f7e3439740..c874b20682 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -153,7 +153,7 @@ def open(self, path_info, mode="r", encoding=None): raise FileNotFoundError(*e.args) raise - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): if not self.exists(self.path_info): return @@ -166,10 +166,13 @@ def list_cache_paths(self, prefix=None): with self.hdfs(self.path_info) as hdfs: while dirs: try: - for entry in hdfs.ls(dirs.pop(), detail=True): + entries = hdfs.ls(dirs.pop(), detail=True) + for entry in entries: if entry["kind"] == "directory": dirs.append(urlparse(entry["name"]).path) elif entry["kind"] == "file": + if progress_callback: + progress_callback() yield urlparse(entry["name"]).path except IOError as e: # When searching for a specific prefix pyarrow raises an diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 7a93434f3d..708480a135 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -56,10 +56,15 @@ def cache_dir(self, value): def supported(cls, config): return True - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): assert prefix is None assert self.path_info is not None - return walk_files(self.path_info) + if progress_callback: + for path in walk_files(self.path_info): + progress_callback() + yield path + else: + yield from walk_files(self.path_info) def get(self, md5): if not md5: diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 9d928f7cb0..9642fb511b 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -38,6 +38,7 @@ class RemoteOSS(RemoteBASE): REQUIRES = {"oss2": "oss2"} PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 + LIST_OBJECT_PAGE_SIZE = 100 def __init__(self, repo, config): super().__init__(repo, config) @@ -90,20 +91,22 @@ def remove(self, path_info): logger.debug("Removing oss://{}".format(path_info)) self.oss_service.delete_object(path_info.path) - def _list_paths(self, prefix): + def _list_paths(self, prefix, progress_callback=None): import oss2 for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): + if progress_callback: + progress_callback() yield blob.key - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: prefix = posixpath.join( self.path_info.path, prefix[:2], prefix[2:] ) else: prefix = self.path_info.path - return self._list_paths(prefix) + return self._list_paths(prefix, progress_callback) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 4eef5fc7b3..d75ca42546 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -191,7 +191,9 @@ def remove(self, path_info): logger.debug("Removing {}".format(path_info)) self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) - def _list_objects(self, path_info, max_items=None, prefix=None): + def _list_objects( + self, path_info, max_items=None, prefix=None, progress_callback=None + ): """ Read config for list object api, paginate through list objects.""" kwargs = { "Bucket": path_info.bucket, @@ -202,16 +204,28 @@ def _list_objects(self, path_info, max_items=None, prefix=None): kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2]) paginator = self.s3.get_paginator(self.list_objects_api) for page in paginator.paginate(**kwargs): - yield from page.get("Contents", ()) - - def _list_paths(self, path_info, max_items=None, prefix=None): + contents = page.get("Contents", ()) + if progress_callback: + for item in contents: + progress_callback() + yield item + else: + yield from contents + + def _list_paths( + self, path_info, max_items=None, prefix=None, progress_callback=None + ): return ( item["Key"] - for item in self._list_objects(path_info, max_items, prefix) + for item in self._list_objects( + path_info, max_items, prefix, progress_callback + ) ) - def list_cache_paths(self, prefix=None): - return self._list_paths(self.path_info, prefix=prefix) + def list_cache_paths(self, prefix=None, progress_callback=None): + return self._list_paths( + self.path_info, prefix=prefix, progress_callback=progress_callback + ) def isfile(self, path_info): from botocore.exceptions import ClientError diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 1f191fff65..703617b2af 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -249,11 +249,16 @@ def open(self, path_info, mode="r", encoding=None): else: yield io.TextIOWrapper(fd, encoding=encoding) - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): assert prefix is None with self.ssh(self.path_info) as ssh: # If we simply return an iterator then with above closes instantly - yield from ssh.walk_files(self.path_info.path) + if progress_callback: + for path in ssh.walk_files(self.path_info.path): + progress_callback() + yield path + else: + yield from ssh.walk_files(self.path_info.path) def walk_files(self, path_info): with self.ssh(path_info) as ssh: diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 54ec121a9a..64e64bc676 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -8,6 +8,16 @@ from dvc.remote.base import RemoteMissingDepsError +class _CallableOrNone(object): + """Helper for testing if object is callable() or None.""" + + def __eq__(self, other): + return other is None or callable(other) + + +CallableOrNone = _CallableOrNone() + + class TestRemoteBASE(object): REMOTE_CLS = RemoteBASE @@ -82,7 +92,11 @@ def test_cache_exists(path_to_checksum, object_exists, traverse): remote.cache_exists(checksums) object_exists.assert_not_called() traverse.assert_called_with( - frozenset(checksums), set(range(256)), None, None + frozenset(checksums), + set(range(256)), + 256 * pow(16, remote.TRAVERSE_PREFIX_LEN), + None, + None, ) # default traverse @@ -105,8 +119,12 @@ def test_cache_exists(path_to_checksum, object_exists, traverse): def test_cache_exists_traverse(path_to_checksum, list_cache_paths): remote = RemoteBASE(None, {}) remote.path_info = PathInfo("foo") - remote._cache_exists_traverse({0}, set()) + remote._cache_exists_traverse({0}, set(), 4096) for i in range(1, 16): - list_cache_paths.assert_any_call(prefix="{:03x}".format(i)) + list_cache_paths.assert_any_call( + prefix="{:03x}".format(i), progress_callback=CallableOrNone + ) for i in range(1, 256): - list_cache_paths.assert_any_call(prefix="{:02x}".format(i)) + list_cache_paths.assert_any_call( + prefix="{:02x}".format(i), progress_callback=CallableOrNone + )