Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remote: use progress bar when paginating #3532

Merged
merged 7 commits into from
Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions dvc/remote/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class RemoteAZURE(RemoteBASE):
REQUIRES = {"azure-storage-blob": "azure.storage.blob"}
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5
LIST_OBJECT_PAGE_SIZE = 5000
casperdcl marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, repo, config):
super().__init__(repo, config)
Expand Down Expand Up @@ -88,7 +89,7 @@ def remove(self, path_info):
logger.debug("Removing {}".format(path_info))
self.blob_service.delete_blob(path_info.bucket, path_info.path)

def _list_paths(self, bucket, prefix):
def _list_paths(self, bucket, prefix, progress_callback=None):
blob_service = self.blob_service
next_marker = None
while True:
Expand All @@ -97,21 +98,25 @@ def _list_paths(self, bucket, prefix):
)

for blob in blobs:
if progress_callback:
progress_callback()
yield blob.name

if not blobs.next_marker:
break

next_marker = blobs.next_marker

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
if prefix:
prefix = posixpath.join(
self.path_info.path, prefix[:2], prefix[2:]
)
else:
prefix = self.path_info.path
return self._list_paths(self.path_info.bucket, prefix)
return self._list_paths(
self.path_info.bucket, prefix, progress_callback
)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
40 changes: 24 additions & 16 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def path_to_checksum(self, path):
def checksum_to_path_info(self, checksum):
return self.path_info / checksum[0:2] / checksum[2:]

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

progress_callback is unused in ssh and local remotes, but the parameter was added to both so that overridden list_cache_paths() in ssh/local still match the abstract method defined in RemoteBASE.list_cache_paths()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but presumably it should be used now that it's added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both local and ssh have their own overridden cache_exists() implementation and neither one uses list_cache_paths() at all during cache_exists(). The only time list_cache_paths() would get called for either local or ssh remotes would be for RemoteBASE.all() during a dvc gc call.

all() does not currently display any progressbars at all, but now that I'm thinking about this, it probably should

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed; though I'm not sure in principle about having a function exists which takes a non-null progress_callback and ignores it. This implies a bar will be displaying hanging at 0% on screen and then eventually suddenly disappear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_cache_paths() for local and ssh have been updated to use progress_bar (if it is set) to keep them consistent with the other remotes. A separate issue (#3543) has been filed for updating the dvc gc/RemoteBASE.all() behavior and will be addressed in a follow-up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw just to clarify - every time I've said "unused?" it's more a reminder to address at some point (could be in a future PR) rather than an expectation to address in this PR.

raise NotImplementedError

def all(self):
Expand Down Expand Up @@ -847,9 +847,20 @@ def cache_exists(self, checksums, jobs=None, name=None):
checksums = frozenset(checksums)
prefix = "0" * self.TRAVERSE_PREFIX_LEN
total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN)
remote_checksums = set(
map(self.path_to_checksum, self.list_cache_paths(prefix=prefix))
)
with Tqdm(
desc="Estimating size of "
+ ("cache in '{}'".format(name) if name else "remote cache"),
unit="file",
) as pbar:

def update(n=1):
pbar.update(n * total_prefixes)

paths = self.list_cache_paths(
prefix=prefix, progress_callback=update
)
remote_checksums = set(map(self.path_to_checksum, paths))

if remote_checksums:
remote_size = total_prefixes * len(remote_checksums)
else:
Expand Down Expand Up @@ -892,11 +903,11 @@ def cache_exists(self, checksums, jobs=None, name=None):
return list(checksums & set(self.all()))

return self._cache_exists_traverse(
checksums, remote_checksums, jobs, name
checksums, remote_checksums, remote_size, jobs, name
)

def _cache_exists_traverse(
self, checksums, remote_checksums, jobs=None, name=None
self, checksums, remote_checksums, remote_size, jobs=None, name=None
):
logger.debug(
"Querying {} checksums via threaded traverse".format(
Expand All @@ -912,20 +923,17 @@ def _cache_exists_traverse(
]
with Tqdm(
desc="Querying "
+ ("cache in " + name if name else "remote cache"),
total=len(traverse_prefixes),
unit="dir",
+ ("cache in '{}'".format(name) if name else "remote cache"),
total=remote_size,
initial=len(remote_checksums),
casperdcl marked this conversation as resolved.
Show resolved Hide resolved
unit="objects",
) as pbar:

def list_with_update(prefix):
ret = map(
self.path_to_checksum,
list(self.list_cache_paths(prefix=prefix)),
paths = self.list_cache_paths(
prefix=prefix, progress_callback=pbar.update
)
pbar.update_desc(
"Querying cache in '{}'".format(self.path_info / prefix)
)
return ret
return map(self.path_to_checksum, list(paths))

with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor:
in_remote = executor.map(list_with_update, traverse_prefixes,)
Expand Down
4 changes: 3 additions & 1 deletion dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _download(self, from_info, to_file, name, no_progress_bar):
file_id = self._get_remote_id(from_info)
self.gdrive_download_file(file_id, to_file, name, no_progress_bar)

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
if not self.cache["ids"]:
return

Expand All @@ -471,6 +471,8 @@ def list_cache_paths(self, prefix=None):
query = "({}) and trashed=false".format(parents_query)

for item in self.gdrive_list_item(query):
if progress_callback:
progress_callback()
parent_id = item["parents"][0]["id"]
yield posixpath.join(self.cache["ids"][parent_id], item["title"])

Expand Down
12 changes: 9 additions & 3 deletions dvc/remote/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,24 @@ def remove(self, path_info):

blob.delete()

def _list_paths(self, path_info, max_items=None, prefix=None):
def _list_paths(
self, path_info, max_items=None, prefix=None, progress_callback=None
):
if prefix:
prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:])
else:
prefix = path_info.path
for blob in self.gs.bucket(path_info.bucket).list_blobs(
prefix=path_info.path, max_results=max_items
):
if progress_callback:
progress_callback()
yield blob.name

def list_cache_paths(self, prefix=None):
return self._list_paths(self.path_info, prefix=prefix)
def list_cache_paths(self, prefix=None, progress_callback=None):
return self._list_paths(
self.path_info, prefix=prefix, progress_callback=progress_callback
)

def walk_files(self, path_info):
for fname in self._list_paths(path_info / ""):
Expand Down
7 changes: 5 additions & 2 deletions dvc/remote/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def open(self, path_info, mode="r", encoding=None):
raise FileNotFoundError(*e.args)
raise

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
if not self.exists(self.path_info):
return

Expand All @@ -166,10 +166,13 @@ def list_cache_paths(self, prefix=None):
with self.hdfs(self.path_info) as hdfs:
while dirs:
try:
for entry in hdfs.ls(dirs.pop(), detail=True):
entries = hdfs.ls(dirs.pop(), detail=True)
for entry in entries:
if entry["kind"] == "directory":
dirs.append(urlparse(entry["name"]).path)
elif entry["kind"] == "file":
if progress_callback:
progress_callback()
yield urlparse(entry["name"]).path
except IOError as e:
# When searching for a specific prefix pyarrow raises an
Expand Down
9 changes: 7 additions & 2 deletions dvc/remote/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ def cache_dir(self, value):
def supported(cls, config):
return True

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
casperdcl marked this conversation as resolved.
Show resolved Hide resolved
assert prefix is None
assert self.path_info is not None
return walk_files(self.path_info)
if progress_callback:
for path in walk_files(self.path_info):
progress_callback()
yield path
else:
yield from walk_files(self.path_info)

def get(self, md5):
if not md5:
Expand Down
9 changes: 6 additions & 3 deletions dvc/remote/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class RemoteOSS(RemoteBASE):
REQUIRES = {"oss2": "oss2"}
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5
LIST_OBJECT_PAGE_SIZE = 100
casperdcl marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, repo, config):
super().__init__(repo, config)
Expand Down Expand Up @@ -90,20 +91,22 @@ def remove(self, path_info):
logger.debug("Removing oss://{}".format(path_info))
self.oss_service.delete_object(path_info.path)

def _list_paths(self, prefix):
def _list_paths(self, prefix, progress_callback=None):
import oss2

for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix):
if progress_callback:
progress_callback()
yield blob.key

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
if prefix:
prefix = posixpath.join(
self.path_info.path, prefix[:2], prefix[2:]
)
else:
prefix = self.path_info.path
return self._list_paths(prefix)
return self._list_paths(prefix, progress_callback)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
28 changes: 21 additions & 7 deletions dvc/remote/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def remove(self, path_info):
logger.debug("Removing {}".format(path_info))
self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path)

def _list_objects(self, path_info, max_items=None, prefix=None):
def _list_objects(
self, path_info, max_items=None, prefix=None, progress_callback=None
):
""" Read config for list object api, paginate through list objects."""
kwargs = {
"Bucket": path_info.bucket,
Expand All @@ -202,16 +204,28 @@ def _list_objects(self, path_info, max_items=None, prefix=None):
kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2])
paginator = self.s3.get_paginator(self.list_objects_api)
for page in paginator.paginate(**kwargs):
yield from page.get("Contents", ())

def _list_paths(self, path_info, max_items=None, prefix=None):
contents = page.get("Contents", ())
if progress_callback:
for item in contents:
progress_callback()
yield item
else:
yield from contents

def _list_paths(
self, path_info, max_items=None, prefix=None, progress_callback=None
):
return (
item["Key"]
for item in self._list_objects(path_info, max_items, prefix)
for item in self._list_objects(
path_info, max_items, prefix, progress_callback
)
)

def list_cache_paths(self, prefix=None):
return self._list_paths(self.path_info, prefix=prefix)
def list_cache_paths(self, prefix=None, progress_callback=None):
return self._list_paths(
self.path_info, prefix=prefix, progress_callback=progress_callback
)

def isfile(self, path_info):
from botocore.exceptions import ClientError
Expand Down
9 changes: 7 additions & 2 deletions dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,16 @@ def open(self, path_info, mode="r", encoding=None):
else:
yield io.TextIOWrapper(fd, encoding=encoding)

def list_cache_paths(self, prefix=None):
def list_cache_paths(self, prefix=None, progress_callback=None):
pmrowla marked this conversation as resolved.
Show resolved Hide resolved
assert prefix is None
with self.ssh(self.path_info) as ssh:
# If we simply return an iterator then with above closes instantly
yield from ssh.walk_files(self.path_info.path)
if progress_callback:
for path in ssh.walk_files(self.path_info.path):
progress_callback()
yield path
else:
yield from ssh.walk_files(self.path_info.path)

def walk_files(self, path_info):
with self.ssh(path_info) as ssh:
Expand Down
26 changes: 22 additions & 4 deletions tests/unit/remote/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
from dvc.remote.base import RemoteMissingDepsError


class _CallableOrNone(object):
"""Helper for testing if object is callable() or None."""

def __eq__(self, other):
return other is None or callable(other)


CallableOrNone = _CallableOrNone()


class TestRemoteBASE(object):
REMOTE_CLS = RemoteBASE

Expand Down Expand Up @@ -82,7 +92,11 @@ def test_cache_exists(path_to_checksum, object_exists, traverse):
remote.cache_exists(checksums)
object_exists.assert_not_called()
traverse.assert_called_with(
frozenset(checksums), set(range(256)), None, None
frozenset(checksums),
set(range(256)),
256 * pow(16, remote.TRAVERSE_PREFIX_LEN),
None,
None,
)

# default traverse
Expand All @@ -105,8 +119,12 @@ def test_cache_exists(path_to_checksum, object_exists, traverse):
def test_cache_exists_traverse(path_to_checksum, list_cache_paths):
remote = RemoteBASE(None, {})
remote.path_info = PathInfo("foo")
remote._cache_exists_traverse({0}, set())
remote._cache_exists_traverse({0}, set(), 4096)
for i in range(1, 16):
list_cache_paths.assert_any_call(prefix="{:03x}".format(i))
list_cache_paths.assert_any_call(
prefix="{:03x}".format(i), progress_callback=CallableOrNone
)
for i in range(1, 256):
list_cache_paths.assert_any_call(prefix="{:02x}".format(i))
list_cache_paths.assert_any_call(
prefix="{:02x}".format(i), progress_callback=CallableOrNone
)