Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a pair of memory object streams instead of two queues #2829

Merged
merged 1 commit into from
Dec 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 13 additions & 46 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import contextlib
import enum
import inspect
import io
import json
import math
import queue
import sys
import typing
from concurrent.futures import Future
Expand Down Expand Up @@ -85,14 +83,6 @@ class WebSocketDenialResponse( # type: ignore[misc]
"""


class _Eof(enum.Enum):
EOF = enum.auto()


EOF: typing.Final = _Eof.EOF
Eof = typing.Literal[_Eof.EOF]


class WebSocketTestSession:
def __init__(
self,
Expand All @@ -104,8 +94,6 @@ def __init__(
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None

def __enter__(self) -> WebSocketTestSession:
Expand All @@ -123,38 +111,23 @@ def __enter__(self) -> WebSocketTestSession:
self.exit_stack = stack.pop_all()
return self

def __exit__(self, *args: typing.Any) -> None:
self.exit_stack.close()

while True:
message = self._send_queue.get()
if message is EOF:
break
if isinstance(message, BaseException):
raise message # pragma: no cover (defensive, should be impossible)
def __exit__(self, *args: typing.Any) -> bool | None:
return self.exit_stack.__exit__(*args)

async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""
try:
with anyio.CancelScope() as cs:
task_status.started(cs)
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+

async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
self._queue_event = anyio.Event()
await self._queue_event.wait()
return self._receive_queue.get()
send_tx, send_rx = anyio.create_memory_object_stream[Message](math.inf)
receive_tx, receive_rx = anyio.create_memory_object_stream[Message](math.inf)
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
self._receive_tx = receive_tx
self._send_rx = send_rx
task_status.started(cs)
await self.app(self.scope, receive_rx.receive, send_tx.send)

async def _asgi_send(self, message: Message) -> None:
self._send_queue.put(message)
# wait for cs.cancel to be called before closing streams
await anyio.sleep_forever()

def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
Expand All @@ -172,9 +145,7 @@ def _raise_on_close(self, message: Message) -> None:
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))

def send(self, message: Message) -> None:
self._receive_queue.put(message)
if hasattr(self, "_queue_event"):
self.portal.start_task_soon(self._queue_event.set)
self.portal.call(self._receive_tx.send, message)

def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
Expand All @@ -193,11 +164,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})

def receive(self) -> Message:
message = self._send_queue.get()
assert message is not EOF
if isinstance(message, BaseException):
raise message
return message
return self.portal.call(self._send_rx.receive)

def receive_text(self) -> str:
message = self.receive()
Expand Down
Loading