From 4a623cdc2506b823cce7672f33ebca5504ce256d Mon Sep 17 00:00:00 2001 From: "zhen.chen" Date: Thu, 2 Dec 2021 18:04:44 +0800 Subject: [PATCH] feat(client): use "session" instead of "urlopen" in "RemoteData.open" PR Closed: https://github.com/Graviti-AI/tensorbay-python-sdk/pull/1143 --- tensorbay/client/dataset.py | 3 +- tensorbay/client/gas.py | 3 +- tensorbay/client/profile.py | 11 +- tensorbay/client/requests.py | 191 ++------------------- tensorbay/utility/__init__.py | 6 +- tensorbay/utility/file.py | 28 +-- tensorbay/utility/request_config.py | 33 ---- tensorbay/utility/requests.py | 254 ++++++++++++++++++++++++++++ 8 files changed, 296 insertions(+), 233 deletions(-) delete mode 100644 tensorbay/utility/request_config.py create mode 100644 tensorbay/utility/requests.py diff --git a/tensorbay/client/dataset.py b/tensorbay/client/dataset.py index 3451b030e..6bfe96664 100644 --- a/tensorbay/client/dataset.py +++ b/tensorbay/client/dataset.py @@ -37,7 +37,7 @@ UPLOAD_SEGMENT_RESUME_TEMPLATE_CLI, UPLOAD_SEGMENT_RESUME_TEMPLATE_SDK, ) -from tensorbay.client.requests import Tqdm, multithread_upload +from tensorbay.client.requests import multithread_upload from tensorbay.client.segment import _STRATEGIES, FusionSegmentClient, SegmentClient from tensorbay.client.statistics import Statistics from tensorbay.client.status import Status @@ -51,6 +51,7 @@ StatusError, ) from tensorbay.label import Catalog +from tensorbay.utility import Tqdm if TYPE_CHECKING: from tensorbay.client.gas import GAS diff --git a/tensorbay/client/gas.py b/tensorbay/client/gas.py index 3f81d252e..5bda891c9 100644 --- a/tensorbay/client/gas.py +++ b/tensorbay/client/gas.py @@ -22,11 +22,12 @@ from tensorbay.client.dataset import DatasetClient, FusionDatasetClient from tensorbay.client.lazy import PagingList from tensorbay.client.log import UPLOAD_DATASET_RESUME_TEMPLATE -from tensorbay.client.requests import Client, Tqdm +from tensorbay.client.requests import Client from tensorbay.client.status import Status from tensorbay.client.struct import ROOT_COMMIT_ID, UserInfo from tensorbay.dataset import Dataset, FusionDataset from tensorbay.exception import DatasetTypeError, ResourceNotExistError +from tensorbay.utility import Tqdm DatasetClientType = Union[DatasetClient, FusionDatasetClient] diff --git a/tensorbay/client/profile.py b/tensorbay/client/profile.py index b237c54ef..a6ed9c4c5 100644 --- a/tensorbay/client/profile.py +++ b/tensorbay/client/profile.py @@ -15,7 +15,6 @@ import time from collections import OrderedDict, defaultdict from functools import wraps -from http.client import HTTPResponse from itertools import chain from multiprocessing import Manager from multiprocessing.managers import SyncManager @@ -27,10 +26,10 @@ from requests_toolbelt.multipart.encoder import FileWrapper, MultipartEncoder from tensorbay.client.requests import Client -from tensorbay.utility.file import RemoteFileMixin +from tensorbay.utility import RemoteFileMixin, UserResponse _Callable = TypeVar("_Callable", bound=Callable[..., Response]) -_OpenCallable = TypeVar("_OpenCallable", bound=Callable[..., HTTPResponse]) +_OpenCallable = TypeVar("_OpenCallable", bound=Callable[..., UserResponse]) _ReadCallable = Callable[..., bytes] _COLUMNS = OrderedDict( @@ -152,9 +151,9 @@ def wrapper(client: Client, method: str, url: str, **kwargs: Any) -> Response: return wrapper # type: ignore[return-value] def _statistical_read(self, download_path: str) -> _ReadCallable: - def wrapper(response: HTTPResponse, amt: Optional[int] = None) -> bytes: + def wrapper(response: UserResponse, amt: Optional[int] = None) -> bytes: start_time = time.time() - content = HTTPResponse.read(response, amt) + content = UserResponse.read(response, amt) self._update(download_path, 0, time.time() - start_time, len(content)) return content @@ -162,7 +161,7 @@ def wrapper(response: HTTPResponse, amt: Optional[int] = None) -> bytes: def _statistical_open(self, func: _OpenCallable) -> _OpenCallable: @wraps(func) - def wrapper(obj: RemoteFileMixin) -> HTTPResponse: + def wrapper(obj: RemoteFileMixin) -> UserResponse: netloc = urlparse(obj.url.get()).netloc # type: ignore[union-attr] download_path = f"[GET] {netloc}/*" diff --git a/tensorbay/client/requests.py b/tensorbay/client/requests.py index af6ec56b7..8345359cb 100644 --- a/tensorbay/client/requests.py +++ b/tensorbay/client/requests.py @@ -12,142 +12,22 @@ """ import logging -import os -from collections import defaultdict from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait from queue import Queue from threading import Lock -from typing import Any, Callable, DefaultDict, Generic, Iterable, Optional, Tuple, TypeVar +from typing import Any, Callable, Generic, Iterable, Optional, Tuple, TypeVar from urllib.parse import urljoin from uuid import uuid4 -import urllib3 -from requests import Session -from requests.adapters import HTTPAdapter -from requests.exceptions import RequestException -from requests.models import PreparedRequest, Response -from tqdm import tqdm -from urllib3.util.retry import Retry +from requests.models import Response from tensorbay.__version__ import __version__ -from tensorbay.client.log import RequestLogging, ResponseLogging from tensorbay.exception import ResponseError, ResponseErrorDistributor -from tensorbay.utility import config +from tensorbay.utility import Tqdm, config, get_session logger = logging.getLogger(__name__) -def _get_allowed_methods_keyword() -> str: - splits = urllib3.__version__.split(".", 2) - major = int(splits[0]) - minor = int(splits[1]) - return "allowed_methods" if (major, minor) >= (1, 26) else "method_whitelist" - - -# check the version of urllib3 and choose the correct keyword for "allowed_methods" in "Retry" -_ALLOWED_METHODS = _get_allowed_methods_keyword() - - -class TimeoutHTTPAdapter(HTTPAdapter): - """This class defines the http adapter for setting the timeout value. - - Arguments: - *args: Extra arguments to initialize TimeoutHTTPAdapter. - timeout: Timeout value of the post request in seconds. - **kwargs: Extra keyword arguments to initialize TimeoutHTTPAdapter. - - """ - - def __init__(self, *args: Any, timeout: Optional[int] = None, **kwargs: Any) -> None: - self.timeout = timeout if timeout is not None else config.timeout - super().__init__(*args, **kwargs) - - def send( # pylint: disable=too-many-arguments - self, - request: PreparedRequest, - stream: Any = False, - timeout: Any = None, - verify: Any = True, - cert: Any = None, - proxies: Any = None, - ) -> Any: - """Send the request. - - Arguments: - request: The PreparedRequest being sent. - stream: Whether to stream the request content. - timeout: Timeout value of the post request in seconds. - verify: A path string to a CA bundle to use or - a boolean which controls whether to verify the server's TLS certificate. - cert: User-provided SSL certificate. - proxies: Proxies dict applying to the request. - - Returns: - Response object. - - """ - if timeout is None: - timeout = self.timeout - return super().send(request, stream, timeout, verify, cert, proxies) - - -class UserSession(Session): - """This class defines UserSession.""" - - def __init__(self) -> None: - super().__init__() - # self.session.hooks["response"] = [logging_hook] - - retry_strategy = Retry( - total=config.max_retries, - status_forcelist=config.allowed_retry_status, - raise_on_status=False, - **{_ALLOWED_METHODS: config.allowed_retry_methods}, - ) - - self.mount("http://", TimeoutHTTPAdapter(20, 20, retry_strategy)) - self.mount("https://", TimeoutHTTPAdapter(20, 20, retry_strategy)) - - def request( # type: ignore[override] - self, method: str, url: str, *args: Any, **kwargs: Any - ) -> Response: # noqa: DAR401 - """Make the request. - - Arguments: - method: Method for the request. - url: URL for the request. - *args: Extra arguments to make the request. - **kwargs: Extra keyword arguments to make the request. - - Returns: - Response of the request. - - Raises: - ResponseError: If post response error. - - """ - try: - response = super().request(method, url, *args, **kwargs) - if response.status_code not in (200, 201): - logger.error( - "Unexpected status code(%d)!%s", response.status_code, ResponseLogging(response) - ) - raise ResponseError(response=response) - - logger.debug(ResponseLogging(response)) - return response - - except RequestException as error: - logger.error( - "%s.%s: %s%s", - error.__class__.__module__, - error.__class__.__name__, - error, - RequestLogging(error.request), - ) - raise - - class Client: """This class defines :class:`Client`. @@ -178,7 +58,6 @@ def __init__(self, access_key: str, url: str = "") -> None: self.gateway_url = urljoin(url, "gatewayv2/") self.access_key = access_key - self._sessions: DefaultDict[int, UserSession] = defaultdict(UserSession) self._open_api = urljoin(self.gateway_url, "tensorbay-open-api/v1/") def _url_make(self, section: str, dataset_id: str = "") -> str: @@ -205,14 +84,20 @@ def _url_make(self, section: str, dataset_id: str = "") -> str: url = urljoin(self._open_api, "datasets") return url - @property - def session(self) -> UserSession: - """Create and return a session per PID so each sub-processes will use their own session. + @staticmethod + def do(method: str, url: str, **kwargs: Any) -> Response: # pylint: disable=invalid-name + """Send a request. + + Arguments: + method: The method of the request. + url: The URL of the request. + **kwargs: Extra keyword arguments to send in the GET request. Returns: - The session corresponding to the process. + Response of the request. + """ - return self._sessions[os.getpid()] + return get_session().request(method=method, url=url, **kwargs) def open_api_do( self, method: str, section: str, dataset_id: str = "", **kwargs: Any @@ -248,56 +133,8 @@ def open_api_do( response=response ) from None - def do(self, method: str, url: str, **kwargs: Any) -> Response: # pylint: disable=invalid-name - """Send a request. - - Arguments: - method: The method of the request. - url: The URL of the request. - **kwargs: Extra keyword arguments to send in the GET request. - - Returns: - Response of the request. - - """ - return self.session.request(method=method, url=url, **kwargs) - _T = TypeVar("_T") - - -class Tqdm(tqdm): # type: ignore[misc] - """A wrapper class of tqdm for showing the process bar. - - Arguments: - total: The number of excepted iterations. - disable: Whether to disable the entire progress bar. - - """ - - def __init__(self, total: int, disable: bool = False) -> None: - super().__init__(desc="Uploading", total=total, disable=disable) - - def update_callback(self, _: Any) -> None: - """Callback function for updating process bar when multithread task is done.""" - self.update() - - def update_for_skip(self, condition: bool) -> bool: - """Update process bar for the items which are skipped in builtin filter function. - - Arguments: - condition: The filter condition, the process bar will be updated if condition is False. - - Returns: - The input condition. - - """ - if not condition: - self.update() - - return condition - - _R = TypeVar("_R") diff --git a/tensorbay/utility/__init__.py b/tensorbay/utility/__init__.py index d68114a58..0ba66a613 100644 --- a/tensorbay/utility/__init__.py +++ b/tensorbay/utility/__init__.py @@ -17,7 +17,7 @@ from tensorbay.utility.itertools import chunked from tensorbay.utility.name import NameList, NameMixin, SortedNameList from tensorbay.utility.repr import ReprMixin, ReprType, repr_config -from tensorbay.utility.request_config import config +from tensorbay.utility.requests import Tqdm, UserResponse, UserSession, config, get_session from tensorbay.utility.type import TypeEnum, TypeMixin, TypeRegister from tensorbay.utility.user import ( UserMapping, @@ -41,6 +41,7 @@ "ReprMixin", "ReprType", "SortedNameList", + "Tqdm", "TypeEnum", "TypeMixin", "TypeRegister", @@ -48,7 +49,9 @@ "UserMapping", "UserMutableMapping", "UserMutableSequence", + "UserResponse", "UserSequence", + "UserSession", "attr", "attr_base", "camel", @@ -57,5 +60,6 @@ "config", "locked", "repr_config", + "get_session", "upper", ] diff --git a/tensorbay/utility/file.py b/tensorbay/utility/file.py index f4cf1b045..b6d03ef07 100644 --- a/tensorbay/utility/file.py +++ b/tensorbay/utility/file.py @@ -7,17 +7,15 @@ import os from hashlib import sha1 -from http.client import HTTPResponse -from string import printable from typing import Any, Callable, Dict, Optional, Union -from urllib.error import HTTPError -from urllib.parse import quote, urljoin -from urllib.request import pathname2url, urlopen +from urllib.parse import urljoin +from urllib.request import pathname2url from _io import BufferedReader +from tensorbay.exception import ResponseError from tensorbay.utility.repr import ReprMixin -from tensorbay.utility.request_config import config +from tensorbay.utility.requests import UserResponse, config, get_session class URL: @@ -161,19 +159,21 @@ def __init__( def _repr_head(self) -> str: return f'{self.__class__.__name__}("{self.path}")' - def _urlopen(self) -> HTTPResponse: + def _urlopen(self) -> UserResponse: + url = self.get_url() if not self.url: raise ValueError(f"The file cannot open because {self._repr_head()} has no url") try: - return urlopen( # type: ignore[no-any-return] - quote(self.url.get(), safe=printable), timeout=config.timeout - ) - except HTTPError as error: - if error.code == 403: + session = get_session() + return UserResponse(session.request("GET", url, timeout=config.timeout, stream=True)) + except ResponseError as error: + if error.response.status_code == 403: self.url.update() - return urlopen(quote(self.url.get(), safe=printable)) # type: ignore[no-any-return] + return UserResponse( + get_session().request("GET", url, timeout=config.timeout, stream=True) + ) raise def get_url(self) -> str: @@ -191,7 +191,7 @@ def get_url(self) -> str: return self.url.get() - def open(self) -> Union[HTTPResponse, BufferedReader]: + def open(self) -> Union[UserResponse, BufferedReader]: """Return the binary file pointer of this file. The remote file pointer will be obtained by ``urllib.request.urlopen()``. diff --git a/tensorbay/utility/request_config.py b/tensorbay/utility/request_config.py deleted file mode 100644 index d8f966b1f..000000000 --- a/tensorbay/utility/request_config.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Graviti. Licensed under MIT License. -# - -"""The request related configs.""" - - -class Config: - """This is a base class defining the concept of Request Config. - - Attributes: - max_retries: Maximum retry times of the request. - allowed_retry_methods: The allowed methods for retrying request. - allowed_retry_status: The allowed status for retrying request. - If both methods and status are fitted, the retrying strategy will work. - timeout: Timeout value of the request in seconds. - is_internal: Whether the request is from internal. - - """ - - def __init__(self) -> None: - - self.max_retries = 3 - self.allowed_retry_methods = ["HEAD", "OPTIONS", "POST", "PUT"] - self.allowed_retry_status = [429, 500, 502, 503, 504] - - self.timeout = 30 - self.is_internal = False - self._x_source = "PYTHON-SDK" - - -config = Config() diff --git a/tensorbay/utility/requests.py b/tensorbay/utility/requests.py new file mode 100644 index 000000000..5509b6fa5 --- /dev/null +++ b/tensorbay/utility/requests.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Graviti. Licensed under MIT License. +# + +"""The request related tools.""" + +import logging +import os +from collections import defaultdict +from typing import Any, DefaultDict, Optional + +import urllib3 +from requests import Session +from requests.adapters import HTTPAdapter +from requests.exceptions import RequestException, StreamConsumedError +from requests.models import PreparedRequest, Response +from tqdm import tqdm +from urllib3.util.retry import Retry + +from tensorbay.client.log import RequestLogging, ResponseLogging +from tensorbay.exception import ResponseError + +logger = logging.getLogger(__name__) + + +_CHUNK_SIZE = 8 * 1024 + + +def _get_allowed_methods_keyword() -> str: + splits = urllib3.__version__.split(".", 2) + major = int(splits[0]) + minor = int(splits[1]) + return "allowed_methods" if (major, minor) >= (1, 26) else "method_whitelist" + + +# check the version of urllib3 and choose the correct keyword for "allowed_methods" in "Retry" +_ALLOWED_METHODS = _get_allowed_methods_keyword() + + +class Config: + """This is a base class defining the concept of Request Config. + + Attributes: + max_retries: Maximum retry times of the request. + allowed_retry_methods: The allowed methods for retrying request. + allowed_retry_status: The allowed status for retrying request. + If both methods and status are fitted, the retrying strategy will work. + timeout: Timeout value of the request in seconds. + is_internal: Whether the request is from internal. + + """ + + def __init__(self) -> None: + + self.max_retries = 3 + self.allowed_retry_methods = ["HEAD", "OPTIONS", "POST", "PUT"] + self.allowed_retry_status = [429, 500, 502, 503, 504] + + self.timeout = 30 + self.is_internal = False + self._x_source = "PYTHON-SDK" + + +config = Config() + + +class TimeoutHTTPAdapter(HTTPAdapter): + """This class defines the http adapter for setting the timeout value. + + Arguments: + *args: Extra arguments to initialize TimeoutHTTPAdapter. + timeout: Timeout value of the post request in seconds. + **kwargs: Extra keyword arguments to initialize TimeoutHTTPAdapter. + + """ + + def __init__(self, *args: Any, timeout: Optional[int] = None, **kwargs: Any) -> None: + self.timeout = timeout if timeout is not None else config.timeout + super().__init__(*args, **kwargs) + + def send( # pylint: disable=too-many-arguments + self, + request: PreparedRequest, + stream: Any = False, + timeout: Any = None, + verify: Any = True, + cert: Any = None, + proxies: Any = None, + ) -> Any: + """Send the request. + + Arguments: + request: The PreparedRequest being sent. + stream: Whether to stream the request content. + timeout: Timeout value of the post request in seconds. + verify: A path string to a CA bundle to use or + a boolean which controls whether to verify the server's TLS certificate. + cert: User-provided SSL certificate. + proxies: Proxies dict applying to the request. + + Returns: + Response object. + + """ + if timeout is None: + timeout = self.timeout + return super().send(request, stream, timeout, verify, cert, proxies) + + +class UserSession(Session): + """This class defines UserSession.""" + + def __init__(self) -> None: + super().__init__() + # self.session.hooks["response"] = [logging_hook] + + retry_strategy = Retry( + total=config.max_retries, + status_forcelist=config.allowed_retry_status, + raise_on_status=False, + **{_ALLOWED_METHODS: config.allowed_retry_methods}, + ) + + self.mount("http://", TimeoutHTTPAdapter(20, 20, retry_strategy)) + self.mount("https://", TimeoutHTTPAdapter(20, 20, retry_strategy)) + + def request( # type: ignore[override] + self, method: str, url: str, *args: Any, **kwargs: Any + ) -> Response: # noqa: DAR401 + """Make the request. + + Arguments: + method: Method for the request. + url: URL for the request. + *args: Extra arguments to make the request. + **kwargs: Extra keyword arguments to make the request. + + Returns: + Response of the request. + + Raises: + ResponseError: If post response error. + + """ + try: + response = super().request(method, url, *args, **kwargs) + if response.status_code not in (200, 201): + logger.error( + "Unexpected status code(%d)!%s", response.status_code, ResponseLogging(response) + ) + raise ResponseError(response=response) + + logger.debug(ResponseLogging(response)) + return response + + except RequestException as error: + logger.error( + "%s.%s: %s%s", + error.__class__.__module__, + error.__class__.__name__, + error, + RequestLogging(error.request), + ) + raise + + +SESSIONS: DefaultDict[int, UserSession] = defaultdict(UserSession) + + +def get_session() -> UserSession: + """Create and return a session per PID so each sub-processes will use their own session. + + Returns: + The session corresponding to the process. + """ + return SESSIONS[os.getpid()] + + +class UserResponse: + """This class used to read data from Response with stream method. + + Arguments: + response: Response of the Session.request(). + + """ + + def __init__(self, response: Response): + self.response = response + + def __enter__(self) -> "UserResponse": + return self + + def __exit__(self, *_: Any) -> None: + self.close() + + def close(self) -> None: + """Close the response.""" + self.response.close() + + def read(self, amt: Optional[int] = None) -> bytes: + """Read data from response. + + Arguments: + amt: The needed read amount. + + Returns: + Response of the request. + + """ + if amt is None: + try: + return b"".join(chunk for chunk in self.response.iter_content(_CHUNK_SIZE)) + + except StreamConsumedError: + return b"" + + try: + return next(self.response.iter_content(amt)) # type: ignore[no-any-return] + + except StopIteration: + return b"" + + +class Tqdm(tqdm): # type: ignore[misc] + """A wrapper class of tqdm for showing the process bar. + + Arguments: + total: The number of excepted iterations. + disable: Whether to disable the entire progress bar. + + """ + + def __init__(self, total: int, disable: bool = False) -> None: + super().__init__(desc="Uploading", total=total, disable=disable) + + def update_callback(self, _: Any) -> None: + """Callback function for updating process bar when multithread task is done.""" + self.update() + + def update_for_skip(self, condition: bool) -> bool: + """Update process bar for the items which are skipped in builtin filter function. + + Arguments: + condition: The filter condition, the process bar will be updated if condition is False. + + Returns: + The input condition. + + """ + if not condition: + self.update() + + return condition