From f04ecc241119b44fa52658cc17c5eab4adcb1648 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 24 Oct 2018 21:44:56 +0200 Subject: [PATCH] Annotate client.py --- aiohttp/client.py | 387 ++++++++++++++++++++--------------- aiohttp/client_exceptions.py | 10 +- aiohttp/cookiejar.py | 5 +- aiohttp/tracing.py | 4 +- aiohttp/typedefs.py | 7 +- 5 files changed, 240 insertions(+), 173 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index e58d1377c8e..542c9cca2e5 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -8,8 +8,9 @@ import sys import traceback import warnings -from collections.abc import Coroutine as CoroutineABC -from typing import Any, Generator, Optional, Tuple +from types import SimpleNamespace, TracebackType +from typing import (Any, Coroutine, Generator, Generic, Iterable, List, # noqa + Mapping, Optional, Set, Tuple, Type, TypeVar, Union) import attr from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr @@ -18,24 +19,27 @@ from . import client_exceptions, client_reqrep from . import connector as connector_mod from . import hdrs, http, payload +from .abc import AbstractCookieJar from .client_exceptions import * # noqa from .client_exceptions import (ClientError, ClientOSError, InvalidURL, ServerTimeoutError, TooManyRedirects, WSServerHandshakeError) from .client_reqrep import * # noqa -from .client_reqrep import ClientRequest, ClientResponse, _merge_ssl_params +from .client_reqrep import (ClientRequest, ClientResponse, Fingerprint, + _merge_ssl_params) from .client_ws import ClientWebSocketResponse from .connector import * # noqa from .connector import BaseConnector, TCPConnector from .cookiejar import CookieJar -from .helpers import (DEBUG, PY_36, CeilTimeout, TimeoutHandle, +from .helpers import (DEBUG, PY_36, BasicAuth, CeilTimeout, TimeoutHandle, proxies_from_env, sentinel, strip_auth_from_url) -from .http import WS_KEY, WebSocketReader, WebSocketWriter -from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse +from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter +from .http_websocket import (WSHandshakeError, WSMessage, ws_ext_gen, # noqa + ws_ext_parse) from .streams import FlowControlDataQueue from .tcp_helpers import tcp_cork, tcp_nodelay -from .tracing import Trace -from .typedefs import StrOrURL +from .tracing import Trace, TraceConfig +from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, StrOrURL __all__ = (client_exceptions.__all__ + # noqa @@ -45,6 +49,12 @@ 'ClientWebSocketResponse', 'request')) +try: + from ssl import SSLContext +except ImportError: # pragma: no cover + SSLContext = object # type: ignore + + @attr.s(frozen=True, slots=True) class ClientTimeout: total = attr.ib(type=float, default=None) @@ -69,6 +79,8 @@ class ClientTimeout: # 5 Minute default read timeout DEFAULT_TIMEOUT = ClientTimeout(total=5*60) +_RetType = TypeVar('_RetType') + class ClientSession: """First-class interface for making HTTP requests.""" @@ -88,18 +100,26 @@ class ClientSession: requote_redirect_url = True - def __init__(self, *, connector=None, loop=None, cookies=None, - headers=None, skip_auto_headers=None, - auth=None, json_serialize=json.dumps, - request_class=ClientRequest, response_class=ClientResponse, - ws_response_class=ClientWebSocketResponse, - version=http.HttpVersion11, - cookie_jar=None, connector_owner=True, - raise_for_status=False, - read_timeout=sentinel, conn_timeout=None, - timeout=sentinel, - auto_decompress=True, trust_env=False, - trace_configs=None) -> None: + def __init__(self, *, connector: Optional[BaseConnector]=None, + loop: Optional[asyncio.AbstractEventLoop]=None, + cookies: Optional[LooseCookies]=None, + headers: LooseHeaders=None, + skip_auto_headers: Optional[Iterable[str]]=None, + auth: Optional[BasicAuth]=None, + json_serialize: JSONEncoder=json.dumps, + request_class: Type[ClientRequest]=ClientRequest, + response_class: Type[ClientResponse]=ClientResponse, + ws_response_class: Type[ClientWebSocketResponse]=ClientWebSocketResponse, # noqa + version: HttpVersion=http.HttpVersion11, + cookie_jar: Optional[AbstractCookieJar]=None, + connector_owner: bool=True, + raise_for_status: bool=False, + read_timeout: Union[float, object]=sentinel, + conn_timeout: Optional[float]=None, + timeout: Union[object, ClientTimeout]=sentinel, + auto_decompress: bool=True, + trust_env: bool=False, + trace_configs: Optional[List[TraceConfig]]=None) -> None: implicit_loop = False if loop is None: @@ -144,8 +164,15 @@ def __init__(self, *, connector=None, loop=None, cookies=None, self._default_auth = auth self._version = version self._json_serialize = json_serialize - if timeout is not sentinel: - self._timeout = timeout + if timeout is sentinel: + self._timeout = DEFAULT_TIMEOUT + if read_timeout is not sentinel: + self._timeout = attr.evolve(self._timeout, total=read_timeout) + if conn_timeout is not None: + self._timeout = attr.evolve(self._timeout, + connect=conn_timeout) + else: + self._timeout = timeout # type: ignore if read_timeout is not sentinel: raise ValueError("read_timeout and timeout parameters " "conflict, please setup " @@ -154,13 +181,6 @@ def __init__(self, *, connector=None, loop=None, cookies=None, raise ValueError("conn_timeout and timeout parameters " "conflict, please setup " "timeout.connect") - else: - self._timeout = DEFAULT_TIMEOUT - if read_timeout is not sentinel: - self._timeout = attr.evolve(self._timeout, total=read_timeout) - if conn_timeout is not None: - self._timeout = attr.evolve(self._timeout, - connect=conn_timeout) self._raise_for_status = raise_for_status self._auto_decompress = auto_decompress self._trust_env = trust_env @@ -185,14 +205,14 @@ def __init__(self, *, connector=None, loop=None, cookies=None, for trace_config in self._trace_configs: trace_config.freeze() - def __init_subclass__(cls): + def __init_subclass__(cls: Type['ClientSession']) -> None: warnings.warn("Inheritance class {} from ClientSession " "is discouraged".format(cls.__name__), DeprecationWarning, stacklevel=2) if DEBUG: - def __setattr__(self, name, val): + def __setattr__(self, name: str, val: Any) -> None: if name not in self.ATTRS: warnings.warn("Setting custom ClientSession.{} attribute " "is discouraged".format(name), @@ -200,7 +220,7 @@ def __setattr__(self, name, val): stacklevel=2) super().__setattr__(name, val) - def __del__(self, _warnings=warnings): + def __del__(self, _warnings: Any=warnings) -> None: if not self.closed: if PY_36: kwargs = {'source': self} @@ -218,33 +238,37 @@ def __del__(self, _warnings=warnings): def request(self, method: str, url: StrOrURL, - **kwargs) -> '_RequestContextManager': + **kwargs: Any) -> '_RequestContextManager': """Perform HTTP request.""" return _RequestContextManager(self._request(method, url, **kwargs)) - async def _request(self, method, url, *, - params=None, - data=None, - json=None, - headers=None, - skip_auto_headers=None, - auth=None, - allow_redirects=True, - max_redirects=10, - compress=None, - chunked=None, - expect100=False, - raise_for_status=None, - read_until_eof=True, - proxy=None, - proxy_auth=None, - timeout=sentinel, - verify_ssl=None, - fingerprint=None, - ssl_context=None, - ssl=None, - proxy_headers=None, - trace_request_ctx=None): + async def _request( + self, + method: str, + str_or_url: StrOrURL, *, + params: Optional[Mapping[str, str]]=None, + data: Any=None, + json: Any=None, + headers: LooseHeaders=None, + skip_auto_headers: Optional[Iterable[str]]=None, + auth: Optional[BasicAuth]=None, + allow_redirects: bool=True, + max_redirects: int=10, + compress: Optional[str]=None, + chunked: Optional[bool]=None, + expect100: bool=False, + raise_for_status: Optional[bool]=None, + read_until_eof: bool=True, + proxy: Optional[StrOrURL]=None, + proxy_auth: Optional[BasicAuth]=None, + timeout: Union[ClientTimeout, object]=sentinel, + verify_ssl: Optional[bool]=None, + fingerprint: Optional[Fingerprint]=None, + ssl_context: Optional[SSLContext]=None, + ssl: Optional[Union[SSLContext, bool, Fingerprint]]=None, + proxy_headers: Optional[LooseHeaders]=None, + trace_request_ctx: Optional[SimpleNamespace]=None + ) -> ClientResponse: # NOTE: timeout clamps existing connect and read timeouts. We cannot # set the default to None because we need to detect if the user wants @@ -274,9 +298,9 @@ async def _request(self, method, url, *, proxy_headers = self._prepare_headers(proxy_headers) try: - url = URL(url) + url = URL(str_or_url) except ValueError: - raise InvalidURL(url) + raise InvalidURL(str_or_url) skip_headers = set(self._skip_auto_headers) if skip_auto_headers is not None: @@ -290,13 +314,15 @@ async def _request(self, method, url, *, raise InvalidURL(proxy) if timeout is sentinel: - timeout = self._timeout + real_timeout = self._timeout # type: ClientTimeout else: if not isinstance(timeout, ClientTimeout): - timeout = ClientTimeout(total=timeout) + real_timeout = ClientTimeout(total=timeout) # type: ignore + else: + real_timeout = timeout # timeout is cumulative for all request operations # (request, redirects, responses, data consuming) - tm = TimeoutHandle(self._loop, timeout.total) + tm = TimeoutHandle(self._loop, real_timeout.total) handle = tm.start() traces = [ @@ -362,27 +388,30 @@ async def _request(self, method, url, *, # connection timeout try: - with CeilTimeout(self._timeout.connect, + with CeilTimeout(real_timeout.connect, loop=self._loop): + assert self._connector is not None conn = await self._connector.connect( req, traces=traces, - timeout=timeout + timeout=real_timeout ) except asyncio.TimeoutError as exc: raise ServerTimeoutError( 'Connection timeout ' 'to host {0}'.format(url)) from exc + assert conn.transport is not None tcp_nodelay(conn.transport, True) tcp_cork(conn.transport, False) + assert conn.protocol is not None conn.protocol.set_response_params( timer=timer, skip_payload=method.upper() == 'HEAD', read_until_eof=read_until_eof, auto_decompress=self._auto_decompress, - read_timeout=timeout.sock_read) + read_timeout=real_timeout.sock_read) try: try: @@ -508,25 +537,27 @@ async def _request(self, method, url, *, ) raise - def ws_connect(self, url: StrOrURL, *, - protocols=(), - timeout=10.0, - receive_timeout=None, - autoclose=True, - autoping=True, - heartbeat=None, - auth=None, - origin=None, - headers=None, - proxy=None, - proxy_auth=None, - ssl=None, - verify_ssl=None, - fingerprint=None, - ssl_context=None, - proxy_headers=None, - compress=0, - max_msg_size=4*1024*1024): + def ws_connect( + self, + url: StrOrURL, *, + protocols: Iterable[str]=(), + timeout: float=10.0, + receive_timeout: Optional[float]=None, + autoclose: bool=True, + autoping: bool=True, + heartbeat: Optional[float]=None, + auth: Optional[BasicAuth]=None, + origin: Optional[str]=None, + headers: Optional[LooseHeaders]=None, + proxy: Optional[StrOrURL]=None, + proxy_auth: Optional[BasicAuth]=None, + ssl: Union[SSLContext, bool, None, Fingerprint]=None, + verify_ssl: Optional[bool]=None, + fingerprint: Optional[Fingerprint]=None, + ssl_context: Optional[SSLContext]=None, + proxy_headers: Optional[LooseHeaders]=None, + compress: int=0, + max_msg_size: int=4*1024*1024) -> '_WSRequestContextManager': """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect(url, @@ -549,28 +580,33 @@ def ws_connect(self, url: StrOrURL, *, compress=compress, max_msg_size=max_msg_size)) - async def _ws_connect(self, url, *, - protocols=(), - timeout=10.0, - receive_timeout=None, - autoclose=True, - autoping=True, - heartbeat=None, - auth=None, - origin=None, - headers=None, - proxy=None, - proxy_auth=None, - ssl=None, - verify_ssl=None, - fingerprint=None, - ssl_context=None, - proxy_headers=None, - compress=0, - max_msg_size=4*1024*1024): + async def _ws_connect( + self, + url: StrOrURL, *, + protocols: Iterable[str]=(), + timeout: float=10.0, + receive_timeout: Optional[float]=None, + autoclose: bool=True, + autoping: bool=True, + heartbeat: Optional[float]=None, + auth: Optional[BasicAuth]=None, + origin: Optional[str]=None, + headers: Optional[LooseHeaders]=None, + proxy: Optional[StrOrURL]=None, + proxy_auth: Optional[BasicAuth]=None, + ssl: Union[SSLContext, bool, None, Fingerprint]=None, + verify_ssl: Optional[bool]=None, + fingerprint: Optional[Fingerprint]=None, + ssl_context: Optional[SSLContext]=None, + proxy_headers: Optional[LooseHeaders]=None, + compress: int=0, + max_msg_size: int=4*1024*1024 + ) -> ClientWebSocketResponse: if headers is None: - headers = CIMultiDict() + real_headers = CIMultiDict() # type: CIMultiDict[str] + else: + real_headers = CIMultiDict(headers) default_headers = { hdrs.UPGRADE: hdrs.WEBSOCKET, @@ -579,24 +615,23 @@ async def _ws_connect(self, url, *, } for key, value in default_headers.items(): - if key not in headers: - headers[key] = value + real_headers.setdefault(key, value) sec_key = base64.b64encode(os.urandom(16)) - headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode() + real_headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode() if protocols: - headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) + real_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) if origin is not None: - headers[hdrs.ORIGIN] = origin + real_headers[hdrs.ORIGIN] = origin if compress: extstr = ws_ext_gen(compress=compress) - headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr + real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) # send request - resp = await self.get(url, headers=headers, + resp = await self.get(url, headers=real_headers, read_until_eof=False, auth=auth, proxy=proxy, @@ -675,7 +710,7 @@ async def _ws_connect(self, url, *, proto = resp.connection.protocol transport = resp.connection.transport reader = FlowControlDataQueue( - proto, limit=2 ** 16, loop=self._loop) + proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa proto.set_parser(WebSocketReader(reader, max_msg_size), reader) tcp_nodelay(transport, True) writer = WebSocketWriter( @@ -698,7 +733,9 @@ async def _ws_connect(self, url, *, compress=compress, client_notakeover=notakeover) - def _prepare_headers(self, headers): + def _prepare_headers( + self, + headers: Optional[LooseHeaders]) -> 'CIMultiDict[str]': """ Add default headers and transform it to CIMultiDict """ # Convert headers to MultiDict @@ -706,7 +743,7 @@ def _prepare_headers(self, headers): if headers: if not isinstance(headers, (MultiDictProxy, MultiDict)): headers = CIMultiDict(headers) - added_names = set() + added_names = set() # type: Set[str] for key, value in headers.items(): if key in added_names: result.add(key, value) @@ -716,7 +753,7 @@ def _prepare_headers(self, headers): return result def get(self, url: StrOrURL, *, allow_redirects: bool=True, - **kwargs) -> '_RequestContextManager': + **kwargs: Any) -> '_RequestContextManager': """Perform HTTP GET request.""" return _RequestContextManager( self._request(hdrs.METH_GET, url, @@ -724,7 +761,7 @@ def get(self, url: StrOrURL, *, allow_redirects: bool=True, **kwargs)) def options(self, url: StrOrURL, *, allow_redirects: bool=True, - **kwargs) -> '_RequestContextManager': + **kwargs: Any) -> '_RequestContextManager': """Perform HTTP OPTIONS request.""" return _RequestContextManager( self._request(hdrs.METH_OPTIONS, url, @@ -732,7 +769,7 @@ def options(self, url: StrOrURL, *, allow_redirects: bool=True, **kwargs)) def head(self, url: StrOrURL, *, allow_redirects: bool=False, - **kwargs) -> '_RequestContextManager': + **kwargs: Any) -> '_RequestContextManager': """Perform HTTP HEAD request.""" return _RequestContextManager( self._request(hdrs.METH_HEAD, url, @@ -740,7 +777,7 @@ def head(self, url: StrOrURL, *, allow_redirects: bool=False, **kwargs)) def post(self, url: StrOrURL, - *, data: Any=None, **kwargs) -> '_RequestContextManager': + *, data: Any=None, **kwargs: Any) -> '_RequestContextManager': """Perform HTTP POST request.""" return _RequestContextManager( self._request(hdrs.METH_POST, url, @@ -748,7 +785,7 @@ def post(self, url: StrOrURL, **kwargs)) def put(self, url: StrOrURL, - *, data: Any=None, **kwargs) -> '_RequestContextManager': + *, data: Any=None, **kwargs: Any) -> '_RequestContextManager': """Perform HTTP PUT request.""" return _RequestContextManager( self._request(hdrs.METH_PUT, url, @@ -756,14 +793,14 @@ def put(self, url: StrOrURL, **kwargs)) def patch(self, url: StrOrURL, - *, data: Any=None, **kwargs) -> '_RequestContextManager': + *, data: Any=None, **kwargs: Any) -> '_RequestContextManager': """Perform HTTP PATCH request.""" return _RequestContextManager( self._request(hdrs.METH_PATCH, url, data=data, **kwargs)) - def delete(self, url: StrOrURL, **kwargs) -> '_RequestContextManager': + def delete(self, url: StrOrURL, **kwargs: Any) -> '_RequestContextManager': """Perform HTTP DELETE request.""" return _RequestContextManager( self._request(hdrs.METH_DELETE, url, @@ -793,7 +830,7 @@ def connector(self) -> Optional[BaseConnector]: return self._connector @property - def cookie_jar(self) -> CookieJar: + def cookie_jar(self) -> AbstractCookieJar: """The session cookies.""" return self._cookie_jar @@ -814,50 +851,65 @@ def detach(self) -> None: """ self._connector = None - def __enter__(self): + def __enter__(self) -> None: raise TypeError("Use async with instead") - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> None: # __exit__ should exist in pair with __enter__ but never executed pass # pragma: no cover async def __aenter__(self) -> 'ClientSession': return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> None: await self.close() -class _BaseRequestContextManager(CoroutineABC): +class _BaseRequestContextManager(Coroutine[Any, + Any, + _RetType], + Generic[_RetType]): __slots__ = ('_coro', '_resp') - def __init__(self, coro): + def __init__( + self, + coro: Coroutine['asyncio.Future[Any]', None, _RetType] + ) -> None: self._coro = coro - def send(self, arg): + def send(self, arg: None) -> 'asyncio.Future[Any]': return self._coro.send(arg) - def throw(self, arg): - return self._coro.throw(arg) + def throw(self, arg: BaseException) -> None: # type: ignore + self._coro.throw(arg) # type: ignore - def close(self): + def close(self) -> None: return self._coro.close() - def __await__(self) -> Generator[Any, None, ClientResponse]: + def __await__(self) -> Generator[Any, None, _RetType]: ret = self._coro.__await__() return ret - def __iter__(self): + def __iter__(self) -> Generator[Any, None, _RetType]: return self.__await__() - async def __aenter__(self) -> ClientResponse: + async def __aenter__(self) -> _RetType: self._resp = await self._coro return self._resp -class _RequestContextManager(_BaseRequestContextManager): - async def __aexit__(self, exc_type, exc, tb): +class _RequestContextManager(_BaseRequestContextManager[ClientResponse]): + async def __aexit__(self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType]) -> None: # We're basing behavior on the exception as it can be caused by # user code unrelated to the status of the connection. If you # would like to close a connection you must do that @@ -866,8 +918,12 @@ async def __aexit__(self, exc_type, exc, tb): self._resp.release() -class _WSRequestContextManager(_BaseRequestContextManager): - async def __aexit__(self, exc_type, exc, tb): +class _WSRequestContextManager(_BaseRequestContextManager[ + ClientWebSocketResponse]): + async def __aexit__(self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType]) -> None: await self._resp.close() @@ -875,41 +931,50 @@ class _SessionRequestContextManager: __slots__ = ('_coro', '_resp', '_session') - def __init__(self, coro, session): + def __init__(self, + coro: Coroutine['asyncio.Future[Any]', None, ClientResponse], + session: ClientSession) -> None: self._coro = coro - self._resp = None + self._resp = None # type: Optional[ClientResponse] self._session = session async def __aenter__(self) -> ClientResponse: self._resp = await self._coro return self._resp - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType]) -> None: + assert self._resp is not None self._resp.close() await self._session.close() -def request(method, url, *, - params=None, - data=None, - json=None, - headers=None, - skip_auto_headers=None, - cookies=None, - auth=None, - allow_redirects=True, - max_redirects=10, - version=http.HttpVersion11, - compress=None, - chunked=None, - expect100=False, - raise_for_status=False, - connector=None, - loop=None, - read_until_eof=True, - timeout=sentinel, - proxy=None, - proxy_auth=None) -> _SessionRequestContextManager: +def request( + method: str, + url: StrOrURL, *, + params: Optional[Mapping[str, str]]=None, + data: Any=None, + json: Any=None, + headers: LooseHeaders=None, + skip_auto_headers: Optional[Iterable[str]]=None, + auth: Optional[BasicAuth]=None, + allow_redirects: bool=True, + max_redirects: int=10, + compress: Optional[str]=None, + chunked: Optional[bool]=None, + expect100: bool=False, + raise_for_status: Optional[bool]=None, + read_until_eof: bool=True, + proxy: Optional[StrOrURL]=None, + proxy_auth: Optional[BasicAuth]=None, + timeout: Union[ClientTimeout, object]=sentinel, + cookies: Optional[LooseCookies]=None, + version: HttpVersion=http.HttpVersion11, + connector: Optional[BaseConnector]=None, + loop: Optional[asyncio.AbstractEventLoop]=None, +) -> _SessionRequestContextManager: """Constructs and sends a request. Returns response object. method - HTTP method url - request url diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 08194e62d17..987c5c49c63 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -2,9 +2,7 @@ import asyncio import warnings -from typing import TYPE_CHECKING, Optional, Tuple - -from yarl import URL +from typing import TYPE_CHECKING, Any, Optional, Tuple from .typedefs import _CIMultiDict @@ -205,11 +203,13 @@ class InvalidURL(ClientError, ValueError): # Derive from ValueError for backward compatibility - def __init__(self, url: URL) -> None: + def __init__(self, url: Any) -> None: + # The type of url is not yarl.URL because the exception can be raised + # on URL(url) call super().__init__(url) @property - def url(self) -> URL: + def url(self) -> Any: return self.args[0] def __repr__(self) -> str: diff --git a/aiohttp/cookiejar.py b/aiohttp/cookiejar.py index 480a07b6dde..b868de0f757 100644 --- a/aiohttp/cookiejar.py +++ b/aiohttp/cookiejar.py @@ -14,7 +14,7 @@ from .abc import AbstractCookieJar from .helpers import is_ip_address -from .typedefs import PathLike +from .typedefs import LooseCookies, PathLike __all__ = ('CookieJar', 'DummyCookieJar') @@ -102,8 +102,7 @@ def _expire_cookie(self, when: float, domain: str, name: str) -> None: self._expirations[(domain, name)] = iwhen def update_cookies(self, - cookies: Union[Iterable[Tuple[str, 'BaseCookie[str]']], - Mapping[str, 'BaseCookie[str]']], + cookies: LooseCookies, response_url: URL=URL()) -> None: """Update cookies.""" hostname = response_url.raw_host diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index 960e47037d4..bc06c81a011 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -176,7 +176,7 @@ class TraceRequestExceptionParams: method = attr.ib(type=str) url = attr.ib(type=URL) headers = attr.ib(type='CIMultiDict[str]') - exception = attr.ib(type=Exception) + exception = attr.ib(type=BaseException) @attr.s(frozen=True, slots=True) @@ -288,7 +288,7 @@ async def send_request_exception(self, method: str, url: URL, headers: 'CIMultiDict[str]', - exception: Exception) -> None: + exception: BaseException) -> None: return await self._trace_config.on_request_exception.send( self._session, self._trace_config_ctx, diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py index ec794cd963e..272c1246c0b 100644 --- a/aiohttp/typedefs.py +++ b/aiohttp/typedefs.py @@ -2,8 +2,8 @@ import os # noqa import pathlib # noqa import sys -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, # noqa - Tuple, Union) +from typing import (TYPE_CHECKING, Any, Callable, Iterable, Mapping, Tuple, + Union) from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy from yarl import URL @@ -17,6 +17,7 @@ _CIMultiDictProxy = CIMultiDictProxy[str] _MultiDict = MultiDict[str] _MultiDictProxy = MultiDictProxy[str] + from http.cookies import BaseCookie # noqa else: _CIMultiDict = CIMultiDict _CIMultiDictProxy = CIMultiDictProxy @@ -29,6 +30,8 @@ LooseHeaders = Union[Mapping[str, str], _CIMultiDict, _CIMultiDictProxy] RawHeaders = Tuple[Tuple[bytes, bytes], ...] StrOrURL = Union[str, URL] +LooseCookies = Union[Iterable[Tuple[str, 'BaseCookie[str]']], + Mapping[str, 'BaseCookie[str]']] if sys.version_info >= (3, 6):