From 48a8078f558b9ff6b2b100d11fd5b6bbc469dade Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 25 Mar 2020 15:07:37 +0900 Subject: [PATCH 1/7] remote: use progress bar when paginating - add `progress_callback` parameter to `list_cache_paths()` so that remotes can update remote cache traverse progress bar after fetching a page --- dvc/remote/azure.py | 12 +++++++++--- dvc/remote/base.py | 25 +++++++++++++------------ dvc/remote/gdrive.py | 14 ++++++++++---- dvc/remote/gs.py | 23 ++++++++++++++++------- dvc/remote/hdfs.py | 7 +++++-- dvc/remote/local.py | 2 +- dvc/remote/oss.py | 17 +++++++++++++---- dvc/remote/s3.py | 23 +++++++++++++++++------ dvc/remote/ssh/__init__.py | 2 +- tests/unit/remote/test_base.py | 12 ++++++++---- 10 files changed, 93 insertions(+), 44 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 23770ad18e..36281c2622 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -29,6 +29,7 @@ class RemoteAZURE(RemoteBASE): REQUIRES = {"azure-storage-blob": "azure.storage.blob"} PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 + LIST_OBJECT_PAGE_SIZE = 5000 def __init__(self, repo, config): super().__init__(repo, config) @@ -88,7 +89,7 @@ def remove(self, path_info): logger.debug("Removing {}".format(path_info)) self.blob_service.delete_blob(path_info.bucket, path_info.path) - def _list_paths(self, bucket, prefix): + def _list_paths(self, bucket, prefix, progress_callback=None): blob_service = self.blob_service next_marker = None while True: @@ -96,6 +97,9 @@ def _list_paths(self, bucket, prefix): bucket, prefix=prefix, marker=next_marker ) + if progress_callback: + progress_callback(len(blobs)) + for blob in blobs: yield blob.name @@ -104,14 +108,16 @@ def _list_paths(self, bucket, prefix): next_marker = blobs.next_marker - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: prefix = posixpath.join( self.path_info.path, prefix[:2], prefix[2:] ) else: prefix = self.path_info.path - return self._list_paths(self.path_info.bucket, prefix) + return self._list_paths( + self.path_info.bucket, prefix, progress_callback + ) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 8b36a51004..1d4fba22d8 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -691,7 +691,7 @@ def path_to_checksum(self, path): def checksum_to_path_info(self, checksum): return self.path_info / checksum[0:2] / checksum[2:] - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): raise NotImplementedError def all(self): @@ -892,11 +892,11 @@ def cache_exists(self, checksums, jobs=None, name=None): return list(checksums & set(self.all())) return self._cache_exists_traverse( - checksums, remote_checksums, jobs, name + checksums, remote_checksums, remote_size, jobs, name ) def _cache_exists_traverse( - self, checksums, remote_checksums, jobs=None, name=None + self, checksums, remote_checksums, remote_size, jobs=None, name=None ): logger.debug( "Querying {} checksums via threaded traverse".format( @@ -912,20 +912,21 @@ def _cache_exists_traverse( ] with Tqdm( desc="Querying " - + ("cache in " + name if name else "remote cache"), - total=len(traverse_prefixes), - unit="dir", + + ("cache in '{}'".format(name) if name else "remote cache"), + total=remote_size, + initial=len(remote_checksums), + unit="objects", ) as pbar: def list_with_update(prefix): - ret = map( + return map( self.path_to_checksum, - list(self.list_cache_paths(prefix=prefix)), - ) - pbar.update_desc( - "Querying cache in '{}'".format(self.path_info / prefix) + list( + self.list_cache_paths( + prefix=prefix, progress_callback=pbar.update + ) + ), ) - return ret with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: in_remote = executor.map(list_with_update, traverse_prefixes,) diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index a695dd09b3..4b73734457 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -304,14 +304,20 @@ def gdrive_download_file( ): gdrive_file.GetContentFile(to_file) - def gdrive_list_item(self, query): + def gdrive_list_item(self, query, progress_callback=None): param = {"q": query, "maxResults": 1000} param.update(self.list_params) file_list = self.drive.ListFile(param) # Isolate and decorate fetching of remote drive items in pages - get_list = gdrive_retry(lambda: next(file_list, None)) + def next_list(): + list_ = next(file_list, None) + if list_ and progress_callback: + progress_callback(len(list_)) + return list_ + + get_list = gdrive_retry(next_list) # Fetch pages until None is received, lazily flatten the thing return cat(iter(get_list, None)) @@ -455,7 +461,7 @@ def _download(self, from_info, to_file, name, no_progress_bar): file_id = self._get_remote_id(from_info) self.gdrive_download_file(file_id, to_file, name, no_progress_bar) - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): if not self.cache["ids"]: return @@ -470,7 +476,7 @@ def list_cache_paths(self, prefix=None): ) query = "({}) and trashed=false".format(parents_query) - for item in self.gdrive_list_item(query): + for item in self.gdrive_list_item(query, progress_callback): parent_id = item["parents"][0]["id"] yield posixpath.join(self.cache["ids"][parent_id], item["title"]) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index eab424458f..b067ef901f 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -127,18 +127,27 @@ def remove(self, path_info): blob.delete() - def _list_paths(self, path_info, max_items=None, prefix=None): + def _list_paths( + self, path_info, max_items=None, prefix=None, progress_callback=None + ): if prefix: prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:]) else: prefix = path_info.path - for blob in self.gs.bucket(path_info.bucket).list_blobs( - prefix=path_info.path, max_results=max_items + for page in ( + self.gs.bucket(path_info.bucket) + .list_blobs(prefix=path_info.path, max_results=max_items) + .pages ): - yield blob.name - - def list_cache_paths(self, prefix=None): - return self._list_paths(self.path_info, prefix=prefix) + if progress_callback: + progress_callback.update(page.num_items) + for blob in page: + yield blob.name + + def list_cache_paths(self, prefix=None, progress_callback=None): + return self._list_paths( + self.path_info, prefix=prefix, progress_callback=progress_callback + ) def walk_files(self, path_info): for fname in self._list_paths(path_info / ""): diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index f7e3439740..16935e4f10 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -153,7 +153,7 @@ def open(self, path_info, mode="r", encoding=None): raise FileNotFoundError(*e.args) raise - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): if not self.exists(self.path_info): return @@ -166,7 +166,10 @@ def list_cache_paths(self, prefix=None): with self.hdfs(self.path_info) as hdfs: while dirs: try: - for entry in hdfs.ls(dirs.pop(), detail=True): + entries = hdfs.ls(dirs.pop(), detail=True) + if progress_callback: + progress_callback.update(len(entries)) + for entry in entries: if entry["kind"] == "directory": dirs.append(urlparse(entry["name"]).path) elif entry["kind"] == "file": diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 7a93434f3d..a5a17308a0 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -56,7 +56,7 @@ def cache_dir(self, value): def supported(cls, config): return True - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): assert prefix is None assert self.path_info is not None return walk_files(self.path_info) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 9d928f7cb0..9fee4eb63f 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -38,6 +38,7 @@ class RemoteOSS(RemoteBASE): REQUIRES = {"oss2": "oss2"} PARAM_CHECKSUM = "etag" COPY_POLL_SECONDS = 5 + LIST_OBJECT_PAGE_SIZE = 100 def __init__(self, repo, config): super().__init__(repo, config) @@ -90,20 +91,28 @@ def remove(self, path_info): logger.debug("Removing oss://{}".format(path_info)) self.oss_service.delete_object(path_info.path) - def _list_paths(self, prefix): + def _list_paths(self, prefix, progress_callback=None): import oss2 - for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): + # oss2.ObjectIterator lacks any convenient way to iterate over pages + # instead of the flattened list of blobs + count = 0 + iterator = oss2.ObjectIterator(self.oss_service, prefix=prefix) + for blob in iterator: yield blob.key + if progress_callback: + count = (count + 1) % iterator.max_keys + if count == 0: + progress_callback(iterator.max_keys) - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: prefix = posixpath.join( self.path_info.path, prefix[:2], prefix[2:] ) else: prefix = self.path_info.path - return self._list_paths(prefix) + return self._list_paths(prefix, progress_callback) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 7bc9b0fcc1..8c961267b0 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -191,7 +191,9 @@ def remove(self, path_info): logger.debug("Removing {}".format(path_info)) self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) - def _list_objects(self, path_info, max_items=None, prefix=None): + def _list_objects( + self, path_info, max_items=None, prefix=None, progress_callback=None + ): """ Read config for list object api, paginate through list objects.""" kwargs = { "Bucket": path_info.bucket, @@ -202,16 +204,25 @@ def _list_objects(self, path_info, max_items=None, prefix=None): kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2]) paginator = self.s3.get_paginator(self.list_objects_api) for page in paginator.paginate(**kwargs): - yield from page.get("Contents", ()) + contents = page.get("Contents", ()) + if progress_callback: + progress_callback(len(contents)) + yield from contents - def _list_paths(self, path_info, max_items=None, prefix=None): + def _list_paths( + self, path_info, max_items=None, prefix=None, progress_callback=None + ): return ( item["Key"] - for item in self._list_objects(path_info, max_items, prefix) + for item in self._list_objects( + path_info, max_items, prefix, progress_callback + ) ) - def list_cache_paths(self, prefix=None): - return self._list_paths(self.path_info, prefix=prefix) + def list_cache_paths(self, prefix=None, progress_callback=None): + return self._list_paths( + self.path_info, prefix=prefix, progress_callback=progress_callback + ) def isfile(self, path_info): from botocore.exceptions import ClientError diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 1f191fff65..61f25ac900 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -249,7 +249,7 @@ def open(self, path_info, mode="r", encoding=None): else: yield io.TextIOWrapper(fd, encoding=encoding) - def list_cache_paths(self, prefix=None): + def list_cache_paths(self, prefix=None, progress_callback=None): assert prefix is None with self.ssh(self.path_info) as ssh: # If we simply return an iterator then with above closes instantly diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 54ec121a9a..7ac81755c0 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -82,7 +82,7 @@ def test_cache_exists(path_to_checksum, object_exists, traverse): remote.cache_exists(checksums) object_exists.assert_not_called() traverse.assert_called_with( - frozenset(checksums), set(range(256)), None, None + frozenset(checksums), set(range(256)), mock.ANY, None, None ) # default traverse @@ -105,8 +105,12 @@ def test_cache_exists(path_to_checksum, object_exists, traverse): def test_cache_exists_traverse(path_to_checksum, list_cache_paths): remote = RemoteBASE(None, {}) remote.path_info = PathInfo("foo") - remote._cache_exists_traverse({0}, set()) + remote._cache_exists_traverse({0}, set(), 4096) for i in range(1, 16): - list_cache_paths.assert_any_call(prefix="{:03x}".format(i)) + list_cache_paths.assert_any_call( + prefix="{:03x}".format(i), progress_callback=mock.ANY + ) for i in range(1, 256): - list_cache_paths.assert_any_call(prefix="{:02x}".format(i)) + list_cache_paths.assert_any_call( + prefix="{:02x}".format(i), progress_callback=mock.ANY + ) From 1c73fe75c786a48b870f9dbcc261857639461dc0 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Wed, 25 Mar 2020 15:59:54 +0900 Subject: [PATCH 2/7] Show remote size estimation status as no-total progress bar --- dvc/remote/base.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 1d4fba22d8..31ace20da3 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -847,9 +847,23 @@ def cache_exists(self, checksums, jobs=None, name=None): checksums = frozenset(checksums) prefix = "0" * self.TRAVERSE_PREFIX_LEN total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) - remote_checksums = set( - map(self.path_to_checksum, self.list_cache_paths(prefix=prefix)) - ) + with Tqdm( + desc="Estimating size of " + + ("cache in '{}'".format(name) if name else "remote cache"), + bar_format=Tqdm.BAR_FMT_NOTOTAL, + unit="objects", + ) as pbar: + remote_checksums = set( + map( + self.path_to_checksum, + self.list_cache_paths( + prefix=prefix, + progress_callback=lambda n: pbar.update( + n * total_prefixes + ), + ), + ) + ) if remote_checksums: remote_size = total_prefixes * len(remote_checksums) else: From d9ca41baf3d461d6bbd1cc94bdc43899ff1abec1 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 26 Mar 2020 12:50:38 +0900 Subject: [PATCH 3/7] Update pbar per item rather than per page --- dvc/remote/azure.py | 5 ++--- dvc/remote/base.py | 5 ++--- dvc/remote/gdrive.py | 14 +++++--------- dvc/remote/gs.py | 11 ++++------- dvc/remote/hdfs.py | 4 ++-- dvc/remote/oss.py | 12 +++--------- dvc/remote/s3.py | 7 +++++-- 7 files changed, 23 insertions(+), 35 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 36281c2622..f126aaeb43 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -97,10 +97,9 @@ def _list_paths(self, bucket, prefix, progress_callback=None): bucket, prefix=prefix, marker=next_marker ) - if progress_callback: - progress_callback(len(blobs)) - for blob in blobs: + if progress_callback: + progress_callback() yield blob.name if not blobs.next_marker: diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 31ace20da3..541779b5af 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -850,15 +850,14 @@ def cache_exists(self, checksums, jobs=None, name=None): with Tqdm( desc="Estimating size of " + ("cache in '{}'".format(name) if name else "remote cache"), - bar_format=Tqdm.BAR_FMT_NOTOTAL, - unit="objects", + unit="file", ) as pbar: remote_checksums = set( map( self.path_to_checksum, self.list_cache_paths( prefix=prefix, - progress_callback=lambda n: pbar.update( + progress_callback=lambda n=1: pbar.update( n * total_prefixes ), ), diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 4b73734457..27f4462023 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -304,20 +304,14 @@ def gdrive_download_file( ): gdrive_file.GetContentFile(to_file) - def gdrive_list_item(self, query, progress_callback=None): + def gdrive_list_item(self, query): param = {"q": query, "maxResults": 1000} param.update(self.list_params) file_list = self.drive.ListFile(param) # Isolate and decorate fetching of remote drive items in pages - def next_list(): - list_ = next(file_list, None) - if list_ and progress_callback: - progress_callback(len(list_)) - return list_ - - get_list = gdrive_retry(next_list) + get_list = gdrive_retry(lambda: next(file_list, None)) # Fetch pages until None is received, lazily flatten the thing return cat(iter(get_list, None)) @@ -476,7 +470,9 @@ def list_cache_paths(self, prefix=None, progress_callback=None): ) query = "({}) and trashed=false".format(parents_query) - for item in self.gdrive_list_item(query, progress_callback): + for item in self.gdrive_list_item(query): + if progress_callback: + progress_callback() parent_id = item["parents"][0]["id"] yield posixpath.join(self.cache["ids"][parent_id], item["title"]) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index b067ef901f..6f53fc8d13 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -134,15 +134,12 @@ def _list_paths( prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:]) else: prefix = path_info.path - for page in ( - self.gs.bucket(path_info.bucket) - .list_blobs(prefix=path_info.path, max_results=max_items) - .pages + for blob in self.gs.bucket(path_info.bucket).list_blobs( + prefix=path_info.path, max_results=max_items ): if progress_callback: - progress_callback.update(page.num_items) - for blob in page: - yield blob.name + progress_callback() + yield blob.name def list_cache_paths(self, prefix=None, progress_callback=None): return self._list_paths( diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 16935e4f10..b280aece20 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -167,12 +167,12 @@ def list_cache_paths(self, prefix=None, progress_callback=None): while dirs: try: entries = hdfs.ls(dirs.pop(), detail=True) - if progress_callback: - progress_callback.update(len(entries)) for entry in entries: if entry["kind"] == "directory": dirs.append(urlparse(entry["name"]).path) elif entry["kind"] == "file": + if progress_callback: + progress_callback.update() yield urlparse(entry["name"]).path except IOError as e: # When searching for a specific prefix pyarrow raises an diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 9fee4eb63f..9642fb511b 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -94,16 +94,10 @@ def remove(self, path_info): def _list_paths(self, prefix, progress_callback=None): import oss2 - # oss2.ObjectIterator lacks any convenient way to iterate over pages - # instead of the flattened list of blobs - count = 0 - iterator = oss2.ObjectIterator(self.oss_service, prefix=prefix) - for blob in iterator: - yield blob.key + for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): if progress_callback: - count = (count + 1) % iterator.max_keys - if count == 0: - progress_callback(iterator.max_keys) + progress_callback() + yield blob.key def list_cache_paths(self, prefix=None, progress_callback=None): if prefix: diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 8c961267b0..c43e46bb3e 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -206,8 +206,11 @@ def _list_objects( for page in paginator.paginate(**kwargs): contents = page.get("Contents", ()) if progress_callback: - progress_callback(len(contents)) - yield from contents + for item in contents: + progress_callback() + yield item + else: + yield from contents def _list_paths( self, path_info, max_items=None, prefix=None, progress_callback=None From d84fd8cc0ca05cec7595d2ba95d9e9bb58610fce Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Thu, 26 Mar 2020 15:17:51 +0900 Subject: [PATCH 4/7] make tests more explicit (don't use mock.ANY) --- tests/unit/remote/test_base.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 7ac81755c0..df688791c1 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -8,6 +8,16 @@ from dvc.remote.base import RemoteMissingDepsError +class _CallableOrNone(object): + """Helper for testing if object is callable() or None.""" + + def __eq__(self, other): + return True if other is None else callable(other) + + +CallableOrNone = _CallableOrNone() + + class TestRemoteBASE(object): REMOTE_CLS = RemoteBASE @@ -82,7 +92,11 @@ def test_cache_exists(path_to_checksum, object_exists, traverse): remote.cache_exists(checksums) object_exists.assert_not_called() traverse.assert_called_with( - frozenset(checksums), set(range(256)), mock.ANY, None, None + frozenset(checksums), + set(range(256)), + 256 * pow(16, remote.TRAVERSE_PREFIX_LEN), + None, + None, ) # default traverse @@ -108,9 +122,9 @@ def test_cache_exists_traverse(path_to_checksum, list_cache_paths): remote._cache_exists_traverse({0}, set(), 4096) for i in range(1, 16): list_cache_paths.assert_any_call( - prefix="{:03x}".format(i), progress_callback=mock.ANY + prefix="{:03x}".format(i), progress_callback=CallableOrNone ) for i in range(1, 256): list_cache_paths.assert_any_call( - prefix="{:02x}".format(i), progress_callback=mock.ANY + prefix="{:02x}".format(i), progress_callback=CallableOrNone ) From 56c29423cf5f4e142613672ab6a95d414461d0f1 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Thu, 26 Mar 2020 14:57:01 +0000 Subject: [PATCH 5/7] Apply suggestions from code review minor fixes/tidy --- dvc/remote/hdfs.py | 2 +- tests/unit/remote/test_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index b280aece20..c874b20682 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -172,7 +172,7 @@ def list_cache_paths(self, prefix=None, progress_callback=None): dirs.append(urlparse(entry["name"]).path) elif entry["kind"] == "file": if progress_callback: - progress_callback.update() + progress_callback() yield urlparse(entry["name"]).path except IOError as e: # When searching for a specific prefix pyarrow raises an diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index df688791c1..64e64bc676 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -12,7 +12,7 @@ class _CallableOrNone(object): """Helper for testing if object is callable() or None.""" def __eq__(self, other): - return True if other is None else callable(other) + return other is None or callable(other) CallableOrNone = _CallableOrNone() From b200d7d5d53f9d2b2a887ea376717912b4c370cc Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Fri, 27 Mar 2020 12:08:11 +0900 Subject: [PATCH 6/7] Use progress_callback in ssh/local list_cache_paths() - keep ssh/local consistent with other remotes (even though `dvc gc`/`RemoteBASE.all()` do not currently use progress bars) --- dvc/remote/local.py | 7 ++++++- dvc/remote/ssh/__init__.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/dvc/remote/local.py b/dvc/remote/local.py index a5a17308a0..708480a135 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -59,7 +59,12 @@ def supported(cls, config): def list_cache_paths(self, prefix=None, progress_callback=None): assert prefix is None assert self.path_info is not None - return walk_files(self.path_info) + if progress_callback: + for path in walk_files(self.path_info): + progress_callback() + yield path + else: + yield from walk_files(self.path_info) def get(self, md5): if not md5: diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 61f25ac900..703617b2af 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -253,7 +253,12 @@ def list_cache_paths(self, prefix=None, progress_callback=None): assert prefix is None with self.ssh(self.path_info) as ssh: # If we simply return an iterator then with above closes instantly - yield from ssh.walk_files(self.path_info.path) + if progress_callback: + for path in ssh.walk_files(self.path_info.path): + progress_callback() + yield path + else: + yield from ssh.walk_files(self.path_info.path) def walk_files(self, path_info): with self.ssh(path_info) as ssh: From ba1dc78b7bbc55e494347c0bc376fd6f1370f45e Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Sat, 28 Mar 2020 12:10:50 +0900 Subject: [PATCH 7/7] Fix review issues --- dvc/remote/base.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 541779b5af..195dc62e86 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -852,17 +852,15 @@ def cache_exists(self, checksums, jobs=None, name=None): + ("cache in '{}'".format(name) if name else "remote cache"), unit="file", ) as pbar: - remote_checksums = set( - map( - self.path_to_checksum, - self.list_cache_paths( - prefix=prefix, - progress_callback=lambda n=1: pbar.update( - n * total_prefixes - ), - ), - ) + + def update(n=1): + pbar.update(n * total_prefixes) + + paths = self.list_cache_paths( + prefix=prefix, progress_callback=update ) + remote_checksums = set(map(self.path_to_checksum, paths)) + if remote_checksums: remote_size = total_prefixes * len(remote_checksums) else: @@ -932,14 +930,10 @@ def _cache_exists_traverse( ) as pbar: def list_with_update(prefix): - return map( - self.path_to_checksum, - list( - self.list_cache_paths( - prefix=prefix, progress_callback=pbar.update - ) - ), + paths = self.list_cache_paths( + prefix=prefix, progress_callback=pbar.update ) + return map(self.path_to_checksum, list(paths)) with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: in_remote = executor.map(list_with_update, traverse_prefixes,)