Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WebSocket reader flow control calculations #9685

Merged
merged 41 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a8fa035
cleanups
bdraco Nov 5, 2024
539ff38
Update aiohttp/_websocket/reader_py.py
bdraco Nov 5, 2024
96d4113
Update aiohttp/client.py
bdraco Nov 5, 2024
ba192fd
Update aiohttp/client.py
bdraco Nov 5, 2024
891036c
Update aiohttp/client_ws.py
bdraco Nov 5, 2024
7cbe329
Update aiohttp/web_ws.py
bdraco Nov 5, 2024
f6b8faa
payload may be None
bdraco Nov 5, 2024
704010b
payload may not be len able
bdraco Nov 5, 2024
a1b54ed
cleanup
bdraco Nov 5, 2024
9555016
split it
bdraco Nov 6, 2024
666809f
split it
bdraco Nov 6, 2024
2e5dc54
preen
bdraco Nov 6, 2024
abbd219
fix
bdraco Nov 6, 2024
8cd2001
fixes
bdraco Nov 6, 2024
bec470b
more fixes
bdraco Nov 6, 2024
faad625
fix more tests
bdraco Nov 6, 2024
2b661c2
fix benchmark
bdraco Nov 6, 2024
25fd3fd
move size into message
bdraco Nov 6, 2024
a463ee1
move size into message
bdraco Nov 6, 2024
151bb30
move size into message
bdraco Nov 6, 2024
9726051
more fixes
bdraco Nov 6, 2024
2bac4c8
Apply suggestions from code review
bdraco Nov 6, 2024
7d56e4e
Apply suggestions from code review
bdraco Nov 6, 2024
1b59b7d
fix typo
bdraco Nov 6, 2024
763ec7e
fixes
bdraco Nov 6, 2024
328738d
preen
bdraco Nov 6, 2024
effafee
lint
bdraco Nov 6, 2024
b271554
cache prop
bdraco Nov 6, 2024
15ef47d
try to avoid all the py math
bdraco Nov 6, 2024
0cd95b0
no return
bdraco Nov 6, 2024
4cf1b69
preen
bdraco Nov 6, 2024
4773d89
dry
bdraco Nov 6, 2024
080d39a
compare data.size to tuple access
bdraco Nov 6, 2024
8cc4b8e
Apply suggestions from code review
bdraco Nov 6, 2024
b4f60eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
35a341c
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
bdraco Nov 6, 2024
af16201
Revert "Apply suggestions from code review"
bdraco Nov 6, 2024
7fb47a2
changelog
bdraco Nov 7, 2024
29dd891
Merge branch 'master' into websocket_flow_control
bdraco Nov 7, 2024
9e839d8
Update CHANGES/9685.breaking.rst
bdraco Nov 7, 2024
6c174e2
Merge branch 'master' into websocket_flow_control
bdraco Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,7 @@
payload_type,
)
from .resolver import AsyncResolver, DefaultResolver, ThreadedResolver
from .streams import (
EMPTY_PAYLOAD,
DataQueue,
EofStream,
FlowControlDataQueue,
StreamReader,
)
from .streams import EMPTY_PAYLOAD, DataQueue, EofStream, StreamReader
from .tracing import (
TraceConfig,
TraceConnectionCreateEndParams,
Expand Down Expand Up @@ -206,7 +200,6 @@
"DataQueue",
"EMPTY_PAYLOAD",
"EofStream",
"FlowControlDataQueue",
"StreamReader",
# tracing
"TraceConfig",
Expand Down
14 changes: 12 additions & 2 deletions aiohttp/_websocket/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,27 @@
from ..helpers import NO_EXTENSIONS

if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
from .reader_py import WebSocketReader as WebSocketReaderPython
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)

WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython
else:
try:
from .reader_c import ( # type: ignore[import-not-found]
WebSocketDataQueue as WebSocketDataQueueCython,
WebSocketReader as WebSocketReaderCython,
)

WebSocketReader = WebSocketReaderCython
WebSocketDataQueue = WebSocketDataQueueCython
except ImportError: # pragma: no cover
from .reader_py import WebSocketReader as WebSocketReaderPython
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)

WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython
44 changes: 42 additions & 2 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Reader for WebSocket protocol versions 13 and 8."""

import asyncio
from typing import Final, List, Optional, Set, Tuple, Union

from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import set_exception
from ..streams import DataQueue
from ..streams import DataQueue, EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
Expand Down Expand Up @@ -45,9 +47,47 @@
TUPLE_NEW = tuple.__new__


class WebSocketDataQueue(DataQueue[WSMessage]):
"""WebSocketDataQueue resumes and pauses an underlying stream.

It is a destination for WebSocket data.
"""

def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
super().__init__(loop=loop)
self._size = 0
self._protocol = protocol
self._limit = limit * 2

def feed_data(self, data: WSMessage) -> None:
self._size += len(data[0])
self._buffer.append(data)
if (waiter := self._waiter) is not None:
self._waiter = None
if not waiter.done():
waiter.set_result(None)
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()

async def read(self) -> WSMessage:
if not self._buffer and not self._eof:
await self._wait_for_data()
if self._buffer:
data = self._buffer.popleft()
self._size -= len(data[0])
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
Fixed Show fixed Hide fixed
raise EofStream


class WebSocketReader:
def __init__(
self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._queue_feed_data = queue.feed_data
Expand Down
6 changes: 2 additions & 4 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from yarl import URL

from . import hdrs, http, payload
from ._websocket.reader import WebSocketDataQueue
from .abc import AbstractCookieJar
from .client_exceptions import (
ClientConnectionError,
Expand Down Expand Up @@ -102,7 +103,6 @@
)
from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter
from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse
from .streams import FlowControlDataQueue
from .tracing import Trace, TraceConfig
from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, Query, StrOrURL

Expand Down Expand Up @@ -1035,9 +1035,7 @@ async def _ws_connect(

transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
conn_proto, 2**16, loop=self._loop
)
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
writer = WebSocketWriter(
conn_proto,
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def set_exception(
def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: FlowControlDataQueue
# payload: WebSocketDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
Expand Down
5 changes: 3 additions & 2 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import TracebackType
from typing import Any, Final, Optional, Type

from ._websocket.reader import WebSocketDataQueue
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
Expand All @@ -18,7 +19,7 @@
WSMsgType,
)
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter, WSMessageError
from .streams import EofStream, FlowControlDataQueue
from .streams import EofStream
from .typedefs import (
DEFAULT_JSON_DECODER,
DEFAULT_JSON_ENCODER,
Expand Down Expand Up @@ -46,7 +47,7 @@ class ClientWSTimeout:
class ClientWebSocketResponse:
def __init__(
self,
reader: "FlowControlDataQueue[WSMessage]",
reader: WebSocketDataQueue,
writer: WebSocketWriter,
protocol: Optional[str],
response: ClientResponse,
Expand Down
49 changes: 5 additions & 44 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@
"EofStream",
"StreamReader",
"DataQueue",
"FlowControlDataQueue",
)

_T = TypeVar("_T")
_SizedT = TypeVar("_SizedT", bound=collections.abc.Sized)


class EofStream(Exception):
Expand Down Expand Up @@ -600,15 +598,15 @@ def read_nowait(self, n: int = -1) -> bytes:
EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader()


class DataQueue(Generic[_SizedT]):
class DataQueue(Generic[_T]):
"""DataQueue is a general-purpose blocking queue with one reader."""

def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Union[Type[BaseException], BaseException, None] = None
self._buffer: Deque[_SizedT] = collections.deque()
self._buffer: Deque[_T] = collections.deque()

def __len__(self) -> int:
return len(self._buffer)
Expand All @@ -633,7 +631,7 @@ def set_exception(
self._waiter = None
set_exception(waiter, exc, exc_cause)

def feed_data(self, data: _SizedT) -> None:
def feed_data(self, data: _T) -> None:
self._buffer.append(data)
if (waiter := self._waiter) is not None:
self._waiter = None
Expand All @@ -654,7 +652,7 @@ async def _wait_for_data(self) -> None:
self._waiter = None
raise

async def read(self) -> _SizedT:
async def read(self) -> _T:
if not self._buffer and not self._eof:
await self._wait_for_data()
if self._buffer:
Expand All @@ -663,42 +661,5 @@ async def read(self) -> _SizedT:
raise self._exception
raise EofStream

def __aiter__(self) -> AsyncStreamIterator[_SizedT]:
def __aiter__(self) -> AsyncStreamIterator[_T]:
return AsyncStreamIterator(self.read)


class FlowControlDataQueue(DataQueue[_SizedT]):
"""FlowControlDataQueue resumes and pauses an underlying stream.

It is a destination for parsed data.
"""

def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
super().__init__(loop=loop)
self._size = 0
self._protocol = protocol
self._limit = limit * 2

def feed_data(self, data: _SizedT) -> None:
self._size += len(data)
self._buffer.append(data)
if (waiter := self._waiter) is not None:
self._waiter = None
set_result(waiter, None)
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()

async def read(self) -> _SizedT:
if not self._buffer and not self._eof:
await self._wait_for_data()
if self._buffer:
data = self._buffer.popleft()
self._size -= len(data)
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
7 changes: 4 additions & 3 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from multidict import CIMultiDict

from . import hdrs
from ._websocket.reader import WebSocketDataQueue
from ._websocket.writer import DEFAULT_LIMIT
from .abc import AbstractStreamWriter
from .client_exceptions import WSMessageTypeError
Expand All @@ -29,7 +30,7 @@
)
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WSMessageError
from .log import ws_logger
from .streams import EofStream, FlowControlDataQueue
from .streams import EofStream
from .typedefs import JSONDecoder, JSONEncoder
from .web_exceptions import HTTPBadRequest, HTTPException
from .web_request import BaseRequest
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(
self._protocols = protocols
self._ws_protocol: Optional[str] = None
self._writer: Optional[WebSocketWriter] = None
self._reader: Optional[FlowControlDataQueue[WSMessage]] = None
self._reader: Optional[WebSocketDataQueue[WSMessage]] = None
bdraco marked this conversation as resolved.
Show resolved Hide resolved
self._closed = False
self._closing = False
self._conn_lost = 0
Expand Down Expand Up @@ -355,7 +356,7 @@ def _post_start(

loop = self._loop
assert loop is not None
self._reader = FlowControlDataQueue(request._protocol, 2**16, loop=loop)
self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
request.protocol.set_parser(
WebSocketReader(self._reader, self._max_msg_size, compress=self._compress)
)
Expand Down
24 changes: 0 additions & 24 deletions tests/test_flowcontrol_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ def stream(
return streams.StreamReader(protocol, limit=1, loop=loop)


@pytest.fixture
def buffer(
loop: asyncio.AbstractEventLoop, protocol: BaseProtocol
) -> streams.FlowControlDataQueue[str]:
return streams.FlowControlDataQueue[str](protocol, limit=1, loop=loop)


class TestFlowControlStreamReader:
async def test_read(self, stream: streams.StreamReader) -> None:
stream.feed_data(b"da")
Expand Down Expand Up @@ -114,20 +107,3 @@ async def test_read_nowait(self, stream: streams.StreamReader) -> None:
res = stream.read_nowait(5)
assert res == b""
assert stream._protocol.resume_reading.call_count == 1 # type: ignore[attr-defined]


class TestFlowControlDataQueue:
def test_feed_pause(self, buffer: streams.FlowControlDataQueue[str]) -> None:
buffer._protocol._reading_paused = False
buffer.feed_data("x" * 100)

assert buffer._protocol.pause_reading.called # type: ignore[attr-defined]

async def test_resume_on_read(
self, buffer: streams.FlowControlDataQueue[str]
) -> None:
buffer.feed_data("x" * 100)

buffer._protocol._reading_paused = True
await buffer.read()
assert buffer._protocol.resume_reading.called # type: ignore[attr-defined]
36 changes: 36 additions & 0 deletions tests/test_websocket_data_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio
from unittest import mock

import pytest

from aiohttp._websocket.models import WSMessageBinary
from aiohttp._websocket.reader import WebSocketDataQueue
from aiohttp.base_protocol import BaseProtocol


@pytest.fixture
def protocol() -> BaseProtocol:
return mock.create_autospec(BaseProtocol, spec_set=True, instance=True, _reading_paused=False) # type: ignore[no-any-return]


@pytest.fixture
def buffer(
loop: asyncio.AbstractEventLoop, protocol: BaseProtocol
) -> WebSocketDataQueue:
return WebSocketDataQueue(protocol, limit=1, loop=loop)


class TestWebSocketDataQueue:
def test_feed_pause(self, buffer: WebSocketDataQueue) -> None:
buffer._protocol._reading_paused = False
for _ in range(3):
buffer.feed_data(WSMessageBinary(b"x"))

assert buffer._protocol.pause_reading.called # type: ignore[attr-defined]

async def test_resume_on_read(self, buffer: WebSocketDataQueue) -> None:
buffer.feed_data(WSMessageBinary(b"x"))

buffer._protocol._reading_paused = True
await buffer.read()
assert buffer._protocol.resume_reading.called # type: ignore[attr-defined]
Loading