Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] GDrive code cleanup #3293

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 110 additions & 87 deletions dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import posixpath
import logging
import threading
import re
import threading
from urllib.parse import urlparse

from funcy import retry, compose, decorator, wrap_with
from funcy.py3 import cat
Expand All @@ -23,12 +24,22 @@ class GDriveRetriableError(DvcException):
pass


class GDrivePathNotFound(DvcException):
def __init__(self, path_info):
super().__init__("Google Drive path '{}' not found.".format(path_info))


class GDriveAccessTokenRefreshError(DvcException):
pass
def __init__(self):
super().__init__("Google Drive access token refreshment is failed.")


class GDriveMissedCredentialKeyError(DvcException):
pass
def __init__(self, path):
super().__init__(
"Google Drive user credentials file '{}' "
"misses value for key.".format(path)
)


@decorator
Expand All @@ -49,17 +60,32 @@ def _wrap_pydrive_retriable(call):


gdrive_retry = compose(
# 8 tries, start at 0.5s, multiply by golden ratio, cap at 10s
# 15 tries, start at 0.5s, multiply by golden ratio, cap at 20s
retry(
8, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 10)
15, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 20)
),
_wrap_pydrive_retriable,
)


class GDriveURLInfo(CloudURLInfo):
def __init__(self, url):
super().__init__(url)

# GDrive URL host part is case sensitive,
# we are restoring it here.
p = urlparse(url)
self.host = p.netloc
assert self.netloc == self.host

# Normalize path. Important since we have a cache (path to ID)
# and don't want to deal with different variations of path in it.
self._spath = re.sub("/{2,}", "/", self._spath.rstrip("/"))


class RemoteGDrive(RemoteBASE):
scheme = Schemes.GDRIVE
path_cls = CloudURLInfo
path_cls = GDriveURLInfo
REQUIRES = {"pydrive2": "pydrive2"}
DEFAULT_NO_TRAVERSE = False
DEFAULT_VERIFY = True
Expand All @@ -69,32 +95,32 @@ class RemoteGDrive(RemoteBASE):

def __init__(self, repo, config):
super().__init__(repo, config)
self.path_info = self.path_cls(config[Config.SECTION_REMOTE_URL])

bucket = re.search(
"{}://(.*)".format(self.scheme),
config[Config.SECTION_REMOTE_URL],
re.IGNORECASE,
)
self.bucket = (
bucket.group(1).split("/")[0] if bucket else self.path_info.bucket
)

url = config[Config.SECTION_REMOTE_URL]
self.path_info = self.path_cls(url)
self.config = config
self.init_drive()

def init_drive(self):
self.client_id = self.config.get(Config.SECTION_GDRIVE_CLIENT_ID, None)
self.client_secret = self.config.get(
if not self.path_info.bucket:
raise DvcException(
"Empty Google Drive URL '{}'. Learn more at "
"{}.".format(
url, format_link("https://man.dvc.org/remote/add")
)
)

self._bucket = self.path_info.bucket
self._client_id = self.config.get(
Config.SECTION_GDRIVE_CLIENT_ID, None
)
self._client_secret = self.config.get(
Config.SECTION_GDRIVE_CLIENT_SECRET, None
)
if not self.client_id or not self.client_secret:
if not self._client_id or not self._client_secret:
raise DvcException(
"Please specify Google Drive's client id and "
"secret in DVC's config. Learn more at "
"secret in DVC config. Learn more at "
"{}.".format(format_link("https://man.dvc.org/remote/add"))
)
self.gdrive_user_credentials_path = (
self._gdrive_user_credentials_path = (
tmp_fname(os.path.join(self.repo.tmp_dir, ""))
if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA)
else self.config.get(
Expand All @@ -104,7 +130,8 @@ def init_drive(self):
),
)
)
self.remote_drive_id = None
self._remote_drive_id = None
self._remote_root_id = None

@gdrive_retry
def gdrive_upload_file(
Expand Down Expand Up @@ -139,7 +166,7 @@ def gdrive_download_file(
# it does not create a file on the remote
gdrive_file = self.drive.CreateFile(param)
bar_format = (
"Donwloading {desc:{ncols_desc}.{ncols_desc}}... "
"Downloading {desc:{ncols_desc}.{ncols_desc}}... "
+ Tqdm.format_sizeof(int(gdrive_file["fileSize"]), "B", 1024)
)
with Tqdm(
Expand All @@ -150,8 +177,8 @@ def gdrive_download_file(
def gdrive_list_item(self, query):
param = {"q": query, "maxResults": 1000, "corpora": self.corpora}

if self.remote_drive_id:
param["driveId"] = self.remote_drive_id
if self._remote_drive_id:
param["driveId"] = self._remote_drive_id

file_list = self.drive.ListFile(param)

Expand All @@ -165,11 +192,12 @@ def cache_root_dirs(self):
cached_dirs = {}
cached_ids = {}
for dir1 in self.gdrive_list_item(
"'{}' in parents and trashed=false".format(self.remote_root_id)
"'{}' in parents and trashed=false".format(self._remote_root_id)
):
remote_path = posixpath.join(self.path_info.path, dir1["title"])
cached_dirs.setdefault(remote_path, []).append(dir1["id"])
cached_ids[dir1["id"]] = dir1["title"]

return cached_dirs, cached_ids

@property
Expand Down Expand Up @@ -201,16 +229,16 @@ def drive(self):

if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA):
with open(
self.gdrive_user_credentials_path, "w"
self._gdrive_user_credentials_path, "w"
) as credentials_file:
credentials_file.write(
os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA)
)

GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings"
GoogleAuth.DEFAULT_SETTINGS["client_config"] = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"client_id": self._client_id,
"client_secret": self._client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"revoke_uri": "https://oauth2.googleapis.com/revoke",
Expand All @@ -220,11 +248,10 @@ def drive(self):
GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file"
GoogleAuth.DEFAULT_SETTINGS[
"save_credentials_file"
] = self.gdrive_user_credentials_path
] = self._gdrive_user_credentials_path
GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True
GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [
"https://www.googleapis.com/auth/drive",
# drive.appdata grants access to appDataFolder GDrive directory
"https://www.googleapis.com/auth/drive.appdata",
]

Expand All @@ -234,33 +261,26 @@ def drive(self):
try:
gauth.CommandLineAuth()
except RefreshError as exc:
raise GDriveAccessTokenRefreshError(
"Google Drive's access token refreshment is failed"
) from exc
raise GDriveAccessTokenRefreshError from exc
except KeyError as exc:
raise GDriveMissedCredentialKeyError(
"Google Drive's user credentials file '{}' "
"misses value for key '{}'".format(
self.gdrive_user_credentials_path, str(exc)
)
)
# Handle pydrive2.auth.AuthenticationError and others auth failures
self._gdrive_user_credentials_path
) from exc
# Handle pydrive2.auth.AuthenticationError and other auth failures
except Exception as exc:
raise DvcException(
"Google Drive authentication failed"
) from exc
finally:
if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA):
os.remove(self.gdrive_user_credentials_path)
os.remove(self._gdrive_user_credentials_path)

self._gdrive = GoogleDrive(gauth)

if self.bucket != "root" and self.bucket != "appDataFolder":
self.remote_drive_id = self.get_remote_drive_id(self.bucket)
self._corpora = "drive" if self.remote_drive_id else "default"
self.remote_root_id = self.get_remote_id(
self.path_info, create=True
)
if self._bucket != "root" and self._bucket != "appDataFolder":
self._remote_drive_id = self._get_remote_drive_id(self._bucket)
self._corpora = "drive" if self._remote_drive_id else "default"
self._remote_root_id = self._get_remote_id(self.path_info)

self._cached_dirs, self._cached_ids = self.cache_root_dirs()

Expand Down Expand Up @@ -302,65 +322,68 @@ def get_remote_item(self, name, parents_ids):
"corpora": self.corpora,
}

if self.remote_drive_id:
param["driveId"] = self.remote_drive_id
if self._remote_drive_id:
param["driveId"] = self._remote_drive_id

# Limit found remote items count to 1 in response
item_list = self.drive.ListFile(param).GetList()
return next(iter(item_list), None)

@gdrive_retry
def get_remote_drive_id(self, remote_id):
def _get_remote_drive_id(self, remote_id):
param = {"id": remote_id}
# it does not create a file on the remote
item = self.drive.CreateFile(param)
item.FetchMetadata("driveId")
return item.get("driveId", None)

def resolve_remote_item_from_path(self, path_parts, create):
parents_ids = [self.bucket]
current_path = ""
for path_part in path_parts:
current_path = posixpath.join(current_path, path_part)
remote_ids = self.get_remote_id_from_cache(current_path)
if remote_ids:
parents_ids = remote_ids
continue
item = self.get_remote_item(path_part, parents_ids)
if not item and create:
item = self.create_remote_dir(parents_ids[0], path_part)
elif not item:
return None
parents_ids = [item["id"]]
return item

def get_remote_id_from_cache(self, remote_path):
def _get_known_remote_ids(self, path):
if not path:
return [self._bucket]
if path == self.path_info.path and self._remote_root_id:
return [self._remote_root_id]
if hasattr(self, "_cached_dirs"):
return self.cached_dirs.get(remote_path, [])
return self.cached_dirs.get(path, [])
return []

def get_remote_id(self, path_info, create=False):
if not path_info.path and path_info.bucket:
# Case sensitive base path
return self.bucket
def _path_to_remote_ids(self, path, create):
remote_ids = self._get_known_remote_ids(path)
if remote_ids:
return remote_ids

remote_ids = self.get_remote_id_from_cache(path_info.path)
parent_path, path_part = posixpath.split(path)
parent_ids = self._path_to_remote_ids(parent_path, create)
item = self.get_remote_item(path_part, parent_ids)

if remote_ids:
return remote_ids[0]
if not item:
if create:
item = self.create_remote_dir(parent_ids[0], path_part)
else:
return None

file1 = self.resolve_remote_item_from_path(
path_info.path.split("/"), create
)
return file1["id"] if file1 else ""
return [item["id"]]

def _get_remote_id(self, path_info, create=False):
assert path_info.bucket == self._bucket

remote_ids = self._path_to_remote_ids(path_info.path, create)
if not remote_ids:
raise GDrivePathNotFound(path_info)

return remote_ids[0]

def exists(self, path_info):
return self.get_remote_id(path_info) != ""
try:
self._get_remote_id(path_info)
except GDrivePathNotFound:
return False
else:
return True

def _upload(self, from_file, to_info, name, no_progress_bar):
dirname = to_info.parent
if dirname:
parent_id = self.get_remote_id(dirname, True)
parent_id = self._get_remote_id(dirname, True)
else:
parent_id = to_info.bucket

Expand All @@ -372,7 +395,7 @@ def _upload(self, from_file, to_info, name, no_progress_bar):
)

def _download(self, from_info, to_file, name, no_progress_bar):
file_id = self.get_remote_id(from_info)
file_id = self._get_remote_id(from_info)
self.gdrive_download_file(file_id, to_file, name, no_progress_bar)

def all(self):
Expand All @@ -396,5 +419,5 @@ def all(self):
logger.debug('Ignoring path as "non-cache looking"')

def remove(self, path_info):
remote_id = self.get_remote_id(path_info)
remote_id = self._get_remote_id(path_info)
self.delete_remote_file(remote_id)