Skip to content

Commit

Permalink
gdrive: method signatures cleanup, more safeguards (#3548)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Mar 30, 2020
1 parent 0c2e4d0 commit f6077d0
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 131 deletions.
254 changes: 126 additions & 128 deletions dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, path):
)


def gdrive_retry(func):
def _gdrive_retry(func):
def should_retry(exc):
from pydrive2.files import ApiRequestError

Expand All @@ -48,12 +48,12 @@ def should_retry(exc):
retry_codes = [403, 500, 502, 503, 504]
result = exc.error.get("code", 0) in retry_codes
if result:
logger.debug("Retry GDrive API call failed with {}.".format(exc))
logger.debug("Retrying GDrive API call, error: {}.".format(exc))
return result

# 15 tries, start at 0.5s, multiply by golden ratio, cap at 20s
# 16 tries, start at 0.5s, multiply by golden ratio, cap at 20s
return retry(
15,
16,
timeout=lambda a: min(0.5 * 1.618 ** a, 20),
filter_errors=should_retry,
)(func)
Expand Down Expand Up @@ -171,7 +171,7 @@ def _validate_config(self):

@wrap_prop(threading.RLock())
@cached_property
def drive(self):
def _drive(self):
from pydrive2.auth import RefreshError
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive
Expand Down Expand Up @@ -237,46 +237,58 @@ def drive(self):

@wrap_prop(threading.RLock())
@cached_property
def cache(self):
cache = {"dirs": defaultdict(list), "ids": {}}
def _ids_cache(self):
cache = {
"dirs": defaultdict(list),
"ids": {},
"root_id": self._get_item_id(self.path_info, use_cache=False),
}

cache["root_id"] = self._get_remote_id(self.path_info)
cache["dirs"][self.path_info.path] = [cache["root_id"]]
self._cache_path(self.path_info.path, cache["root_id"], cache)
self._cache_path_id(self.path_info.path, cache["root_id"], cache)

for item in self.gdrive_list_item(
for item in self._gdrive_list(
"'{}' in parents and trashed=false".format(cache["root_id"])
):
remote_path = (self.path_info / item["title"]).path
self._cache_path(remote_path, item["id"], cache)
item_path = (self.path_info / item["title"]).path
self._cache_path_id(item_path, item["id"], cache)

return cache

def _cache_path(self, remote_path, remote_id, cache=None):
cache = cache or self.cache
cache["dirs"][remote_path].append(remote_id)
cache["ids"][remote_id] = remote_path
def _cache_path_id(self, path, item_id, cache=None):
cache = cache or self._ids_cache
cache["dirs"][path].append(item_id)
cache["ids"][item_id] = path

@cached_property
def list_params(self):
def _list_params(self):
params = {"corpora": "default"}
if self._bucket != "root" and self._bucket != "appDataFolder":
drive_id = self._get_remote_drive_id(self._bucket)
drive_id = self._gdrive_shared_drive_id(self._bucket)
if drive_id:
params["driveId"] = drive_id
params["corpora"] = "drive"
return params

@gdrive_retry
def gdrive_upload_file(
@_gdrive_retry
def _gdrive_shared_drive_id(self, item_id):
param = {"id": item_id}
# it does not create a file on the remote
item = self._drive.CreateFile(param)
# ID of the shared drive the item resides in.
# Only populated for items in shared drives.
item.FetchMetadata("driveId")
return item.get("driveId", None)

@_gdrive_retry
def _gdrive_upload_file(
self,
parent_id,
title,
no_progress_bar=False,
from_file="",
progress_name="",
):
item = self.drive.CreateFile(
item = self._drive.CreateFile(
{"title": title, "parents": [{"id": parent_id}]}
)

Expand All @@ -296,71 +308,37 @@ def gdrive_upload_file(
item.Upload()
return item

@gdrive_retry
def gdrive_download_file(
self, file_id, to_file, progress_name, no_progress_bar
@_gdrive_retry
def _gdrive_download_file(
self, item_id, to_file, progress_desc, no_progress_bar
):
param = {"id": file_id}
param = {"id": item_id}
# it does not create a file on the remote
gdrive_file = self.drive.CreateFile(param)
gdrive_file = self._drive.CreateFile(param)
bar_format = (
"Downloading {desc:{ncols_desc}.{ncols_desc}}... "
+ Tqdm.format_sizeof(int(gdrive_file["fileSize"]), "B", 1024)
)
with Tqdm(
bar_format=bar_format, desc=progress_name, disable=no_progress_bar
bar_format=bar_format, desc=progress_desc, disable=no_progress_bar
):
gdrive_file.GetContentFile(to_file)

def gdrive_list_item(self, query):
param = {"q": query, "maxResults": 1000}
param.update(self.list_params)

file_list = self.drive.ListFile(param)

# Isolate and decorate fetching of remote drive items in pages
get_list = gdrive_retry(lambda: next(file_list, None))

# Fetch pages until None is received, lazily flatten the thing
return cat(iter(get_list, None))

@wrap_with(threading.RLock())
def gdrive_create_dir(self, parent_id, title, remote_path):
cached = self.cache["dirs"].get(remote_path)
if cached:
return cached[0]

item = self._create_remote_dir(parent_id, title)

if parent_id == self.cache["root_id"]:
self._cache_path(remote_path, item["id"])

return item["id"]

@gdrive_retry
def _create_remote_dir(self, parent_id, title):
parent = {"id": parent_id}
item = self.drive.CreateFile(
{"title": title, "parents": [parent], "mimeType": FOLDER_MIME_TYPE}
)
item.Upload()
return item

@gdrive_retry
def _delete_remote_file(self, remote_id):
@_gdrive_retry
def _gdrive_delete_file(self, item_id):
from pydrive2.files import ApiRequestError

param = {"id": remote_id}
param = {"id": item_id}
# it does not create a file on the remote
item = self.drive.CreateFile(param)
item = self._drive.CreateFile(param)

try:
item.Trash() if self._trash_only else item.Delete()
except ApiRequestError as exc:
http_error_code = exc.error.get("code", 0)
if (
http_error_code == 403
and self.list_params["corpora"] == "drive"
and self._list_params["corpora"] == "drive"
and _location(exc) == "file.permissions"
):
raise DvcException(
Expand All @@ -378,112 +356,132 @@ def _delete_remote_file(self, remote_id):
) from exc
raise

@gdrive_retry
def _get_remote_item(self, name, parents_ids):
if not parents_ids:
return None
query = "({})".format(
" or ".join(
"'{}' in parents".format(parent_id)
for parent_id in parents_ids
)
def _gdrive_list(self, query):
param = {"q": query, "maxResults": 1000}
param.update(self._list_params)
file_list = self._drive.ListFile(param)

# Isolate and decorate fetching of remote drive items in pages.
get_list = _gdrive_retry(lambda: next(file_list, None))

# Fetch pages until None is received, lazily flatten the thing.
return cat(iter(get_list, None))

@_gdrive_retry
def _gdrive_create_dir(self, parent_id, title):
parent = {"id": parent_id}
item = self._drive.CreateFile(
{"title": title, "parents": [parent], "mimeType": FOLDER_MIME_TYPE}
)
item.Upload()
return item

query += " and trashed=false and title='{}'".format(name)
@wrap_with(threading.RLock())
def _create_dir(self, parent_id, title, remote_path):
cached = self._ids_cache["dirs"].get(remote_path)
if cached:
return cached[0]

# Remote might contain items with duplicated path (titles).
# We thus limit number of items.
param = {"q": query, "maxResults": 1}
param.update(self.list_params)
item = self._gdrive_create_dir(parent_id, title)

# Limit found remote items count to 1 in response
item_list = self.drive.ListFile(param).GetList()
return next(iter(item_list), None)
if parent_id == self._ids_cache["root_id"]:
self._cache_path_id(remote_path, item["id"])

@gdrive_retry
def _get_remote_drive_id(self, remote_id):
param = {"id": remote_id}
# it does not create a file on the remote
item = self.drive.CreateFile(param)
item.FetchMetadata("driveId")
return item.get("driveId", None)
return item["id"]

def _get_cached_remote_ids(self, path):
def _get_remote_item_ids(self, parent_ids, title):
if not parent_ids:
return None
query = "trashed=false and ({})".format(
" or ".join(
"'{}' in parents".format(parent_id) for parent_id in parent_ids
)
)
query += " and title='{}'".format(title.replace("'", "\\'"))

# GDrive list API is case insensitive, we need to compare
# all results and pick the ones with the right title
return [
item["id"]
for item in self._gdrive_list(query)
if item["title"] == title
]

def _get_cached_item_ids(self, path, use_cache):
if not path:
return [self._bucket]
if "cache" in self.__dict__:
return self.cache["dirs"].get(path, [])
if use_cache:
return self._ids_cache["dirs"].get(path, [])
return []

def _path_to_remote_ids(self, path, create):
remote_ids = self._get_cached_remote_ids(path)
if remote_ids:
return remote_ids

parent_path, part = posixpath.split(path)
parent_ids = self._path_to_remote_ids(parent_path, create)
item = self._get_remote_item(part, parent_ids)
def _path_to_item_ids(self, path, create, use_cache):
item_ids = self._get_cached_item_ids(path, use_cache)
if item_ids:
return item_ids

if not item:
return (
[self.gdrive_create_dir(parent_ids[0], part, path)]
if create
else []
)
parent_path, title = posixpath.split(path)
parent_ids = self._path_to_item_ids(parent_path, create, use_cache)
item_ids = self._get_remote_item_ids(parent_ids, title)
if item_ids:
return item_ids

return [item["id"]]
return (
[self._create_dir(min(parent_ids), title, path)] if create else []
)

def _get_remote_id(self, path_info, create=False):
def _get_item_id(self, path_info, create=False, use_cache=True):
assert path_info.bucket == self._bucket

remote_ids = self._path_to_remote_ids(path_info.path, create)
if not remote_ids:
raise GDrivePathNotFound(path_info)
item_ids = self._path_to_item_ids(path_info.path, create, use_cache)
if item_ids:
return min(item_ids)

return remote_ids[0]
raise GDrivePathNotFound(path_info)

def exists(self, path_info):
try:
self._get_remote_id(path_info)
self._get_item_id(path_info)
except GDrivePathNotFound:
return False
else:
return True

def _upload(self, from_file, to_info, name, no_progress_bar):
def _upload(self, from_file, to_info, name=None, no_progress_bar=False):
dirname = to_info.parent
assert dirname
parent_id = self._get_remote_id(dirname, True)
parent_id = self._get_item_id(dirname, True)

self.gdrive_upload_file(
self._gdrive_upload_file(
parent_id, to_info.name, no_progress_bar, from_file, name
)

def _download(self, from_info, to_file, name, no_progress_bar):
file_id = self._get_remote_id(from_info)
self.gdrive_download_file(file_id, to_file, name, no_progress_bar)
def _download(self, from_info, to_file, name=None, no_progress_bar=False):
item_id = self._get_item_id(from_info)
self._gdrive_download_file(item_id, to_file, name, no_progress_bar)

def list_cache_paths(self, prefix=None, progress_callback=None):
if not self.cache["ids"]:
if not self._ids_cache["ids"]:
return

if prefix:
dir_ids = self.cache["dirs"].get(prefix[:2])
dir_ids = self._ids_cache["dirs"].get(prefix[:2])
if not dir_ids:
return
else:
dir_ids = self.cache["ids"]
dir_ids = self._ids_cache["ids"]
parents_query = " or ".join(
"'{}' in parents".format(dir_id) for dir_id in dir_ids
)
query = "({}) and trashed=false".format(parents_query)

for item in self.gdrive_list_item(query):
for item in self._gdrive_list(query):
if progress_callback:
progress_callback()
parent_id = item["parents"][0]["id"]
yield posixpath.join(self.cache["ids"][parent_id], item["title"])
yield posixpath.join(
self._ids_cache["ids"][parent_id], item["title"]
)

def remove(self, path_info):
remote_id = self._get_remote_id(path_info)
self._delete_remote_file(remote_id)
item_id = self._get_item_id(path_info)
self._gdrive_delete_file(item_id)
2 changes: 1 addition & 1 deletion tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def setup_gdrive_cloud(remote_url, dvc):

dvc.config = config
remote = DataCloud(dvc).get_remote()
remote._create_remote_dir("root", remote.path_info.path)
remote._gdrive_create_dir("root", remote.path_info.path)


class TestRemoteGDrive(GDrive, TestDataCloudBase):
Expand Down
Loading

0 comments on commit f6077d0

Please sign in to comment.