From 4ccbc95cd44102cbd1fca74af6f1585e309ff84f Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Sat, 30 May 2020 23:42:38 +0100 Subject: [PATCH] gdrive: add open Fixes #3408 Related #2865 Fixes #3897 --- dvc/remote/gdrive.py | 20 +++++++++++++++++++ dvc/utils/http.py | 46 ++------------------------------------------ dvc/utils/stream.py | 45 +++++++++++++++++++++++++++++++++++++++++++ setup.py | 2 +- 4 files changed, 68 insertions(+), 45 deletions(-) create mode 100644 dvc/utils/stream.py diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index dfc7b1d5a5..27d763afcc 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -1,9 +1,11 @@ +import io import logging import os import posixpath import re import threading from collections import defaultdict +from contextlib import contextmanager from urllib.parse import urlparse from funcy import cached_property, retry, wrap_prop, wrap_with @@ -15,6 +17,7 @@ from dvc.remote.base import BaseRemote from dvc.scheme import Schemes from dvc.utils import format_link, tmp_fname +from dvc.utils.stream import IterStream logger = logging.getLogger(__name__) FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" @@ -393,6 +396,23 @@ def _gdrive_download_file( ) as pbar: gdrive_file.GetContentFile(to_file, callback=pbar.update_to) + @contextmanager + @_gdrive_retry + def open(self, path_info, mode="r", encoding=None): + assert mode in {"r", "rt", "rb"} + + item_id = self._get_item_id(path_info) + param = {"id": item_id} + # it does not create a file on the remote + gdrive_file = self._drive.CreateFile(param) + fd = gdrive_file.GetContentIOBuffer() + stream = IterStream(iter(fd)) + + if mode != "rb": + stream = io.TextIOWrapper(stream, encoding=encoding) + + yield stream + @_gdrive_retry def _gdrive_delete_file(self, item_id): from pydrive2.files import ApiRequestError diff --git a/dvc/utils/http.py b/dvc/utils/http.py index b1fb13cb79..4472b80912 100644 --- a/dvc/utils/http.py +++ b/dvc/utils/http.py @@ -1,6 +1,8 @@ import io from contextlib import contextmanager +from dvc.utils.stream import IterStream + @contextmanager def open_url(url, mode="r", encoding=None): @@ -61,47 +63,3 @@ def gen(response): finally: # Ensure connection is closed it.close() - - -class IterStream(io.RawIOBase): - """Wraps an iterator yielding bytes as a file object""" - - def __init__(self, iterator): - self.iterator = iterator - self.leftover = None - - def readable(self): - return True - - # Python 3 requires only .readinto() method, it still uses other ones - # under some circumstances and falls back if those are absent. Since - # iterator already constructs byte strings for us, .readinto() is not the - # most optimal, so we provide .read1() too. - - def readinto(self, b): - try: - n = len(b) # We're supposed to return at most this much - chunk = self.leftover or next(self.iterator) - output, self.leftover = chunk[:n], chunk[n:] - - n_out = len(output) - b[:n_out] = output - return n_out - except StopIteration: - return 0 # indicate EOF - - readinto1 = readinto - - def read1(self, n=-1): - try: - chunk = self.leftover or next(self.iterator) - except StopIteration: - return b"" - - # Return an arbitrary number or bytes - if n <= 0: - self.leftover = None - return chunk - - output, self.leftover = chunk[:n], chunk[n:] - return output diff --git a/dvc/utils/stream.py b/dvc/utils/stream.py new file mode 100644 index 0000000000..6109475030 --- /dev/null +++ b/dvc/utils/stream.py @@ -0,0 +1,45 @@ +import io + + +class IterStream(io.RawIOBase): + """Wraps an iterator yielding bytes as a file object""" + + def __init__(self, iterator): + self.iterator = iterator + self.leftover = None + + def readable(self): + return True + + # Python 3 requires only .readinto() method, it still uses other ones + # under some circumstances and falls back if those are absent. Since + # iterator already constructs byte strings for us, .readinto() is not the + # most optimal, so we provide .read1() too. + + def readinto(self, b): + try: + n = len(b) # We're supposed to return at most this much + chunk = self.leftover or next(self.iterator) + output, self.leftover = chunk[:n], chunk[n:] + + n_out = len(output) + b[:n_out] = output + return n_out + except StopIteration: + return 0 # indicate EOF + + readinto1 = readinto + + def read1(self, n=-1): + try: + chunk = self.leftover or next(self.iterator) + except StopIteration: + return b"" + + # Return an arbitrary number or bytes + if n <= 0: + self.leftover = None + return chunk + + output, self.leftover = chunk[:n], chunk[n:] + return output diff --git a/setup.py b/setup.py index 5191b74b8a..50619690d1 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def run(self): # Extra dependencies for remote integrations gs = ["google-cloud-storage==1.19.0"] -gdrive = ["pydrive2>=1.4.13"] +gdrive = ["pydrive2>=1.4.15"] s3 = ["boto3>=1.9.201"] azure = ["azure-storage-blob==2.1.0"] oss = ["oss2==2.6.1"]