Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make web.BaseRequest and web.Request slot-based #3942

Merged
merged 5 commits into from
Jul 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/3942.removal
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `web.BaseRequest`, `web.Request`, `web.StreamResponse`, `web.Response` and `web.WebSocketResponse` slot-based, prevent custom instance attributes.
1 change: 1 addition & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ def __init__(self, method: str, url: URL, *,
loop: asyncio.AbstractEventLoop,
session: 'ClientSession') -> None:
assert isinstance(url, URL)
super().__init__()

self.method = method
self.cookies = SimpleCookie()
Expand Down
10 changes: 5 additions & 5 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,12 @@ def __enter__(self) -> async_timeout.timeout:

class HeadersMixin:

ATTRS = frozenset([
'_content_type', '_content_dict', '_stored_content_type'])
__slots__ = ('_content_type', '_content_dict', '_stored_content_type')

_content_type = None # type: Optional[str]
_content_dict = None # type: Optional[Dict[str, str]]
_stored_content_type = sentinel
def __init__(self) -> None:
self._content_type = None # type: Optional[str]
self._content_dict = None # type: Optional[Dict[str, str]]
self._stored_content_type = sentinel

def _parse_content_type(self, raw: str) -> None:
self._stored_content_type = raw
Expand Down
19 changes: 4 additions & 15 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import string
import tempfile
import types
import warnings
from email.utils import parsedate
from http.cookies import SimpleCookie
from types import MappingProxyType
Expand All @@ -31,7 +30,6 @@
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import (
DEBUG,
ChainMapProxy,
HeadersMixin,
is_expected_content_type,
Expand Down Expand Up @@ -108,11 +106,11 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT,
hdrs.METH_TRACE, hdrs.METH_DELETE}

ATTRS = HeadersMixin.ATTRS | frozenset([
__slots__ = (
'_message', '_protocol', '_payload_writer', '_payload', '_headers',
'_method', '_version', '_rel_url', '_post', '_read_bytes',
'_state', '_cache', '_task', '_client_max_size', '_loop',
'_transport_sslcontext', '_transport_peername'])
'_transport_sslcontext', '_transport_peername')

def __init__(self, message: RawRequestMessage,
payload: StreamReader, protocol: 'RequestHandler',
Expand All @@ -124,6 +122,7 @@ def __init__(self, message: RawRequestMessage,
scheme: Optional[str]=None,
host: Optional[str]=None,
remote: Optional[str]=None) -> None:
super().__init__()
if state is None:
state = {}
self._message = message
Expand Down Expand Up @@ -689,7 +688,7 @@ async def _prepare_hook(self, response: StreamResponse) -> None:

class Request(BaseRequest):

ATTRS = BaseRequest.ATTRS | frozenset(['_match_info'])
__slots__ = ('_match_info',)

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -700,16 +699,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# initialized after route resolving
self._match_info = None # type: Optional[UrlMappingMatchInfo]

if DEBUG:
def __setattr__(self, name: str, val: Any) -> None:
if name not in self.ATTRS:
warnings.warn("Setting custom {}.{} attribute "
"is discouraged".format(self.__class__.__name__,
name),
DeprecationWarning,
stacklevel=2)
super().__setattr__(name, val)

def clone(self, *, method: str=sentinel, rel_url:
StrOrURL=sentinel, headers: LooseHeaders=sentinel,
scheme: str=sentinel, host: str=sentinel, remote:
Expand Down
12 changes: 11 additions & 1 deletion aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,17 @@ class ContentCoding(enum.Enum):

class StreamResponse(BaseClass, HeadersMixin):

_length_check = True
__slots__ = ('_length_check', '_body', '_keep_alive', '_chunked',
'_compression', '_compression_force', '_cookies', '_req',
'_payload_writer', '_eof_sent', '_body_length', '_state',
'_headers', '_status', '_reason')

def __init__(self, *,
status: int=200,
reason: Optional[str]=None,
headers: Optional[LooseHeaders]=None) -> None:
super().__init__()
self._length_check = True
self._body = None
self._keep_alive = None # type: Optional[bool]
self._chunked = False
Expand Down Expand Up @@ -465,6 +470,11 @@ def __eq__(self, other: object) -> bool:

class Response(StreamResponse):

__slots__ = ('_body_payload',
'_compressed_body',
'_zlib_executor_size',
'_zlib_executor')

def __init__(self, *,
body: Any=None,
status: int=200,
Expand Down
8 changes: 6 additions & 2 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ def __bool__(self) -> bool:


class WebSocketResponse(StreamResponse):

_length_check = False
__slots__ = ('_protocols', '_ws_protocol', '_writer', '_reader', '_closed',
'_closing', '_conn_lost', '_close_code', '_loop', '_waiting',
'_exception', '_timeout', '_receive_timeout', '_autoclose',
'_autoping', '_heartbeat', '_heartbeat_cb', '_pong_heartbeat',
'_pong_response_cb', '_compress', '_max_msg_size')

def __init__(self, *,
timeout: float=10.0, receive_timeout: Optional[float]=None,
Expand All @@ -55,6 +58,7 @@ def __init__(self, *,
protocols: Iterable[str]=(),
compress: bool=True, max_msg_size: int=4*1024*1024) -> None:
super().__init__(status=101)
self._length_check = False
self._protocols = protocols
self._ws_protocol = None # type: Optional[str]
self._writer = None # type: Optional[WebSocketWriter]
Expand Down
8 changes: 5 additions & 3 deletions tests/test_web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,12 @@ async def test_unhandled_runtime_error(
make_srv, transport, request_handler
):

class MyResponse(web.Response):
async def write_eof(self, data=b''):
raise RuntimeError()

async def handle(request):
resp = web.Response()
resp.write_eof = mock.Mock()
resp.write_eof.side_effect = RuntimeError
resp = MyResponse()
return resp

loop = asyncio.get_event_loop()
Expand Down
37 changes: 36 additions & 1 deletion tests/test_web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from unittest import mock

import pytest
from multidict import CIMultiDict, MultiDict
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict
from yarl import URL

from aiohttp import HttpVersion, web
from aiohttp.helpers import DEBUG
from aiohttp.http_parser import RawRequestMessage
from aiohttp.streams import StreamReader
from aiohttp.test_utils import make_mocked_request
from aiohttp.web import HTTPRequestEntityTooLarge, HTTPUnsupportedMediaType
Expand All @@ -19,6 +20,38 @@ def protocol():
return mock.Mock(_reading_paused=False)


def test_base_ctor() -> None:
message = RawRequestMessage(
'GET', '/path/to?a=1&b=2', HttpVersion(1, 1),
CIMultiDictProxy(CIMultiDict()), (),
False, False, False, False, URL('/path/to?a=1&b=2'))

req = web.BaseRequest(message,
mock.Mock(),
mock.Mock(),
mock.Mock(),
mock.Mock(),
mock.Mock())

assert 'GET' == req.method
assert HttpVersion(1, 1) == req.version
assert req.host == socket.getfqdn()
assert '/path/to?a=1&b=2' == req.path_qs
assert '/path/to' == req.path
assert 'a=1&b=2' == req.query_string
assert CIMultiDict() == req.headers
assert () == req.raw_headers

get = req.query
assert MultiDict([('a', '1'), ('b', '2')]) == get
# second call should return the same object
assert get is req.query

assert req.keep_alive

assert '__dict__' not in dir(req)


def test_ctor() -> None:
req = make_mocked_request('GET', '/path/to?a=1&b=2')

Expand Down Expand Up @@ -53,6 +86,8 @@ def test_ctor() -> None:
assert req.raw_headers == ((b'FOO', b'bar'),)
assert req.task is req._task

assert '__dict__' not in dir(req)


def test_doubleslashes() -> None:
# NB: //foo/bar is an absolute URL with foo netloc and /bar path
Expand Down
7 changes: 5 additions & 2 deletions tests/test_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,12 @@ async def handler(request):
async def test_raw_server_cancelled_in_write_eof(aiohttp_raw_server,
aiohttp_client):

class MyResponse(web.Response):
async def write_eof(self, data=b''):
raise asyncio.CancelledError("error")

async def handler(request):
resp = web.Response(text=str(request.rel_url))
resp.write_eof = mock.Mock(side_effect=asyncio.CancelledError("error"))
resp = MyResponse(text=str(request.rel_url))
return resp

loop = asyncio.get_event_loop()
Expand Down
30 changes: 1 addition & 29 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from multidict import CIMultiDict

from aiohttp import WSMessage, WSMsgType, signals
from aiohttp import WSMsgType, signals
from aiohttp.log import ws_logger
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro, make_mocked_request
Expand Down Expand Up @@ -105,34 +105,6 @@ async def test_nonstarted_receive_json() -> None:
await ws.receive_json()


async def test_receive_str_nonstring(make_request) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
await ws.prepare(req)

async def receive():
return WSMessage(WSMsgType.BINARY, b'data', b'')

ws.receive = receive

with pytest.raises(TypeError):
await ws.receive_str()


async def test_receive_bytes_nonsbytes(make_request) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
await ws.prepare(req)

async def receive():
return WSMessage(WSMsgType.TEXT, 'data', b'')

ws.receive = receive

with pytest.raises(TypeError):
await ws.receive_bytes()


async def test_send_str_nonstring(make_request) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
Expand Down
42 changes: 42 additions & 0 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,45 @@ async def handler(request):
ws = await client.ws_connect('/')
data = await ws.receive_str()
assert data == 'OK'


async def test_receive_str_nonstring(loop, aiohttp_client) -> None:

async def handler(request):
ws = web.WebSocketResponse()
if not ws.can_prepare(request):
return web.HTTPUpgradeRequired()

await ws.prepare(request)
await ws.send_bytes(b'answer')
await ws.close()
return ws

app = web.Application()
app.router.add_route('GET', '/', handler)
client = await aiohttp_client(app)

ws = await client.ws_connect('/')
with pytest.raises(TypeError):
await ws.receive_str()


async def test_receive_bytes_nonbytes(loop, aiohttp_client) -> None:

async def handler(request):
ws = web.WebSocketResponse()
if not ws.can_prepare(request):
return web.HTTPUpgradeRequired()

await ws.prepare(request)
await ws.send_bytes('answer')
await ws.close()
return ws

app = web.Application()
app.router.add_route('GET', '/', handler)
client = await aiohttp_client(app)

ws = await client.ws_connect('/')
with pytest.raises(TypeError):
await ws.receive_bytes()