From 1fabfad271fe1a51165fe99cc4b828b1241af34c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sun, 22 Nov 2020 14:01:53 +0200 Subject: [PATCH] [3.8] Make type hints for http parser stricter (#5267). (#5268) (cherry picked from commit a6c7f154ddee11e6e23c66c830b5b0b668f81c8e) Co-authored-by: Andrew Svetlov --- CHANGES/5267.feature | 1 + aiohttp/client_exceptions.py | 3 ++- aiohttp/client_proto.py | 4 ++-- aiohttp/http_parser.py | 28 ++++++++++++++++------------ aiohttp/web_protocol.py | 14 ++++++++++++-- 5 files changed, 33 insertions(+), 17 deletions(-) create mode 100644 CHANGES/5267.feature diff --git a/CHANGES/5267.feature b/CHANGES/5267.feature new file mode 100644 index 00000000000..63dd2ffc518 --- /dev/null +++ b/CHANGES/5267.feature @@ -0,0 +1 @@ +Make type hints for http parser stricter diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 7bc483ce681..4c96d556793 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -4,6 +4,7 @@ import warnings from typing import TYPE_CHECKING, Any, Optional, Tuple, Union +from .http_parser import RawResponseMessage from .typedefs import LooseHeaders try: @@ -225,7 +226,7 @@ class ServerConnectionError(ClientConnectionError): class ServerDisconnectedError(ServerConnectionError): """Server disconnected.""" - def __init__(self, message: Optional[str] = None) -> None: + def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None: if message is None: message = "Server disconnected" diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 2973342e440..7ed6c878155 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -23,7 +23,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._should_close = False - self._payload = None + self._payload: Optional[StreamReader] = None self._skip_payload = False self._payload_parser = None @@ -223,7 +223,7 @@ def data_received(self, data: bytes) -> None: self._upgraded = upgraded - payload = None + payload: Optional[StreamReader] = None for message, payload in messages: if message.should_close: self._should_close = True diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 90bd05a25c3..940371c588c 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -4,8 +4,9 @@ import re import string import zlib +from contextlib import suppress from enum import IntEnum -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Generic, List, Optional, Tuple, Type, TypeVar, Union from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL @@ -88,6 +89,9 @@ ) +_MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) + + class ParseState(IntEnum): PARSE_NONE = 0 @@ -198,7 +202,7 @@ def parse_headers( return (CIMultiDictProxy(headers), tuple(raw_headers)) -class HttpParser(abc.ABC): +class HttpParser(abc.ABC, Generic[_MsgT]): def __init__( self, protocol: Optional[BaseProtocol] = None, @@ -239,10 +243,10 @@ def __init__( self._headers_parser = HeadersParser(max_line_size, max_headers, max_field_size) @abc.abstractmethod - def parse_message(self, lines: List[bytes]) -> Any: + def parse_message(self, lines: List[bytes]) -> _MsgT: pass - def feed_eof(self) -> Any: + def feed_eof(self) -> Optional[_MsgT]: if self._payload_parser is not None: self._payload_parser.feed_eof() self._payload_parser = None @@ -254,10 +258,9 @@ def feed_eof(self) -> Any: if self._lines: if self._lines[-1] != "\r\n": self._lines.append(b"") - try: + with suppress(Exception): return self.parse_message(self._lines) - except Exception: - return None + return None def feed_data( self, @@ -267,7 +270,7 @@ def feed_data( CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, METH_CONNECT: str = hdrs.METH_CONNECT, SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, - ) -> Tuple[List[Any], bool, bytes]: + ) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]: messages = [] @@ -346,6 +349,7 @@ def feed_data( if not payload_parser.done: self._payload_parser = payload_parser elif method == METH_CONNECT: + assert isinstance(msg, RawRequestMessage) payload = StreamReader( self.protocol, timer=self.timer, @@ -479,13 +483,13 @@ def set_upgraded(self, val: bool) -> None: self._upgraded = val -class HttpRequestParser(HttpParser): +class HttpRequestParser(HttpParser[RawRequestMessage]): """Read request status line. Exception .http_exceptions.BadStatusLine could be raised in case of any errors in status line. Returns RawRequestMessage. """ - def parse_message(self, lines: List[bytes]) -> Any: + def parse_message(self, lines: List[bytes]) -> RawRequestMessage: # request line line = lines[0].decode("utf-8", "surrogateescape") try: @@ -542,13 +546,13 @@ def parse_message(self, lines: List[bytes]) -> Any: ) -class HttpResponseParser(HttpParser): +class HttpResponseParser(HttpParser[RawResponseMessage]): """Read response status line and headers. BadStatusLine could be raised in case of any errors in status line. Returns RawResponseMessage""" - def parse_message(self, lines: List[bytes]) -> Any: + def parse_message(self, lines: List[bytes]) -> RawResponseMessage: line = lines[0].decode("utf-8", "surrogateescape") try: version, status = line.split(None, 1) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 8e02bc4aab7..5a032777dca 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -7,7 +7,17 @@ from html import escape as html_escape from http import HTTPStatus from logging import Logger -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Deque, + Optional, + Tuple, + Type, + cast, +) import yarl @@ -172,7 +182,7 @@ def __init__( self._keepalive_timeout = keepalive_timeout self._lingering_time = float(lingering_time) - self._messages = deque() # type: Any # Python 3.5 has no typing.Deque + self._messages: Deque[Tuple[RawRequestMessage, StreamReader]] = deque() self._message_tail = b"" self._waiter = None # type: Optional[asyncio.Future[None]]