Skip to content

Commit

Permalink
remote: Optimize traverse/no_traverse behavior
Browse files Browse the repository at this point in the history
- estimate remote file count by fetching a single parent cache dir, then
  determine whether or not to use no_traverse method for checking
  remainder of cache entries in `cache_exists`
- thread requests when traversing full remote file lists (by fetching
  one parent cache dir per thread)
  • Loading branch information
pmrowla committed Mar 20, 2020
1 parent 8f546f2 commit 97aedaa
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 27 deletions.
9 changes: 7 additions & 2 deletions dvc/remote/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class RemoteAZURE(RemoteBASE):
r"(?P<connection_string>.+)?)?)$"
)
REQUIRES = {"azure-storage-blob": "azure.storage.blob"}
DEFAULT_NO_TRAVERSE = False
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5

Expand Down Expand Up @@ -103,8 +104,12 @@ def _list_paths(self, bucket, prefix):

next_marker = blobs.next_marker

def list_cache_paths(self):
return self._list_paths(self.path_info.bucket, self.path_info.path)
def list_cache_paths(self, prefix=None):
if prefix:
prefix = "/".join([self.path_info.path, prefix[:2], prefix[2:]])
else:
prefix = self.path_info.path
return self._list_paths(self.path_info.bucket, prefix)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
93 changes: 89 additions & 4 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class RemoteBASE(object):
DEFAULT_CACHE_TYPES = ["copy"]
DEFAULT_NO_TRAVERSE = True
DEFAULT_VERIFY = False
LIST_OBJECT_PAGE_SIZE = 1000
TRAVERSE_WEIGHT_MULTIPLIER = 20

CACHE_MODE = None
SHARED_MODE_MAP = {None: (None, None), "group": (None, None)}
Expand Down Expand Up @@ -686,7 +688,7 @@ def path_to_checksum(self, path):
def checksum_to_path_info(self, checksum):
return self.path_info / checksum[0:2] / checksum[2:]

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
raise NotImplementedError

def all(self):
Expand Down Expand Up @@ -804,7 +806,9 @@ def cache_exists(self, checksums, jobs=None, name=None):
- Traverse: Get a list of all the files in the remote
(traversing the cache directory) and compare it with
the given checksums.
the given checksums. Cache parent directories will
be retrieved in parallel threads, and a progress bar
will be displayed.
- No traverse: For each given checksum, run the `exists`
method and filter the checksums that aren't on the remote.
Expand All @@ -820,9 +824,90 @@ def cache_exists(self, checksums, jobs=None, name=None):
Returns:
A list with checksums that were found in the remote
"""
if not self.no_traverse:
return list(set(checksums) & set(self.all()))

if self.no_traverse:
return self._cache_exists_no_traverse(checksums, jobs, name)

# Fetch one parent cache dir for estimating size of entire remote cache
checksums = frozenset(checksums)
remote_checksums = set(
map(self.path_to_checksum, self.list_cache_paths(prefix="00"))
)
if not remote_checksums:
remote_size = 256
else:
remote_size = 256 * len(remote_checksums)
logger.debug("Estimated remote size: {} files".format(remote_size))

traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE
# For sufficiently large remotes, traverse must be weighted to account
# for performance overhead from large lists/sets.
# From testing with S3, for remotes with 1M+ files, no_traverse is
# faster until len(checksums) is at least 10k~100k
if remote_size > 500000:
traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER
else:
traverse_weight = traverse_pages
if len(checksums) < traverse_weight:
logger.debug(
"Large remote, using no_traverse for remaining checksums"
)
return list(
checksums & remote_checksums
) + self._cache_exists_no_traverse(
checksums - remote_checksums, jobs, name
)

if traverse_pages < 256 / self.JOBS:
# Threaded traverse will require making at least 255 more requests
# to the remote, so for small enough remotes, fetching the entire
# list at once will require fewer requests (but also take into
# account that this must be done sequentially rather than in
# parallel)
logger.debug(
"Querying {} checksums via default traverse".format(
len(checksums)
)
)
return list(checksums & set(self.all()))

logger.debug(
"Querying {} checksums via threaded traverse".format(
len(checksums)
)
)

with Tqdm(
desc="Querying "
+ ("cache in " + name if name else "remote cache"),
total=256,
unit="dir",
) as pbar:

def list_with_update(prefix):
ret = map(
self.path_to_checksum,
list(self.list_cache_paths(prefix=prefix)),
)
pbar.update_desc(
"Querying cache in {}/{}".format(self.path_info, prefix)
)
return ret

with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor:
in_remote = executor.map(
list_with_update,
["{:02x}".format(i) for i in range(1, 256)],
)
remote_checksums.update(
itertools.chain.from_iterable(in_remote)
)
return list(checksums & remote_checksums)

def _cache_exists_no_traverse(self, checksums, jobs=None, name=None):
logger.debug(
"Querying {} checksums via no_traverse".format(len(checksums))
)
with Tqdm(
desc="Querying "
+ ("cache in " + name if name else "remote cache"),
Expand Down
12 changes: 10 additions & 2 deletions dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class RemoteGDrive(RemoteBASE):
REQUIRES = {"pydrive2": "pydrive2"}
DEFAULT_NO_TRAVERSE = False
DEFAULT_VERIFY = True
# Always prefer traverse for GDrive since API usage quotas are a concern.
TRAVERSE_WEIGHT_MULTIPLIER = 1

GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA"
DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json"
Expand Down Expand Up @@ -414,12 +416,18 @@ def _download(self, from_info, to_file, name, no_progress_bar):
file_id = self._get_remote_id(from_info)
self.gdrive_download_file(file_id, to_file, name, no_progress_bar)

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
if not self.cache["ids"]:
return

if prefix:
dir_ids = self.cache["dirs"].get(prefix[:2])
if not dir_ids:
return
else:
dir_ids = self.cache["ids"]
parents_query = " or ".join(
"'{}' in parents".format(dir_id) for dir_id in self.cache["ids"]
"'{}' in parents".format(dir_id) for dir_id in dir_ids
)
query = "({}) and trashed=false".format(parents_query)

Expand Down
11 changes: 8 additions & 3 deletions dvc/remote/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class RemoteGS(RemoteBASE):
scheme = Schemes.GS
path_cls = CloudURLInfo
REQUIRES = {"google-cloud-storage": "google.cloud.storage"}
DEFAULT_NO_TRAVERSE = False
PARAM_CHECKSUM = "md5"

def __init__(self, repo, config):
Expand Down Expand Up @@ -126,14 +127,18 @@ def remove(self, path_info):

blob.delete()

def _list_paths(self, path_info, max_items=None):
def _list_paths(self, path_info, max_items=None, prefix=None):
if prefix:
prefix = "/".join([path_info.path, prefix[:2], prefix[2:]])
else:
prefix = path_info.path
for blob in self.gs.bucket(path_info.bucket).list_blobs(
prefix=path_info.path, max_results=max_items
):
yield blob.name

def list_cache_paths(self):
return self._list_paths(self.path_info)
def list_cache_paths(self, prefix=None):
return self._list_paths(self.path_info, prefix=prefix)

def walk_files(self, path_info):
for fname in self._list_paths(path_info / ""):
Expand Down
25 changes: 18 additions & 7 deletions dvc/remote/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RemoteHDFS(RemoteBASE):
REGEX = r"^hdfs://((?P<user>.*)@)?.*$"
PARAM_CHECKSUM = "checksum"
REQUIRES = {"pyarrow": "pyarrow"}
DEFAULT_NO_TRAVERSE = False

def __init__(self, repo, config):
super().__init__(repo, config)
Expand Down Expand Up @@ -152,16 +153,26 @@ def open(self, path_info, mode="r", encoding=None):
raise FileNotFoundError(*e.args)
raise

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
if not self.exists(self.path_info):
return

dirs = deque([self.path_info.path])
if prefix:
root = "/".join([self.path_info.path, prefix[:2]])
else:
root = self.path_info.path
dirs = deque([root])

with self.hdfs(self.path_info) as hdfs:
while dirs:
for entry in hdfs.ls(dirs.pop(), detail=True):
if entry["kind"] == "directory":
dirs.append(urlparse(entry["name"]).path)
elif entry["kind"] == "file":
yield urlparse(entry["name"]).path
try:
for entry in hdfs.ls(dirs.pop(), detail=True):
if entry["kind"] == "directory":
dirs.append(urlparse(entry["name"]).path)
elif entry["kind"] == "file":
yield urlparse(entry["name"]).path
except IOError as e:
# When searching for a specific prefix pyarrow raises an
# exception if the specified cache dir does not exist
if not prefix:
raise e
2 changes: 1 addition & 1 deletion dvc/remote/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def cache_dir(self, value):
def supported(cls, config):
return True

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
assert self.path_info is not None
return walk_files(self.path_info)

Expand Down
9 changes: 7 additions & 2 deletions dvc/remote/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RemoteOSS(RemoteBASE):
scheme = Schemes.OSS
path_cls = CloudURLInfo
REQUIRES = {"oss2": "oss2"}
DEFAULT_NO_TRAVERSE = False
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5

Expand Down Expand Up @@ -95,8 +96,12 @@ def _list_paths(self, prefix):
for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix):
yield blob.key

def list_cache_paths(self):
return self._list_paths(self.path_info.path)
def list_cache_paths(self, prefix=None):
if prefix:
prefix = "/".join([self.path_info.path, prefix[:2], prefix[2:]])
else:
prefix = self.path_info.path
return self._list_paths(prefix)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
14 changes: 9 additions & 5 deletions dvc/remote/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RemoteS3(RemoteBASE):
scheme = Schemes.S3
path_cls = CloudURLInfo
REQUIRES = {"boto3": "boto3"}
DEFAULT_NO_TRAVERSE = False
PARAM_CHECKSUM = "etag"

def __init__(self, repo, config):
Expand Down Expand Up @@ -190,24 +191,27 @@ def remove(self, path_info):
logger.debug("Removing {}".format(path_info))
self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path)

def _list_objects(self, path_info, max_items=None):
def _list_objects(self, path_info, max_items=None, prefix=None):
""" Read config for list object api, paginate through list objects."""
kwargs = {
"Bucket": path_info.bucket,
"Prefix": path_info.path,
"PaginationConfig": {"MaxItems": max_items},
}
if prefix:
kwargs["Prefix"] = "/".join([path_info.path, prefix[:2]])
paginator = self.s3.get_paginator(self.list_objects_api)
for page in paginator.paginate(**kwargs):
yield from page.get("Contents", ())

def _list_paths(self, path_info, max_items=None):
def _list_paths(self, path_info, max_items=None, prefix=None):
return (
item["Key"] for item in self._list_objects(path_info, max_items)
item["Key"]
for item in self._list_objects(path_info, max_items, prefix)
)

def list_cache_paths(self):
return self._list_paths(self.path_info)
def list_cache_paths(self, prefix=None):
return self._list_paths(self.path_info, prefix=prefix)

def isfile(self, path_info):
from botocore.exceptions import ClientError
Expand Down
2 changes: 1 addition & 1 deletion dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def open(self, path_info, mode="r", encoding=None):
else:
yield io.TextIOWrapper(fd, encoding=encoding)

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
with self.ssh(self.path_info) as ssh:
# If we simply return an iterator then with above closes instantly
yield from ssh.walk_files(self.path_info.path)
Expand Down

0 comments on commit 97aedaa

Please sign in to comment.