Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remote: Optimize traverse/no_traverse behavior #3501

Merged
merged 11 commits into from
Mar 25, 2020
12 changes: 10 additions & 2 deletions dvc/remote/azure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import posixpath
import re
from datetime import datetime, timedelta
from urllib.parse import urlparse
Expand All @@ -26,6 +27,7 @@ class RemoteAZURE(RemoteBASE):
r"(?P<connection_string>.+)?)?)$"
)
REQUIRES = {"azure-storage-blob": "azure.storage.blob"}
DEFAULT_NO_TRAVERSE = False
pmrowla marked this conversation as resolved.
Show resolved Hide resolved
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5

Expand Down Expand Up @@ -103,8 +105,14 @@ def _list_paths(self, bucket, prefix):

next_marker = blobs.next_marker

def list_cache_paths(self):
return self._list_paths(self.path_info.bucket, self.path_info.path)
def list_cache_paths(self, prefix=None):
if prefix:
prefix = posixpath.join(
self.path_info.path, prefix[:2], prefix[2:]
)
else:
prefix = self.path_info.path
return self._list_paths(self.path_info.bucket, prefix)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
122 changes: 118 additions & 4 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class RemoteBASE(object):
DEFAULT_CACHE_TYPES = ["copy"]
DEFAULT_NO_TRAVERSE = True
DEFAULT_VERIFY = False
LIST_OBJECT_PAGE_SIZE = 1000
TRAVERSE_WEIGHT_MULTIPLIER = 20
TRAVERSE_PREFIX_LEN = 3
TRAVERSE_THRESHOLD_SIZE = 500000

CACHE_MODE = None
SHARED_MODE_MAP = {None: (None, None), "group": (None, None)}
Expand Down Expand Up @@ -686,7 +690,7 @@ def path_to_checksum(self, path):
def checksum_to_path_info(self, checksum):
return self.path_info / checksum[0:2] / checksum[2:]

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
raise NotImplementedError

def all(self):
Expand Down Expand Up @@ -804,7 +808,9 @@ def cache_exists(self, checksums, jobs=None, name=None):

- Traverse: Get a list of all the files in the remote
(traversing the cache directory) and compare it with
the given checksums.
the given checksums. Cache parent directories will
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
be retrieved in parallel threads, and a progress bar
will be displayed.

- No traverse: For each given checksum, run the `exists`
method and filter the checksums that aren't on the remote.
Expand All @@ -817,12 +823,120 @@ def cache_exists(self, checksums, jobs=None, name=None):
check if particular file exists much quicker, use their own
implementation of cache_exists (see ssh, local).

Which method to use will be automatically determined after estimating
the size of the remote cache, and comparing the estimated size with
len(checksums). To estimate the size of the remote cache, we fetch
a small subset of cache entries (i.e. entries starting with "00...").
Based on the number of entries in that subset, the size of the full
cache can be estimated, since the cache is evenly distributed according
to checksum.

Returns:
A list with checksums that were found in the remote
"""
if not self.no_traverse:
return list(set(checksums) & set(self.all()))
# Remotes which do not use traverse prefix should override
# cache_exists() (see ssh, local)
assert self.TRAVERSE_PREFIX_LEN >= 2

if self.no_traverse:
return self._cache_exists_no_traverse(checksums, jobs, name)
pmrowla marked this conversation as resolved.
Show resolved Hide resolved

# Fetch cache entries beginning with "00..." prefix for estimating the
# size of entire remote cache
checksums = frozenset(checksums)
prefix = "0" * self.TRAVERSE_PREFIX_LEN
total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN)
remote_checksums = set(
map(self.path_to_checksum, self.list_cache_paths(prefix=prefix))
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
)
if remote_checksums:
remote_size = total_prefixes * len(remote_checksums)
else:
remote_size = total_prefixes
logger.debug("Estimated remote size: {} files".format(remote_size))
shcheklein marked this conversation as resolved.
Show resolved Hide resolved

traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE
# For sufficiently large remotes, traverse must be weighted to account
# for performance overhead from large lists/sets.
# From testing with S3, for remotes with 1M+ files, no_traverse is
# faster until len(checksums) is at least 10k~100k
if remote_size > self.TRAVERSE_THRESHOLD_SIZE:
traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER
else:
traverse_weight = traverse_pages
if len(checksums) < traverse_weight:
logger.debug(
"Large remote ({} checksums < {} traverse weight), "
pmrowla marked this conversation as resolved.
Show resolved Hide resolved
"using no_traverse for remaining checksums".format(
pmrowla marked this conversation as resolved.
Show resolved Hide resolved
len(checksums), traverse_weight
)
)
return list(
checksums & remote_checksums
) + self._cache_exists_no_traverse(
checksums - remote_checksums, jobs, name
)

if traverse_pages < 256 / self.JOBS:
# Threaded traverse will require making at least 255 more requests
# to the remote, so for small enough remotes, fetching the entire
# list at once will require fewer requests (but also take into
# account that this must be done sequentially rather than in
# parallel)
logger.debug(
"Querying {} checksums via default traverse".format(
len(checksums)
)
)
return list(checksums & set(self.all()))

return self._cache_exists_traverse(
checksums, remote_checksums, jobs, name
)

def _cache_exists_traverse(
self, checksums, remote_checksums, jobs=None, name=None
):
logger.debug(
"Querying {} checksums via threaded traverse".format(
len(checksums)
)
)

traverse_prefixes = ["{:02x}".format(i) for i in range(1, 256)]
if self.TRAVERSE_PREFIX_LEN > 2:
traverse_prefixes += [
"{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN)
for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2))
]
pmrowla marked this conversation as resolved.
Show resolved Hide resolved
with Tqdm(
desc="Querying "
+ ("cache in " + name if name else "remote cache"),
skshetry marked this conversation as resolved.
Show resolved Hide resolved
total=len(traverse_prefixes),
unit="dir",
) as pbar:

def list_with_update(prefix):
ret = map(
self.path_to_checksum,
list(self.list_cache_paths(prefix=prefix)),
)
pbar.update_desc(
"Querying cache in {}/{}".format(self.path_info, prefix)
pmrowla marked this conversation as resolved.
Show resolved Hide resolved
)
return ret

with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor:
in_remote = executor.map(list_with_update, traverse_prefixes,)
remote_checksums.update(
itertools.chain.from_iterable(in_remote)
)
return list(checksums & remote_checksums)

def _cache_exists_no_traverse(self, checksums, jobs=None, name=None):
logger.debug(
"Querying {} checksums via no_traverse".format(len(checksums))
)
with Tqdm(
desc="Querying "
+ ("cache in " + name if name else "remote cache"),
Expand Down
13 changes: 11 additions & 2 deletions dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class RemoteGDrive(RemoteBASE):
REQUIRES = {"pydrive2": "pydrive2"}
DEFAULT_NO_TRAVERSE = False
DEFAULT_VERIFY = True
# Always prefer traverse for GDrive since API usage quotas are a concern.
TRAVERSE_WEIGHT_MULTIPLIER = 1
TRAVERSE_PREFIX_LEN = 2

GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA"
DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json"
Expand Down Expand Up @@ -414,12 +417,18 @@ def _download(self, from_info, to_file, name, no_progress_bar):
file_id = self._get_remote_id(from_info)
self.gdrive_download_file(file_id, to_file, name, no_progress_bar)

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
if not self.cache["ids"]:
return

if prefix:
dir_ids = self.cache["dirs"].get(prefix[:2])
if not dir_ids:
return
else:
dir_ids = self.cache["ids"]
parents_query = " or ".join(
"'{}' in parents".format(dir_id) for dir_id in self.cache["ids"]
"'{}' in parents".format(dir_id) for dir_id in dir_ids
)
query = "({}) and trashed=false".format(parents_query)

Expand Down
12 changes: 9 additions & 3 deletions dvc/remote/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import wraps
import io
import os.path
import posixpath
import threading

from funcy import cached_property, wrap_prop
Expand Down Expand Up @@ -69,6 +70,7 @@ class RemoteGS(RemoteBASE):
scheme = Schemes.GS
path_cls = CloudURLInfo
REQUIRES = {"google-cloud-storage": "google.cloud.storage"}
DEFAULT_NO_TRAVERSE = False
PARAM_CHECKSUM = "md5"

def __init__(self, repo, config):
Expand Down Expand Up @@ -126,14 +128,18 @@ def remove(self, path_info):

blob.delete()

def _list_paths(self, path_info, max_items=None):
def _list_paths(self, path_info, max_items=None, prefix=None):
if prefix:
prefix = posixpath.join(path_info.path, prefix[:2], prefix[2:])
else:
prefix = path_info.path
for blob in self.gs.bucket(path_info.bucket).list_blobs(
prefix=path_info.path, max_results=max_items
):
yield blob.name

def list_cache_paths(self):
return self._list_paths(self.path_info)
def list_cache_paths(self, prefix=None):
return self._list_paths(self.path_info, prefix=prefix)

def walk_files(self, path_info):
for fname in self._list_paths(path_info / ""):
Expand Down
26 changes: 19 additions & 7 deletions dvc/remote/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class RemoteHDFS(RemoteBASE):
REGEX = r"^hdfs://((?P<user>.*)@)?.*$"
PARAM_CHECKSUM = "checksum"
REQUIRES = {"pyarrow": "pyarrow"}
DEFAULT_NO_TRAVERSE = False
TRAVERSE_PREFIX_LEN = 2

def __init__(self, repo, config):
super().__init__(repo, config)
Expand Down Expand Up @@ -152,16 +154,26 @@ def open(self, path_info, mode="r", encoding=None):
raise FileNotFoundError(*e.args)
raise

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
if not self.exists(self.path_info):
return

dirs = deque([self.path_info.path])
if prefix:
root = posixpath.join(self.path_info.path, prefix[:2])
else:
root = self.path_info.path
dirs = deque([root])

with self.hdfs(self.path_info) as hdfs:
while dirs:
for entry in hdfs.ls(dirs.pop(), detail=True):
if entry["kind"] == "directory":
dirs.append(urlparse(entry["name"]).path)
elif entry["kind"] == "file":
yield urlparse(entry["name"]).path
try:
for entry in hdfs.ls(dirs.pop(), detail=True):
if entry["kind"] == "directory":
dirs.append(urlparse(entry["name"]).path)
elif entry["kind"] == "file":
yield urlparse(entry["name"]).path
except IOError as e:
# When searching for a specific prefix pyarrow raises an
# exception if the specified cache dir does not exist
if not prefix:
raise e
3 changes: 2 additions & 1 deletion dvc/remote/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def cache_dir(self, value):
def supported(cls, config):
return True

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
assert prefix is None
assert self.path_info is not None
pmrowla marked this conversation as resolved.
Show resolved Hide resolved
return walk_files(self.path_info)

Expand Down
12 changes: 10 additions & 2 deletions dvc/remote/oss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import posixpath
import threading

from funcy import cached_property, wrap_prop
Expand Down Expand Up @@ -35,6 +36,7 @@ class RemoteOSS(RemoteBASE):
scheme = Schemes.OSS
path_cls = CloudURLInfo
REQUIRES = {"oss2": "oss2"}
DEFAULT_NO_TRAVERSE = False
PARAM_CHECKSUM = "etag"
COPY_POLL_SECONDS = 5

Expand Down Expand Up @@ -95,8 +97,14 @@ def _list_paths(self, prefix):
for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix):
yield blob.key

def list_cache_paths(self):
return self._list_paths(self.path_info.path)
def list_cache_paths(self, prefix=None):
if prefix:
prefix = posixpath.join(
self.path_info.path, prefix[:2], prefix[2:]
)
else:
prefix = self.path_info.path
return self._list_paths(prefix)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
15 changes: 10 additions & 5 deletions dvc/remote/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
import posixpath
import threading

from funcy import cached_property, wrap_prop
Expand All @@ -21,6 +22,7 @@ class RemoteS3(RemoteBASE):
scheme = Schemes.S3
path_cls = CloudURLInfo
REQUIRES = {"boto3": "boto3"}
DEFAULT_NO_TRAVERSE = False
PARAM_CHECKSUM = "etag"

def __init__(self, repo, config):
Expand Down Expand Up @@ -190,24 +192,27 @@ def remove(self, path_info):
logger.debug("Removing {}".format(path_info))
self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path)

def _list_objects(self, path_info, max_items=None):
def _list_objects(self, path_info, max_items=None, prefix=None):
""" Read config for list object api, paginate through list objects."""
kwargs = {
"Bucket": path_info.bucket,
"Prefix": path_info.path,
"PaginationConfig": {"MaxItems": max_items},
}
if prefix:
kwargs["Prefix"] = posixpath.join(path_info.path, prefix[:2])
paginator = self.s3.get_paginator(self.list_objects_api)
for page in paginator.paginate(**kwargs):
yield from page.get("Contents", ())

def _list_paths(self, path_info, max_items=None):
def _list_paths(self, path_info, max_items=None, prefix=None):
return (
item["Key"] for item in self._list_objects(path_info, max_items)
item["Key"]
for item in self._list_objects(path_info, max_items, prefix)
)

def list_cache_paths(self):
return self._list_paths(self.path_info)
def list_cache_paths(self, prefix=None):
return self._list_paths(self.path_info, prefix=prefix)

def isfile(self, path_info):
from botocore.exceptions import ClientError
Expand Down
3 changes: 2 additions & 1 deletion dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def open(self, path_info, mode="r", encoding=None):
else:
yield io.TextIOWrapper(fd, encoding=encoding)

def list_cache_paths(self):
def list_cache_paths(self, prefix=None):
pmrowla marked this conversation as resolved.
Show resolved Hide resolved
assert prefix is None
with self.ssh(self.path_info) as ssh:
# If we simply return an iterator then with above closes instantly
yield from ssh.walk_files(self.path_info.path)
Expand Down
Loading