From 82e627f7fcbb39928f9f001aef59bd7ba59b6ac2 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Mon, 9 Aug 2021 10:57:32 +0300 Subject: [PATCH] gdrive: migrate to pydrive2.fs --- dvc/fs/gdrive.py | 383 ++----------------------------- setup.py | 5 +- tests/func/test_data_cloud.py | 8 +- tests/func/test_fs.py | 6 +- tests/remotes/gdrive.py | 4 +- tests/unit/remote/test_gdrive.py | 4 +- 6 files changed, 32 insertions(+), 378 deletions(-) diff --git a/dvc/fs/gdrive.py b/dvc/fs/gdrive.py index ee277f2c1c..e622e1619c 100644 --- a/dvc/fs/gdrive.py +++ b/dvc/fs/gdrive.py @@ -1,25 +1,18 @@ -import errno -import io import logging import os import posixpath import re import threading -from collections import defaultdict -from contextlib import contextmanager from urllib.parse import urlparse -from funcy import cached_property, retry, wrap_prop, wrap_with -from funcy.py3 import cat +from funcy import cached_property, retry, wrap_prop from dvc.exceptions import DvcException from dvc.path_info import CloudURLInfo -from dvc.progress import Tqdm from dvc.scheme import Schemes from dvc.utils import format_link, tmp_fname -from dvc.utils.stream import IterStream -from .base import BaseFileSystem +from .fsspec_wrapper import CallbackMixin, FSSpecWrapper logger = logging.getLogger(__name__) FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" @@ -86,7 +79,9 @@ def __init__(self, url): self._spath = re.sub("/{2,}", "/", self._spath.rstrip("/")) -class GDriveFileSystem(BaseFileSystem): # pylint:disable=abstract-method +class GDriveFileSystem( + CallbackMixin, FSSpecWrapper +): # pylint:disable=abstract-method scheme = Schemes.GDRIVE PATH_CLS = GDriveURLInfo PARAM_CHECKSUM = "checksum" @@ -214,9 +209,9 @@ def _validate_credentials(auth, settings): @wrap_prop(threading.RLock()) @cached_property - def _drive(self): + def fs(self): from pydrive2.auth import GoogleAuth - from pydrive2.drive import GoogleDrive + from pydrive2.fs import GDriveFileSystem as _GDriveFileSystem temporary_save_path = self._gdrive_user_credentials_path is_credentials_temp = os.getenv( @@ -291,359 +286,19 @@ def _drive(self): if is_credentials_temp: os.remove(temporary_save_path) - return GoogleDrive(gauth) - - @wrap_prop(threading.RLock()) - @cached_property - def _ids_cache(self): - cache = { - "dirs": defaultdict(list), - "ids": {}, - "root_id": self._get_item_id( - self.path_info, - use_cache=False, - hint="Confirm the directory exists and you can access it.", - ), - } - - self._cache_path_id(self._path, cache["root_id"], cache=cache) - - for item in self._gdrive_list( - "'{}' in parents and trashed=false".format(cache["root_id"]) - ): - item_path = (self.path_info / item["title"]).path - self._cache_path_id(item_path, item["id"], cache=cache) - - return cache - - def _cache_path_id(self, path, *item_ids, cache=None): - cache = cache or self._ids_cache - for item_id in item_ids: - cache["dirs"][path].append(item_id) - cache["ids"][item_id] = path - - @cached_property - def _list_params(self): - params = {"corpora": "default"} - if self._bucket != "root" and self._bucket != "appDataFolder": - drive_id = self._gdrive_shared_drive_id(self._bucket) - if drive_id: - logger.debug( - "GDrive remote '{}' is using shared drive id '{}'.".format( - self.path_info, drive_id - ) - ) - params["driveId"] = drive_id - params["corpora"] = "drive" - return params - - @_gdrive_retry - def _gdrive_shared_drive_id(self, item_id): - from pydrive2.files import ApiRequestError - - param = {"id": item_id} - # it does not create a file on the remote - item = self._drive.CreateFile(param) - # ID of the shared drive the item resides in. - # Only populated for items in shared drives. - try: - item.FetchMetadata("driveId") - except ApiRequestError as exc: - error_code = exc.error.get("code", 0) - if error_code == 404: - raise DvcException( - "'{}' for '{}':\n\n" - "1. Confirm the directory exists and you can access it.\n" - "2. Make sure that credentials in '{}'\n" - " are correct for this remote e.g. " - "use the `gdrive_user_credentials_file` config\n" - " option if you use multiple GDrive remotes with " - "different email accounts.\n\nDetails".format( - item_id, self.path_info, self.credentials_location - ) - ) from exc - raise - - return item.get("driveId", None) - - @_gdrive_retry - def _gdrive_upload_fobj(self, fobj, parent_id, title): - item = self._drive.CreateFile( - {"title": title, "parents": [{"id": parent_id}]} - ) - item.content = fobj - item.Upload() - return item - - @_gdrive_retry - def _gdrive_download_file( - self, item_id, to_file, progress_desc, no_progress_bar - ): - param = {"id": item_id} - # it does not create a file on the remote - gdrive_file = self._drive.CreateFile(param) - - with Tqdm( - desc=progress_desc, - disable=no_progress_bar, - bytes=True, - # explicit `bar_format` as `total` will be set by `update_to` - bar_format=Tqdm.BAR_FMT_DEFAULT, - ) as pbar: - gdrive_file.GetContentFile(to_file, callback=pbar.update_to) - - @contextmanager - @_gdrive_retry - def open(self, path_info, mode="r", encoding=None, **kwargs): - assert mode in {"r", "rt", "rb"} - - item_id = self._get_item_id(path_info) - param = {"id": item_id} - # it does not create a file on the remote - gdrive_file = self._drive.CreateFile(param) - fd = gdrive_file.GetContentIOBuffer() - stream = IterStream(iter(fd)) - - if mode != "rb": - stream = io.TextIOWrapper(stream, encoding=encoding) - - yield stream - - @_gdrive_retry - def gdrive_delete_file(self, item_id): - from pydrive2.files import ApiRequestError - - param = {"id": item_id} - # it does not create a file on the remote - item = self._drive.CreateFile(param) - - try: - item.Trash() if self._trash_only else item.Delete() - except ApiRequestError as exc: - http_error_code = exc.error.get("code", 0) - if ( - http_error_code == 403 - and self._list_params["corpora"] == "drive" - and exc.GetField("location") == "file.permissions" - ): - raise DvcException( - "Insufficient permissions to {}. You should have {} " - "access level for the used shared drive. More details " - "at {}.".format( - "move the file into Trash" - if self._trash_only - else "permanently delete the file", - "Manager or Content Manager" - if self._trash_only - else "Manager", - "https://support.google.com/a/answer/7337554", - ) - ) from exc - raise - - def _gdrive_list(self, query): - param = {"q": query, "maxResults": 1000} - param.update(self._list_params) - file_list = self._drive.ListFile(param) - - # Isolate and decorate fetching of remote drive items in pages. - get_list = _gdrive_retry(lambda: next(file_list, None)) - - # Fetch pages until None is received, lazily flatten the thing. - return cat(iter(get_list, None)) - - @_gdrive_retry - def _gdrive_create_dir(self, parent_id, title): - parent = {"id": parent_id} - item = self._drive.CreateFile( - {"title": title, "parents": [parent], "mimeType": FOLDER_MIME_TYPE} + return _GDriveFileSystem( + self._with_bucket(self.path_info), + gauth, + trash_only=self._trash_only, ) - item.Upload() - return item - @wrap_with(threading.RLock()) - def _create_dir(self, parent_id, title, remote_path): - cached = self._ids_cache["dirs"].get(remote_path) - if cached: - return cached[0] - - item = self._gdrive_create_dir(parent_id, title) - - if parent_id == self._ids_cache["root_id"]: - self._cache_path_id(remote_path, item["id"]) - - return item["id"] - - def _get_remote_item_ids(self, parent_ids, title): - if not parent_ids: - return None - query = "trashed=false and ({})".format( - " or ".join( - f"'{parent_id}' in parents" for parent_id in parent_ids - ) - ) - query += " and title='{}'".format(title.replace("'", "\\'")) - - # GDrive list API is case insensitive, we need to compare - # all results and pick the ones with the right title - return [ - item["id"] - for item in self._gdrive_list(query) - if item["title"] == title - ] - - def _get_cached_item_ids(self, path, use_cache): - if not path: - return [self._bucket] - if use_cache: - return self._ids_cache["dirs"].get(path, []) - return [] - - def _path_to_item_ids(self, path, create=False, use_cache=True): - item_ids = self._get_cached_item_ids(path, use_cache) - if item_ids: - return item_ids - - parent_path, title = posixpath.split(path) - parent_ids = self._path_to_item_ids(parent_path, create, use_cache) - item_ids = self._get_remote_item_ids(parent_ids, title) - if item_ids: - return item_ids - - return ( - [self._create_dir(min(parent_ids), title, path)] if create else [] - ) + def _with_bucket(self, path): + if isinstance(path, self.PATH_CLS): + return posixpath.join(path.bucket, path.path) - def _get_item_id(self, path_info, create=False, use_cache=True, hint=None): - assert path_info.bucket == self._bucket - - item_ids = self._path_to_item_ids(path_info.path, create, use_cache) - if item_ids: - return min(item_ids) - - assert not create - raise FileNotFoundError( - errno.ENOENT, os.strerror(errno.ENOENT), hint or path_info - ) - - def exists(self, path_info) -> bool: - try: - self._get_item_id(path_info) - except FileNotFoundError: - return False - else: - return True - - def _gdrive_list_ids(self, query_ids): - query = " or ".join( - f"'{query_id}' in parents" for query_id in query_ids - ) - query = f"({query}) and trashed=false" - return self._gdrive_list(query) - - def find(self, path_info, detail=False, prefix=None): - root_path = path_info.path - seen_paths = set() - - dir_ids = [self._ids_cache["ids"].copy()] - while dir_ids: - query_ids = { - dir_id: dir_name - for dir_id, dir_name in dir_ids.pop().items() - if posixpath.commonpath([root_path, dir_name]) == root_path - if dir_id not in seen_paths - } - if not query_ids: - continue - - seen_paths |= query_ids.keys() - - new_query_ids = {} - dir_ids.append(new_query_ids) - for item in self._gdrive_list_ids(query_ids): - parent_id = item["parents"][0]["id"] - item_path = posixpath.join(query_ids[parent_id], item["title"]) - if item["mimeType"] == FOLDER_MIME_TYPE: - new_query_ids[item["id"]] = item_path - self._cache_path_id(item_path, item["id"]) - continue - - if detail: - yield { - "type": "file", - "name": item_path, - "size": item["fileSize"], - "checksum": item["md5Checksum"], - } - else: - yield item_path - - def ls(self, path_info, detail=False): - cached = path_info.path in self._ids_cache["dirs"] - if cached: - dir_ids = self._ids_cache["dirs"][path_info.path] - else: - dir_ids = self._path_to_item_ids(path_info.path) - - if not dir_ids: - return None - - root_path = path_info.path - for item in self._gdrive_list_ids(dir_ids): - item_path = posixpath.join(root_path, item["title"]) - if detail: - if item["mimeType"] == FOLDER_MIME_TYPE: - yield {"type": "directory", "name": item_path} - else: - yield { - "type": "file", - "name": item_path, - "size": item["fileSize"], - "checksum": item["md5Checksum"], - } - else: - yield item_path - - if not cached: - self._cache_path_id(root_path, *dir_ids) - - def walk_files(self, path_info, **kwargs): - for file in self.find(path_info): - yield path_info.replace(path=file) - - def remove(self, path_info): - item_id = self._get_item_id(path_info) - self.gdrive_delete_file(item_id) - - def info(self, path_info): - item_id = self._get_item_id(path_info) - gdrive_file = self._drive.CreateFile({"id": item_id}) - gdrive_file.FetchMetadata(fields="fileSize") - return {"size": int(gdrive_file.get("fileSize")), "type": "file"} - - def _upload_fobj(self, fobj, to_info, **kwargs): - dirname = to_info.parent - assert dirname - parent_id = self._get_item_id(dirname, create=True) - self._gdrive_upload_fobj(fobj, parent_id, to_info.name) - - def _upload( - self, - from_file, - to_info, - name=None, - no_progress_bar=False, - **_kwargs, - ): - with open(from_file, "rb") as fobj: - self.upload_fobj( - fobj, - to_info, - size=os.path.getsize(from_file), - no_progress_bar=no_progress_bar, - desc=name or to_info.name, - ) + return super()._with_bucket(path) - def _download(self, from_info, to_file, name=None, no_progress_bar=False): - item_id = self._get_item_id(from_info) - self._gdrive_download_file(item_id, to_file, name, no_progress_bar) + def _upload_fobj(self, fobj, to_info, size: int = None): + rpath = self._with_bucket(to_info) + self.makedirs(os.path.dirname(rpath)) + return self.fs.upload_fobj(fobj, rpath, size=size) diff --git a/setup.py b/setup.py index 67f17a8f3b..d91e34946e 100644 --- a/setup.py +++ b/setup.py @@ -96,7 +96,10 @@ def run(self): # Extra dependencies for remote integrations gs = ["gcsfs==2021.7.0"] -gdrive = ["pydrive2>=1.8.1", "six >= 1.13.0"] +gdrive = [ + "pydrive2[fsspec] @ git+https://github.com/iterative/pydrive2@fsspec", + "six >= 1.13.0", +] s3 = ["s3fs==2021.7.0", "aiobotocore[boto3]>1.0.1"] azure = ["adlfs==2021.7.1", "azure-identity>=1.4.0", "knack"] # https://github.com/Legrandin/pycryptodome/issues/465 diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 9878033b7d..2fb4cd42c2 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -24,14 +24,8 @@ "webdav", "webhdfs", "oss", + "gdrive", ] -] + [ - pytest.param( - pytest.lazy_fixture("gdrive"), - marks=pytest.mark.xfail( - reason="https://github.com/iterative/dvc/issues/6347" - ), - ) ] # Clouds that implement the general methods that can be tested diff --git a/tests/func/test_fs.py b/tests/func/test_fs.py index 96154f9d76..8452f90e1e 100644 --- a/tests/func/test_fs.py +++ b/tests/func/test_fs.py @@ -301,7 +301,9 @@ def test_fs_ls(dvc, cloud): fs = cls(**config) path_info /= "directory" - assert {os.path.basename(file_key) for file_key in fs.ls(path_info)} == { + assert { + os.path.basename(file_key.rstrip("/")) for file_key in fs.ls(path_info) + } == { "foo", "bar", "baz", @@ -309,7 +311,7 @@ def test_fs_ls(dvc, cloud): } assert set(fs.ls(path_info / "empty")) == set() assert { - (detail["type"], os.path.basename(detail["name"])) + (detail["type"], os.path.basename(detail["name"].rstrip("/"))) for detail in fs.ls(path_info / "baz", detail=True) } == {("file", "quux"), ("directory", "egg")} diff --git a/tests/remotes/gdrive.py b/tests/remotes/gdrive.py index f4c452ba73..b67ceae2df 100644 --- a/tests/remotes/gdrive.py +++ b/tests/remotes/gdrive.py @@ -94,7 +94,7 @@ def client(self): @_gdrive_retry def mkdir(self, mode=0o777, parents=False, exist_ok=False): - if not self.client.exists(self.path): + if not self.client.info(self.path): self.client.mkdir(self.path) @_gdrive_retry @@ -124,5 +124,5 @@ def gdrive(test_config, make_tmp_dir): fs = GDriveFileSystem( gdrive_credentials_tmp_dir=tmp_dir.dvc.tmp_dir, **ret.config ) - fs._gdrive_create_dir("root", fs.path_info.path) + fs.fs._gdrive_create_dir("root", fs.path_info.path) yield ret diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index 4fe02eb8da..b0dbf98781 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -29,7 +29,7 @@ def test_drive(self, dvc, monkeypatch): USER_CREDS_TOKEN_REFRESH_ERROR, ) with pytest.raises(GDriveAuthError): - assert fs._drive + assert fs.fs monkeypatch.setenv(GDriveFileSystem.GDRIVE_CREDENTIALS_DATA, "") fs = GDriveFileSystem( @@ -40,4 +40,4 @@ def test_drive(self, dvc, monkeypatch): USER_CREDS_MISSED_KEY_ERROR, ) with pytest.raises(GDriveAuthError): - assert fs._drive + assert fs.fs