Skip to content

Commit

Permalink
feat(client): use "session" instead of "urlopen" in "RemoteData.open"
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen.chen committed Dec 6, 2021
1 parent 670a3e2 commit c1ee389
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 233 deletions.
3 changes: 2 additions & 1 deletion tensorbay/client/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
UPLOAD_SEGMENT_RESUME_TEMPLATE_CLI,
UPLOAD_SEGMENT_RESUME_TEMPLATE_SDK,
)
from tensorbay.client.requests import Tqdm, multithread_upload
from tensorbay.client.requests import multithread_upload
from tensorbay.client.segment import _STRATEGIES, FusionSegmentClient, SegmentClient
from tensorbay.client.statistics import Statistics
from tensorbay.client.status import Status
Expand All @@ -51,6 +51,7 @@
StatusError,
)
from tensorbay.label import Catalog
from tensorbay.utility import Tqdm

if TYPE_CHECKING:
from tensorbay.client.gas import GAS
Expand Down
3 changes: 2 additions & 1 deletion tensorbay/client/gas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from tensorbay.client.dataset import DatasetClient, FusionDatasetClient
from tensorbay.client.lazy import PagingList
from tensorbay.client.log import UPLOAD_DATASET_RESUME_TEMPLATE
from tensorbay.client.requests import Client, Tqdm
from tensorbay.client.requests import Client
from tensorbay.client.status import Status
from tensorbay.client.struct import ROOT_COMMIT_ID, UserInfo
from tensorbay.dataset import Dataset, FusionDataset
from tensorbay.exception import DatasetTypeError, ResourceNotExistError
from tensorbay.utility import Tqdm

DatasetClientType = Union[DatasetClient, FusionDatasetClient]

Expand Down
11 changes: 5 additions & 6 deletions tensorbay/client/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import time
from collections import OrderedDict, defaultdict
from functools import wraps
from http.client import HTTPResponse
from itertools import chain
from multiprocessing import Manager
from multiprocessing.managers import SyncManager
Expand All @@ -27,10 +26,10 @@
from requests_toolbelt.multipart.encoder import FileWrapper, MultipartEncoder

from tensorbay.client.requests import Client
from tensorbay.utility.file import RemoteFileMixin
from tensorbay.utility import RemoteFileMixin, UserResponse

_Callable = TypeVar("_Callable", bound=Callable[..., Response])
_OpenCallable = TypeVar("_OpenCallable", bound=Callable[..., HTTPResponse])
_OpenCallable = TypeVar("_OpenCallable", bound=Callable[..., UserResponse])
_ReadCallable = Callable[..., bytes]

_COLUMNS = OrderedDict(
Expand Down Expand Up @@ -152,17 +151,17 @@ def wrapper(client: Client, method: str, url: str, **kwargs: Any) -> Response:
return wrapper # type: ignore[return-value]

def _statistical_read(self, download_path: str) -> _ReadCallable:
def wrapper(response: HTTPResponse, amt: Optional[int] = None) -> bytes:
def wrapper(response: UserResponse, amt: Optional[int] = None) -> bytes:
start_time = time.time()
content = HTTPResponse.read(response, amt)
content = UserResponse.read(response, amt)
self._update(download_path, 0, time.time() - start_time, len(content))
return content

return wrapper

def _statistical_open(self, func: _OpenCallable) -> _OpenCallable:
@wraps(func)
def wrapper(obj: RemoteFileMixin) -> HTTPResponse:
def wrapper(obj: RemoteFileMixin) -> UserResponse:
netloc = urlparse(obj.url.get()).netloc # type: ignore[union-attr]
download_path = f"[GET] {netloc}/*"

Expand Down
191 changes: 14 additions & 177 deletions tensorbay/client/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,142 +12,22 @@
"""

import logging
import os
from collections import defaultdict
from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait
from queue import Queue
from threading import Lock
from typing import Any, Callable, DefaultDict, Generic, Iterable, Optional, Tuple, TypeVar
from typing import Any, Callable, Generic, Iterable, Optional, Tuple, TypeVar
from urllib.parse import urljoin
from uuid import uuid4

import urllib3
from requests import Session
from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException
from requests.models import PreparedRequest, Response
from tqdm import tqdm
from urllib3.util.retry import Retry
from requests.models import Response

from tensorbay.__version__ import __version__
from tensorbay.client.log import RequestLogging, ResponseLogging
from tensorbay.exception import ResponseError, ResponseErrorDistributor
from tensorbay.utility import config
from tensorbay.utility import Tqdm, config, get_session

logger = logging.getLogger(__name__)


def _get_allowed_methods_keyword() -> str:
splits = urllib3.__version__.split(".", 2)
major = int(splits[0])
minor = int(splits[1])
return "allowed_methods" if (major, minor) >= (1, 26) else "method_whitelist"


# check the version of urllib3 and choose the correct keyword for "allowed_methods" in "Retry"
_ALLOWED_METHODS = _get_allowed_methods_keyword()


class TimeoutHTTPAdapter(HTTPAdapter):
"""This class defines the http adapter for setting the timeout value.
Arguments:
*args: Extra arguments to initialize TimeoutHTTPAdapter.
timeout: Timeout value of the post request in seconds.
**kwargs: Extra keyword arguments to initialize TimeoutHTTPAdapter.
"""

def __init__(self, *args: Any, timeout: Optional[int] = None, **kwargs: Any) -> None:
self.timeout = timeout if timeout is not None else config.timeout
super().__init__(*args, **kwargs)

def send( # pylint: disable=too-many-arguments
self,
request: PreparedRequest,
stream: Any = False,
timeout: Any = None,
verify: Any = True,
cert: Any = None,
proxies: Any = None,
) -> Any:
"""Send the request.
Arguments:
request: The PreparedRequest being sent.
stream: Whether to stream the request content.
timeout: Timeout value of the post request in seconds.
verify: A path string to a CA bundle to use or
a boolean which controls whether to verify the server's TLS certificate.
cert: User-provided SSL certificate.
proxies: Proxies dict applying to the request.
Returns:
Response object.
"""
if timeout is None:
timeout = self.timeout
return super().send(request, stream, timeout, verify, cert, proxies)


class UserSession(Session):
"""This class defines UserSession."""

def __init__(self) -> None:
super().__init__()
# self.session.hooks["response"] = [logging_hook]

retry_strategy = Retry(
total=config.max_retries,
status_forcelist=config.allowed_retry_status,
raise_on_status=False,
**{_ALLOWED_METHODS: config.allowed_retry_methods},
)

self.mount("http://", TimeoutHTTPAdapter(20, 20, retry_strategy))
self.mount("https://", TimeoutHTTPAdapter(20, 20, retry_strategy))

def request( # type: ignore[override]
self, method: str, url: str, *args: Any, **kwargs: Any
) -> Response: # noqa: DAR401
"""Make the request.
Arguments:
method: Method for the request.
url: URL for the request.
*args: Extra arguments to make the request.
**kwargs: Extra keyword arguments to make the request.
Returns:
Response of the request.
Raises:
ResponseError: If post response error.
"""
try:
response = super().request(method, url, *args, **kwargs)
if response.status_code not in (200, 201):
logger.error(
"Unexpected status code(%d)!%s", response.status_code, ResponseLogging(response)
)
raise ResponseError(response=response)

logger.debug(ResponseLogging(response))
return response

except RequestException as error:
logger.error(
"%s.%s: %s%s",
error.__class__.__module__,
error.__class__.__name__,
error,
RequestLogging(error.request),
)
raise


class Client:
"""This class defines :class:`Client`.
Expand Down Expand Up @@ -178,7 +58,6 @@ def __init__(self, access_key: str, url: str = "") -> None:
self.gateway_url = urljoin(url, "gatewayv2/")
self.access_key = access_key

self._sessions: DefaultDict[int, UserSession] = defaultdict(UserSession)
self._open_api = urljoin(self.gateway_url, "tensorbay-open-api/v1/")

def _url_make(self, section: str, dataset_id: str = "") -> str:
Expand All @@ -205,14 +84,20 @@ def _url_make(self, section: str, dataset_id: str = "") -> str:
url = urljoin(self._open_api, "datasets")
return url

@property
def session(self) -> UserSession:
"""Create and return a session per PID so each sub-processes will use their own session.
@staticmethod
def do(method: str, url: str, **kwargs: Any) -> Response: # pylint: disable=invalid-name
"""Send a request.
Arguments:
method: The method of the request.
url: The URL of the request.
**kwargs: Extra keyword arguments to send in the GET request.
Returns:
The session corresponding to the process.
Response of the request.
"""
return self._sessions[os.getpid()]
return get_session().request(method=method, url=url, **kwargs)

def open_api_do(
self, method: str, section: str, dataset_id: str = "", **kwargs: Any
Expand Down Expand Up @@ -248,56 +133,8 @@ def open_api_do(
response=response
) from None

def do(self, method: str, url: str, **kwargs: Any) -> Response: # pylint: disable=invalid-name
"""Send a request.
Arguments:
method: The method of the request.
url: The URL of the request.
**kwargs: Extra keyword arguments to send in the GET request.
Returns:
Response of the request.
"""
return self.session.request(method=method, url=url, **kwargs)


_T = TypeVar("_T")


class Tqdm(tqdm): # type: ignore[misc]
"""A wrapper class of tqdm for showing the process bar.
Arguments:
total: The number of excepted iterations.
disable: Whether to disable the entire progress bar.
"""

def __init__(self, total: int, disable: bool = False) -> None:
super().__init__(desc="Uploading", total=total, disable=disable)

def update_callback(self, _: Any) -> None:
"""Callback function for updating process bar when multithread task is done."""
self.update()

def update_for_skip(self, condition: bool) -> bool:
"""Update process bar for the items which are skipped in builtin filter function.
Arguments:
condition: The filter condition, the process bar will be updated if condition is False.
Returns:
The input condition.
"""
if not condition:
self.update()

return condition


_R = TypeVar("_R")


Expand Down
6 changes: 5 additions & 1 deletion tensorbay/utility/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tensorbay.utility.itertools import chunked
from tensorbay.utility.name import NameList, NameMixin, SortedNameList
from tensorbay.utility.repr import ReprMixin, ReprType, repr_config
from tensorbay.utility.request_config import config
from tensorbay.utility.requests import Tqdm, UserResponse, UserSession, config, get_session
from tensorbay.utility.type import TypeEnum, TypeMixin, TypeRegister
from tensorbay.utility.user import (
UserMapping,
Expand All @@ -41,14 +41,17 @@
"ReprMixin",
"ReprType",
"SortedNameList",
"Tqdm",
"TypeEnum",
"TypeMixin",
"TypeRegister",
"URL",
"UserMapping",
"UserMutableMapping",
"UserMutableSequence",
"UserResponse",
"UserSequence",
"UserSession",
"attr",
"attr_base",
"camel",
Expand All @@ -57,5 +60,6 @@
"config",
"locked",
"repr_config",
"get_session",
"upper",
]
Loading

0 comments on commit c1ee389

Please sign in to comment.