diff --git a/CHANGES/3045.feature b/CHANGES/3045.feature new file mode 100644 index 00000000000..cbd7ab4f583 --- /dev/null +++ b/CHANGES/3045.feature @@ -0,0 +1 @@ +Limit websocket message size on reading to 4 MB by default. \ No newline at end of file diff --git a/aiohttp/client.py b/aiohttp/client.py index d70b5f7a966..e8d32d8dce8 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -519,7 +519,8 @@ def ws_connect(self, url, *, fingerprint=None, ssl_context=None, proxy_headers=None, - compress=0): + compress=0, + max_msg_size=4*1024*1024): """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect(url, @@ -539,7 +540,8 @@ def ws_connect(self, url, *, fingerprint=fingerprint, ssl_context=ssl_context, proxy_headers=proxy_headers, - compress=compress)) + compress=compress, + max_msg_size=max_msg_size)) async def _ws_connect(self, url, *, protocols=(), @@ -558,7 +560,8 @@ async def _ws_connect(self, url, *, fingerprint=None, ssl_context=None, proxy_headers=None, - compress=0): + compress=0, + max_msg_size=4*1024*1024): if headers is None: headers = CIMultiDict() @@ -667,7 +670,7 @@ async def _ws_connect(self, url, *, transport = resp.connection.transport reader = FlowControlDataQueue( proto, limit=2 ** 16, loop=self._loop) - proto.set_parser(WebSocketReader(reader), reader) + proto.set_parser(WebSocketReader(reader, max_msg_size), reader) tcp_nodelay(transport, True) writer = WebSocketWriter( proto, transport, use_mask=True, diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index edc1b8dffb3..4c8983c744a 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -236,8 +236,9 @@ class WSParserState(IntEnum): class WebSocketReader: - def __init__(self, queue, compress=True): + def __init__(self, queue, max_msg_size, compress=True): self.queue = queue + self._max_msg_size = max_msg_size self._exc = None self._partial = bytearray() @@ -320,6 +321,12 @@ def _feed_data(self, data): if opcode != WSMsgType.CONTINUATION: self._opcode = opcode self._partial.extend(payload) + if (self._max_msg_size and + len(self._partial) >= self._max_msg_size): + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size)) else: # previous frame was non finished # we should get continuation opcode @@ -335,13 +342,26 @@ def _feed_data(self, data): self._opcode = None self._partial.extend(payload) + if (self._max_msg_size and + len(self._partial) >= self._max_msg_size): + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size)) # Decompress process must to be done after all packets # received. if compressed: self._partial.extend(_WS_DEFLATE_TRAILING) payload_merged = self._decompressobj.decompress( - self._partial) + self._partial, self._max_msg_size) + if self._decompressobj.unconsumed_tail: + left = len(self._decompressobj.unconsumed_tail) + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Decompressed message size exceeds limit {}". + format(self._max_msg_size + left, + self._max_msg_size)) else: payload_merged = bytes(self._partial) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index be2bea699e4..13273950754 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -38,7 +38,7 @@ class WebSocketResponse(StreamResponse): def __init__(self, *, timeout=10.0, receive_timeout=None, autoclose=True, autoping=True, heartbeat=None, - protocols=(), compress=True): + protocols=(), compress=True, max_msg_size=4*1024*1024): super().__init__(status=101) self._protocols = protocols self._ws_protocol = None @@ -61,6 +61,7 @@ def __init__(self, *, self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb = None self._compress = compress + self._max_msg_size = max_msg_size def _cancel_heartbeat(self): if self._pong_response_cb is not None: @@ -203,7 +204,7 @@ def _post_start(self, request, protocol, writer): self._reader = FlowControlDataQueue( request._protocol, limit=2 ** 16, loop=self._loop) request.protocol.set_parser(WebSocketReader( - self._reader, compress=self._compress)) + self._reader, self._max_msg_size, compress=self._compress)) # disable HTTP keepalive for WebSocket request.protocol.keep_alive(False) diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 07d2ff990d2..a29485326d2 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -498,7 +498,7 @@ The client session supports the context manager protocol for self closing. proxy=None, proxy_auth=None, ssl=None, \ verify_ssl=None, fingerprint=None, \ ssl_context=None, proxy_headers=None, \ - compress=0) + compress=0, max_msg_size=4194304) :async-with: :coroutine: @@ -601,6 +601,12 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 2.3 + :param int max_msg_size: maximum size of read websocket message, + 4 MB by default. To disable the size + limit use ``0``. + + .. versionadded:: 3.3 + .. comethod:: close() diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 5f9aa871dd4..52d459a74de 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -865,7 +865,7 @@ WebSocketResponse .. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, \ autoclose=True, autoping=True, heartbeat=None, \ - protocols=(), compress=True) + protocols=(), compress=True, max_msg_size=4194304) Class for handling server-side websockets, inherited from :class:`StreamResponse`. @@ -903,6 +903,12 @@ WebSocketResponse :param bool compress: Enable per-message deflate extension support. False for disabled, default value is True. + :param int max_msg_size: maximum size of read websocket message, 4 + MB by default. To disable the size limit use ``0``. + + .. versionadded:: 3.3 + + The class supports ``async for`` statement for iterating over incoming messages:: diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 5ef9bd6ad76..e16b42dcf6a 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -1,5 +1,6 @@ import random import struct +import zlib from unittest import mock import pytest @@ -7,13 +8,20 @@ import aiohttp from aiohttp import http_websocket from aiohttp.http import WebSocketError, WSCloseCode, WSMessage, WSMsgType -from aiohttp.http_websocket import (PACK_CLOSE_CODE, PACK_LEN1, PACK_LEN2, - PACK_LEN3, WebSocketReader, - _websocket_mask) +from aiohttp.http_websocket import (_WS_DEFLATE_TRAILING, PACK_CLOSE_CODE, + PACK_LEN1, PACK_LEN2, PACK_LEN3, + WebSocketReader, _websocket_mask) -def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True): +def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True, + compress=False): """Send a frame over the websocket with message as its payload.""" + if compress: + compressobj = zlib.compressobj(wbits=-9) + message = compressobj.compress(message) + message = message + compressobj.flush(zlib.Z_SYNC_FLUSH) + if message.endswith(_WS_DEFLATE_TRAILING): + message = message[:-4] msg_length = len(message) if use_mask: # pragma: no cover mask_bit = 0x80 @@ -25,6 +33,9 @@ def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True): else: header_first_byte = opcode + if compress: + header_first_byte |= 0x40 + if msg_length < 126: header = PACK_LEN1( header_first_byte, msg_length | mask_bit) @@ -67,7 +78,7 @@ def out(loop): @pytest.fixture() def parser(out): - return WebSocketReader(out) + return WebSocketReader(out, 4*1024*1024) def test_parse_frame(parser): @@ -444,16 +455,35 @@ def test_parse_compress_error_frame(parser): assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR -@pytest.fixture() -def parser_no_compress(out): - return WebSocketReader(out, compress=False) - - -def test_parse_no_compress_frame_single(parser_no_compress): - +def test_parse_no_compress_frame_single(): + parser_no_compress = WebSocketReader(out, 0, compress=False) with pytest.raises(WebSocketError) as ctx: parser_no_compress.parse_frame(struct.pack( '!BB', 0b11000001, 0b00000001)) parser_no_compress.parse_frame(b'1') assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR + + +def test_msg_too_large(out): + parser = WebSocketReader(out, 256, compress=False) + data = build_frame(b'text'*256, WSMsgType.TEXT) + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(data) + assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG + + +def test_msg_too_large_not_fin(out): + parser = WebSocketReader(out, 256, compress=False) + data = build_frame(b'text'*256, WSMsgType.TEXT, is_fin=False) + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(data) + assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG + + +def test_compressed_msg_too_large(out): + parser = WebSocketReader(out, 256, compress=True) + data = build_frame(b'aaa'*256, WSMsgType.TEXT, compress=True) + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(data) + assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG