diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 4bffb60b46..57876ab95f 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -1,8 +1,9 @@ import os import posixpath import logging -import threading import re +import threading +from urllib.parse import urlparse from funcy import retry, compose, decorator, wrap_with from funcy.py3 import cat @@ -23,12 +24,22 @@ class GDriveRetriableError(DvcException): pass +class GDrivePathNotFound(DvcException): + def __init__(self, path_info): + super().__init__("Google Drive path '{}' not found.".format(path_info)) + + class GDriveAccessTokenRefreshError(DvcException): - pass + def __init__(self): + super().__init__("Google Drive access token refreshment is failed.") class GDriveMissedCredentialKeyError(DvcException): - pass + def __init__(self, path): + super().__init__( + "Google Drive user credentials file '{}' " + "misses value for key.".format(path) + ) @decorator @@ -49,17 +60,32 @@ def _wrap_pydrive_retriable(call): gdrive_retry = compose( - # 8 tries, start at 0.5s, multiply by golden ratio, cap at 10s + # 15 tries, start at 0.5s, multiply by golden ratio, cap at 20s retry( - 8, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 10) + 15, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 20) ), _wrap_pydrive_retriable, ) +class GDriveURLInfo(CloudURLInfo): + def __init__(self, url): + super().__init__(url) + + # GDrive URL host part is case sensitive, + # we are restoring it here. + p = urlparse(url) + self.host = p.netloc + assert self.netloc == self.host + + # Normalize path. Important since we have a cache (path to ID) + # and don't want to deal with different variations of path in it. + self._spath = re.sub("/{2,}", "/", self._spath.rstrip("/")) + + class RemoteGDrive(RemoteBASE): scheme = Schemes.GDRIVE - path_cls = CloudURLInfo + path_cls = GDriveURLInfo REQUIRES = {"pydrive2": "pydrive2"} DEFAULT_NO_TRAVERSE = False DEFAULT_VERIFY = True @@ -69,32 +95,32 @@ class RemoteGDrive(RemoteBASE): def __init__(self, repo, config): super().__init__(repo, config) - self.path_info = self.path_cls(config[Config.SECTION_REMOTE_URL]) - - bucket = re.search( - "{}://(.*)".format(self.scheme), - config[Config.SECTION_REMOTE_URL], - re.IGNORECASE, - ) - self.bucket = ( - bucket.group(1).split("/")[0] if bucket else self.path_info.bucket - ) - + url = config[Config.SECTION_REMOTE_URL] + self.path_info = self.path_cls(url) self.config = config - self.init_drive() - def init_drive(self): - self.client_id = self.config.get(Config.SECTION_GDRIVE_CLIENT_ID, None) - self.client_secret = self.config.get( + if not self.path_info.bucket: + raise DvcException( + "Empty Google Drive URL '{}'. Learn more at " + "{}.".format( + url, format_link("https://man.dvc.org/remote/add") + ) + ) + + self._bucket = self.path_info.bucket + self._client_id = self.config.get( + Config.SECTION_GDRIVE_CLIENT_ID, None + ) + self._client_secret = self.config.get( Config.SECTION_GDRIVE_CLIENT_SECRET, None ) - if not self.client_id or not self.client_secret: + if not self._client_id or not self._client_secret: raise DvcException( "Please specify Google Drive's client id and " - "secret in DVC's config. Learn more at " + "secret in DVC config. Learn more at " "{}.".format(format_link("https://man.dvc.org/remote/add")) ) - self.gdrive_user_credentials_path = ( + self._gdrive_user_credentials_path = ( tmp_fname(os.path.join(self.repo.tmp_dir, "")) if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) else self.config.get( @@ -104,7 +130,8 @@ def init_drive(self): ), ) ) - self.remote_drive_id = None + self._remote_drive_id = None + self._remote_root_id = None @gdrive_retry def gdrive_upload_file( @@ -139,7 +166,7 @@ def gdrive_download_file( # it does not create a file on the remote gdrive_file = self.drive.CreateFile(param) bar_format = ( - "Donwloading {desc:{ncols_desc}.{ncols_desc}}... " + "Downloading {desc:{ncols_desc}.{ncols_desc}}... " + Tqdm.format_sizeof(int(gdrive_file["fileSize"]), "B", 1024) ) with Tqdm( @@ -150,8 +177,8 @@ def gdrive_download_file( def gdrive_list_item(self, query): param = {"q": query, "maxResults": 1000, "corpora": self.corpora} - if self.remote_drive_id: - param["driveId"] = self.remote_drive_id + if self._remote_drive_id: + param["driveId"] = self._remote_drive_id file_list = self.drive.ListFile(param) @@ -165,11 +192,12 @@ def cache_root_dirs(self): cached_dirs = {} cached_ids = {} for dir1 in self.gdrive_list_item( - "'{}' in parents and trashed=false".format(self.remote_root_id) + "'{}' in parents and trashed=false".format(self._remote_root_id) ): remote_path = posixpath.join(self.path_info.path, dir1["title"]) cached_dirs.setdefault(remote_path, []).append(dir1["id"]) cached_ids[dir1["id"]] = dir1["title"] + return cached_dirs, cached_ids @property @@ -201,7 +229,7 @@ def drive(self): if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): with open( - self.gdrive_user_credentials_path, "w" + self._gdrive_user_credentials_path, "w" ) as credentials_file: credentials_file.write( os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) @@ -209,8 +237,8 @@ def drive(self): GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings" GoogleAuth.DEFAULT_SETTINGS["client_config"] = { - "client_id": self.client_id, - "client_secret": self.client_secret, + "client_id": self._client_id, + "client_secret": self._client_secret, "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://oauth2.googleapis.com/token", "revoke_uri": "https://oauth2.googleapis.com/revoke", @@ -220,11 +248,10 @@ def drive(self): GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file" GoogleAuth.DEFAULT_SETTINGS[ "save_credentials_file" - ] = self.gdrive_user_credentials_path + ] = self._gdrive_user_credentials_path GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [ "https://www.googleapis.com/auth/drive", - # drive.appdata grants access to appDataFolder GDrive directory "https://www.googleapis.com/auth/drive.appdata", ] @@ -234,33 +261,26 @@ def drive(self): try: gauth.CommandLineAuth() except RefreshError as exc: - raise GDriveAccessTokenRefreshError( - "Google Drive's access token refreshment is failed" - ) from exc + raise GDriveAccessTokenRefreshError from exc except KeyError as exc: raise GDriveMissedCredentialKeyError( - "Google Drive's user credentials file '{}' " - "misses value for key '{}'".format( - self.gdrive_user_credentials_path, str(exc) - ) - ) - # Handle pydrive2.auth.AuthenticationError and others auth failures + self._gdrive_user_credentials_path + ) from exc + # Handle pydrive2.auth.AuthenticationError and other auth failures except Exception as exc: raise DvcException( "Google Drive authentication failed" ) from exc finally: if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): - os.remove(self.gdrive_user_credentials_path) + os.remove(self._gdrive_user_credentials_path) self._gdrive = GoogleDrive(gauth) - if self.bucket != "root" and self.bucket != "appDataFolder": - self.remote_drive_id = self.get_remote_drive_id(self.bucket) - self._corpora = "drive" if self.remote_drive_id else "default" - self.remote_root_id = self.get_remote_id( - self.path_info, create=True - ) + if self._bucket != "root" and self._bucket != "appDataFolder": + self._remote_drive_id = self._get_remote_drive_id(self._bucket) + self._corpora = "drive" if self._remote_drive_id else "default" + self._remote_root_id = self._get_remote_id(self.path_info) self._cached_dirs, self._cached_ids = self.cache_root_dirs() @@ -302,65 +322,68 @@ def get_remote_item(self, name, parents_ids): "corpora": self.corpora, } - if self.remote_drive_id: - param["driveId"] = self.remote_drive_id + if self._remote_drive_id: + param["driveId"] = self._remote_drive_id # Limit found remote items count to 1 in response item_list = self.drive.ListFile(param).GetList() return next(iter(item_list), None) @gdrive_retry - def get_remote_drive_id(self, remote_id): + def _get_remote_drive_id(self, remote_id): param = {"id": remote_id} # it does not create a file on the remote item = self.drive.CreateFile(param) item.FetchMetadata("driveId") return item.get("driveId", None) - def resolve_remote_item_from_path(self, path_parts, create): - parents_ids = [self.bucket] - current_path = "" - for path_part in path_parts: - current_path = posixpath.join(current_path, path_part) - remote_ids = self.get_remote_id_from_cache(current_path) - if remote_ids: - parents_ids = remote_ids - continue - item = self.get_remote_item(path_part, parents_ids) - if not item and create: - item = self.create_remote_dir(parents_ids[0], path_part) - elif not item: - return None - parents_ids = [item["id"]] - return item - - def get_remote_id_from_cache(self, remote_path): + def _get_known_remote_ids(self, path): + if not path: + return [self._bucket] + if path == self.path_info.path and self._remote_root_id: + return [self._remote_root_id] if hasattr(self, "_cached_dirs"): - return self.cached_dirs.get(remote_path, []) + return self.cached_dirs.get(path, []) return [] - def get_remote_id(self, path_info, create=False): - if not path_info.path and path_info.bucket: - # Case sensitive base path - return self.bucket + def _path_to_remote_ids(self, path, create): + remote_ids = self._get_known_remote_ids(path) + if remote_ids: + return remote_ids - remote_ids = self.get_remote_id_from_cache(path_info.path) + parent_path, path_part = posixpath.split(path) + parent_ids = self._path_to_remote_ids(parent_path, create) + item = self.get_remote_item(path_part, parent_ids) - if remote_ids: - return remote_ids[0] + if not item: + if create: + item = self.create_remote_dir(parent_ids[0], path_part) + else: + return None - file1 = self.resolve_remote_item_from_path( - path_info.path.split("/"), create - ) - return file1["id"] if file1 else "" + return [item["id"]] + + def _get_remote_id(self, path_info, create=False): + assert path_info.bucket == self._bucket + + remote_ids = self._path_to_remote_ids(path_info.path, create) + if not remote_ids: + raise GDrivePathNotFound(path_info) + + return remote_ids[0] def exists(self, path_info): - return self.get_remote_id(path_info) != "" + try: + self._get_remote_id(path_info) + except GDrivePathNotFound: + return False + else: + return True def _upload(self, from_file, to_info, name, no_progress_bar): dirname = to_info.parent if dirname: - parent_id = self.get_remote_id(dirname, True) + parent_id = self._get_remote_id(dirname, True) else: parent_id = to_info.bucket @@ -372,7 +395,7 @@ def _upload(self, from_file, to_info, name, no_progress_bar): ) def _download(self, from_info, to_file, name, no_progress_bar): - file_id = self.get_remote_id(from_info) + file_id = self._get_remote_id(from_info) self.gdrive_download_file(file_id, to_file, name, no_progress_bar) def all(self): @@ -396,5 +419,5 @@ def all(self): logger.debug('Ignoring path as "non-cache looking"') def remove(self, path_info): - remote_id = self.get_remote_id(path_info) + remote_id = self._get_remote_id(path_info) self.delete_remote_file(remote_id)