Skip to content

Commit

Permalink
[PR #8498/7bf6ee1 backport][3.10] Avoid creating a future on every we…
Browse files Browse the repository at this point in the history
…bsocket receive (#8503)

Co-authored-by: Sam Bull <[email protected]>
  • Loading branch information
bdraco and Dreamsorcerer authored Jul 14, 2024
1 parent aae7ac5 commit 3bc89fe
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGES/8498.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid creating a future on every websocket receive -- by :user:`bdraco`.
19 changes: 11 additions & 8 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._loop = loop
self._waiting: Optional[asyncio.Future[bool]] = None
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
Expand Down Expand Up @@ -185,10 +186,12 @@ async def send_json(
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closing:
if self._waiting and not self._closing:
assert self._loop is not None
self._close_wait = self._loop.create_future()
self._closing = True
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting
await self._close_wait

if not self._closed:
self._cancel_heartbeat()
Expand Down Expand Up @@ -232,7 +235,7 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo

async def receive(self, timeout: Optional[float] = None) -> WSMessage:
while True:
if self._waiting is not None:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")

if self._closed:
Expand All @@ -242,15 +245,15 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
return WS_CLOSED_MESSAGE

try:
self._waiting = self._loop.create_future()
self._waiting = True
try:
async with async_timeout.timeout(timeout or self._receive_timeout):
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
self._waiting = None
set_result(waiter, True)
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
Expand Down
20 changes: 12 additions & 8 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def __init__(
self._conn_lost = 0
self._close_code: Optional[int] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._waiting: Optional[asyncio.Future[bool]] = None
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._timeout = timeout
self._receive_timeout = receive_timeout
Expand Down Expand Up @@ -376,9 +377,12 @@ async def close(

# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
if self._waiting and not self._closed:
if not self._close_wait:
assert self._loop is not None
self._close_wait = self._loop.create_future()
reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting
await self._close_wait

if self._closed:
return False
Expand Down Expand Up @@ -445,7 +449,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
loop = self._loop
assert loop is not None
while True:
if self._waiting is not None:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")

if self._closed:
Expand All @@ -457,15 +461,15 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
return WS_CLOSING_MESSAGE

try:
self._waiting = loop.create_future()
self._waiting = True
try:
async with async_timeout.timeout(timeout or self._receive_timeout):
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
set_result(waiter, True)
self._waiting = None
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except asyncio.TimeoutError:
raise
except EofStream:
Expand Down
39 changes: 36 additions & 3 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import sys
from typing import Any

import pytest

Expand Down Expand Up @@ -245,7 +246,7 @@ async def handler(request):
await client_ws.close()

msg = await ws.receive()
assert msg.type == aiohttp.WSMsgType.CLOSE
assert msg.type is aiohttp.WSMsgType.CLOSE
return ws

app = web.Application()
Expand All @@ -256,11 +257,43 @@ async def handler(request):
await ws.send_bytes(b"ask")

msg = await ws.receive()
assert msg.type == aiohttp.WSMsgType.CLOSING
assert msg.type is aiohttp.WSMsgType.CLOSING

await asyncio.sleep(0.01)
msg = await ws.receive()
assert msg.type == aiohttp.WSMsgType.CLOSED
assert msg.type is aiohttp.WSMsgType.CLOSED


async def test_concurrent_close_multiple_tasks(aiohttp_client: Any) -> None:
async def handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)

await ws.receive_bytes()
await ws.send_str("test")

msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSE
return ws

app = web.Application()
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)
ws = await client.ws_connect("/")

await ws.send_bytes(b"ask")

task1 = asyncio.create_task(ws.close())
task2 = asyncio.create_task(ws.close())

msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSED

await task1
await task2

msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSED


async def test_concurrent_task_close(aiohttp_client) -> None:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,47 @@ async def handler(request):
assert msg.type == WSMsgType.CLOSED


async def test_concurrent_close_multiple_tasks(loop: Any, aiohttp_client: Any) -> None:
srv_ws = None

async def handler(request):
nonlocal srv_ws
ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar"))
await ws.prepare(request)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING

await asyncio.sleep(0)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSED

return ws

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar"))

task1 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT))
task2 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT))

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE

await task1
await task2

await asyncio.sleep(0)
msg = await ws.receive()
assert msg.type == WSMsgType.CLOSED


async def test_close_op_code_from_client(loop: Any, aiohttp_client: Any) -> None:
srv_ws: Optional[web.WebSocketResponse] = None

Expand Down

0 comments on commit 3bc89fe

Please sign in to comment.