diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index de4211641e..8894fdf9fe 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -38,7 +38,7 @@ def __init__(self, path): ) -def gdrive_retry(func): +def _gdrive_retry(func): def should_retry(exc): from pydrive2.files import ApiRequestError @@ -48,12 +48,12 @@ def should_retry(exc): retry_codes = [403, 500, 502, 503, 504] result = exc.error.get("code", 0) in retry_codes if result: - logger.debug("Retry GDrive API call failed with {}.".format(exc)) + logger.debug("Retrying GDrive API call, error: {}.".format(exc)) return result - # 15 tries, start at 0.5s, multiply by golden ratio, cap at 20s + # 16 tries, start at 0.5s, multiply by golden ratio, cap at 20s return retry( - 15, + 16, timeout=lambda a: min(0.5 * 1.618 ** a, 20), filter_errors=should_retry, )(func) @@ -171,7 +171,7 @@ def _validate_config(self): @wrap_prop(threading.RLock()) @cached_property - def drive(self): + def _drive(self): from pydrive2.auth import RefreshError from pydrive2.auth import GoogleAuth from pydrive2.drive import GoogleDrive @@ -237,38 +237,50 @@ def drive(self): @wrap_prop(threading.RLock()) @cached_property - def cache(self): - cache = {"dirs": defaultdict(list), "ids": {}} + def _ids_cache(self): + cache = { + "dirs": defaultdict(list), + "ids": {}, + "root_id": self._get_item_id(self.path_info, use_cache=False), + } - cache["root_id"] = self._get_remote_id(self.path_info) - cache["dirs"][self.path_info.path] = [cache["root_id"]] - self._cache_path(self.path_info.path, cache["root_id"], cache) + self._cache_path_id(self.path_info.path, cache["root_id"], cache) - for item in self.gdrive_list_item( + for item in self._gdrive_list( "'{}' in parents and trashed=false".format(cache["root_id"]) ): - remote_path = (self.path_info / item["title"]).path - self._cache_path(remote_path, item["id"], cache) + item_path = (self.path_info / item["title"]).path + self._cache_path_id(item_path, item["id"], cache) return cache - def _cache_path(self, remote_path, remote_id, cache=None): - cache = cache or self.cache - cache["dirs"][remote_path].append(remote_id) - cache["ids"][remote_id] = remote_path + def _cache_path_id(self, path, item_id, cache=None): + cache = cache or self._ids_cache + cache["dirs"][path].append(item_id) + cache["ids"][item_id] = path @cached_property - def list_params(self): + def _list_params(self): params = {"corpora": "default"} if self._bucket != "root" and self._bucket != "appDataFolder": - drive_id = self._get_remote_drive_id(self._bucket) + drive_id = self._gdrive_shared_drive_id(self._bucket) if drive_id: params["driveId"] = drive_id params["corpora"] = "drive" return params - @gdrive_retry - def gdrive_upload_file( + @_gdrive_retry + def _gdrive_shared_drive_id(self, item_id): + param = {"id": item_id} + # it does not create a file on the remote + item = self._drive.CreateFile(param) + # ID of the shared drive the item resides in. + # Only populated for items in shared drives. + item.FetchMetadata("driveId") + return item.get("driveId", None) + + @_gdrive_retry + def _gdrive_upload_file( self, parent_id, title, @@ -276,7 +288,7 @@ def gdrive_upload_file( from_file="", progress_name="", ): - item = self.drive.CreateFile( + item = self._drive.CreateFile( {"title": title, "parents": [{"id": parent_id}]} ) @@ -296,63 +308,29 @@ def gdrive_upload_file( item.Upload() return item - @gdrive_retry - def gdrive_download_file( - self, file_id, to_file, progress_name, no_progress_bar + @_gdrive_retry + def _gdrive_download_file( + self, item_id, to_file, progress_desc, no_progress_bar ): - param = {"id": file_id} + param = {"id": item_id} # it does not create a file on the remote - gdrive_file = self.drive.CreateFile(param) + gdrive_file = self._drive.CreateFile(param) bar_format = ( "Downloading {desc:{ncols_desc}.{ncols_desc}}... " + Tqdm.format_sizeof(int(gdrive_file["fileSize"]), "B", 1024) ) with Tqdm( - bar_format=bar_format, desc=progress_name, disable=no_progress_bar + bar_format=bar_format, desc=progress_desc, disable=no_progress_bar ): gdrive_file.GetContentFile(to_file) - def gdrive_list_item(self, query): - param = {"q": query, "maxResults": 1000} - param.update(self.list_params) - - file_list = self.drive.ListFile(param) - - # Isolate and decorate fetching of remote drive items in pages - get_list = gdrive_retry(lambda: next(file_list, None)) - - # Fetch pages until None is received, lazily flatten the thing - return cat(iter(get_list, None)) - - @wrap_with(threading.RLock()) - def gdrive_create_dir(self, parent_id, title, remote_path): - cached = self.cache["dirs"].get(remote_path) - if cached: - return cached[0] - - item = self._create_remote_dir(parent_id, title) - - if parent_id == self.cache["root_id"]: - self._cache_path(remote_path, item["id"]) - - return item["id"] - - @gdrive_retry - def _create_remote_dir(self, parent_id, title): - parent = {"id": parent_id} - item = self.drive.CreateFile( - {"title": title, "parents": [parent], "mimeType": FOLDER_MIME_TYPE} - ) - item.Upload() - return item - - @gdrive_retry - def _delete_remote_file(self, remote_id): + @_gdrive_retry + def _gdrive_delete_file(self, item_id): from pydrive2.files import ApiRequestError - param = {"id": remote_id} + param = {"id": item_id} # it does not create a file on the remote - item = self.drive.CreateFile(param) + item = self._drive.CreateFile(param) try: item.Trash() if self._trash_only else item.Delete() @@ -360,7 +338,7 @@ def _delete_remote_file(self, remote_id): http_error_code = exc.error.get("code", 0) if ( http_error_code == 403 - and self.list_params["corpora"] == "drive" + and self._list_params["corpora"] == "drive" and _location(exc) == "file.permissions" ): raise DvcException( @@ -378,112 +356,132 @@ def _delete_remote_file(self, remote_id): ) from exc raise - @gdrive_retry - def _get_remote_item(self, name, parents_ids): - if not parents_ids: - return None - query = "({})".format( - " or ".join( - "'{}' in parents".format(parent_id) - for parent_id in parents_ids - ) + def _gdrive_list(self, query): + param = {"q": query, "maxResults": 1000} + param.update(self._list_params) + file_list = self._drive.ListFile(param) + + # Isolate and decorate fetching of remote drive items in pages. + get_list = _gdrive_retry(lambda: next(file_list, None)) + + # Fetch pages until None is received, lazily flatten the thing. + return cat(iter(get_list, None)) + + @_gdrive_retry + def _gdrive_create_dir(self, parent_id, title): + parent = {"id": parent_id} + item = self._drive.CreateFile( + {"title": title, "parents": [parent], "mimeType": FOLDER_MIME_TYPE} ) + item.Upload() + return item - query += " and trashed=false and title='{}'".format(name) + @wrap_with(threading.RLock()) + def _create_dir(self, parent_id, title, remote_path): + cached = self._ids_cache["dirs"].get(remote_path) + if cached: + return cached[0] - # Remote might contain items with duplicated path (titles). - # We thus limit number of items. - param = {"q": query, "maxResults": 1} - param.update(self.list_params) + item = self._gdrive_create_dir(parent_id, title) - # Limit found remote items count to 1 in response - item_list = self.drive.ListFile(param).GetList() - return next(iter(item_list), None) + if parent_id == self._ids_cache["root_id"]: + self._cache_path_id(remote_path, item["id"]) - @gdrive_retry - def _get_remote_drive_id(self, remote_id): - param = {"id": remote_id} - # it does not create a file on the remote - item = self.drive.CreateFile(param) - item.FetchMetadata("driveId") - return item.get("driveId", None) + return item["id"] - def _get_cached_remote_ids(self, path): + def _get_remote_item_ids(self, parent_ids, title): + if not parent_ids: + return None + query = "trashed=false and ({})".format( + " or ".join( + "'{}' in parents".format(parent_id) for parent_id in parent_ids + ) + ) + query += " and title='{}'".format(title.replace("'", "\\'")) + + # GDrive list API is case insensitive, we need to compare + # all results and pick the ones with the right title + return [ + item["id"] + for item in self._gdrive_list(query) + if item["title"] == title + ] + + def _get_cached_item_ids(self, path, use_cache): if not path: return [self._bucket] - if "cache" in self.__dict__: - return self.cache["dirs"].get(path, []) + if use_cache: + return self._ids_cache["dirs"].get(path, []) return [] - def _path_to_remote_ids(self, path, create): - remote_ids = self._get_cached_remote_ids(path) - if remote_ids: - return remote_ids - - parent_path, part = posixpath.split(path) - parent_ids = self._path_to_remote_ids(parent_path, create) - item = self._get_remote_item(part, parent_ids) + def _path_to_item_ids(self, path, create, use_cache): + item_ids = self._get_cached_item_ids(path, use_cache) + if item_ids: + return item_ids - if not item: - return ( - [self.gdrive_create_dir(parent_ids[0], part, path)] - if create - else [] - ) + parent_path, title = posixpath.split(path) + parent_ids = self._path_to_item_ids(parent_path, create, use_cache) + item_ids = self._get_remote_item_ids(parent_ids, title) + if item_ids: + return item_ids - return [item["id"]] + return ( + [self._create_dir(min(parent_ids), title, path)] if create else [] + ) - def _get_remote_id(self, path_info, create=False): + def _get_item_id(self, path_info, create=False, use_cache=True): assert path_info.bucket == self._bucket - remote_ids = self._path_to_remote_ids(path_info.path, create) - if not remote_ids: - raise GDrivePathNotFound(path_info) + item_ids = self._path_to_item_ids(path_info.path, create, use_cache) + if item_ids: + return min(item_ids) - return remote_ids[0] + raise GDrivePathNotFound(path_info) def exists(self, path_info): try: - self._get_remote_id(path_info) + self._get_item_id(path_info) except GDrivePathNotFound: return False else: return True - def _upload(self, from_file, to_info, name, no_progress_bar): + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): dirname = to_info.parent assert dirname - parent_id = self._get_remote_id(dirname, True) + parent_id = self._get_item_id(dirname, True) - self.gdrive_upload_file( + self._gdrive_upload_file( parent_id, to_info.name, no_progress_bar, from_file, name ) - def _download(self, from_info, to_file, name, no_progress_bar): - file_id = self._get_remote_id(from_info) - self.gdrive_download_file(file_id, to_file, name, no_progress_bar) + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + item_id = self._get_item_id(from_info) + self._gdrive_download_file(item_id, to_file, name, no_progress_bar) def list_cache_paths(self, prefix=None, progress_callback=None): - if not self.cache["ids"]: + if not self._ids_cache["ids"]: return if prefix: - dir_ids = self.cache["dirs"].get(prefix[:2]) + dir_ids = self._ids_cache["dirs"].get(prefix[:2]) if not dir_ids: return else: - dir_ids = self.cache["ids"] + dir_ids = self._ids_cache["ids"] parents_query = " or ".join( "'{}' in parents".format(dir_id) for dir_id in dir_ids ) query = "({}) and trashed=false".format(parents_query) - for item in self.gdrive_list_item(query): + for item in self._gdrive_list(query): if progress_callback: progress_callback() parent_id = item["parents"][0]["id"] - yield posixpath.join(self.cache["ids"][parent_id], item["title"]) + yield posixpath.join( + self._ids_cache["ids"][parent_id], item["title"] + ) def remove(self, path_info): - remote_id = self._get_remote_id(path_info) - self._delete_remote_file(remote_id) + item_id = self._get_item_id(path_info) + self._gdrive_delete_file(item_id) diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 57945afeb0..1bf93ab7fb 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -200,7 +200,7 @@ def setup_gdrive_cloud(remote_url, dvc): dvc.config = config remote = DataCloud(dvc).get_remote() - remote._create_remote_dir("root", remote.path_info.path) + remote._gdrive_create_dir("root", remote.path_info.path) class TestRemoteGDrive(GDrive, TestDataCloudBase): diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index 72b790ea92..2e9c8f7917 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -35,7 +35,7 @@ def test_drive(self): RemoteGDrive.GDRIVE_CREDENTIALS_DATA ] = USER_CREDS_TOKEN_REFRESH_ERROR with pytest.raises(GDriveAccessTokenRefreshError): - remote.drive + remote._drive os.environ[RemoteGDrive.GDRIVE_CREDENTIALS_DATA] = "" remote = RemoteGDrive(Repo(), self.CONFIG) @@ -43,4 +43,4 @@ def test_drive(self): RemoteGDrive.GDRIVE_CREDENTIALS_DATA ] = USER_CREDS_MISSED_KEY_ERROR with pytest.raises(GDriveMissedCredentialKeyError): - remote.drive + remote._drive