Skip to content

Commit

Permalink
use a pair of memory object streams instead of two queues (#2829)
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert authored Dec 29, 2024
1 parent 27b6f4c commit e16bacb
Showing 1 changed file with 13 additions and 46 deletions.
59 changes: 13 additions & 46 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import contextlib
import enum
import inspect
import io
import json
import math
import queue
import sys
import typing
from concurrent.futures import Future
Expand Down Expand Up @@ -85,14 +83,6 @@ class WebSocketDenialResponse( # type: ignore[misc]
"""


class _Eof(enum.Enum):
EOF = enum.auto()


EOF: typing.Final = _Eof.EOF
Eof = typing.Literal[_Eof.EOF]


class WebSocketTestSession:
def __init__(
self,
Expand All @@ -104,8 +94,6 @@ def __init__(
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None

def __enter__(self) -> WebSocketTestSession:
Expand All @@ -123,38 +111,23 @@ def __enter__(self) -> WebSocketTestSession:
self.exit_stack = stack.pop_all()
return self

def __exit__(self, *args: typing.Any) -> None:
self.exit_stack.close()

while True:
message = self._send_queue.get()
if message is EOF:
break
if isinstance(message, BaseException):
raise message # pragma: no cover (defensive, should be impossible)
def __exit__(self, *args: typing.Any) -> bool | None:
return self.exit_stack.__exit__(*args)

async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""
try:
with anyio.CancelScope() as cs:
task_status.started(cs)
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+

async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
self._queue_event = anyio.Event()
await self._queue_event.wait()
return self._receive_queue.get()
send_tx, send_rx = anyio.create_memory_object_stream[Message](math.inf)
receive_tx, receive_rx = anyio.create_memory_object_stream[Message](math.inf)
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
self._receive_tx = receive_tx
self._send_rx = send_rx
task_status.started(cs)
await self.app(self.scope, receive_rx.receive, send_tx.send)

async def _asgi_send(self, message: Message) -> None:
self._send_queue.put(message)
# wait for cs.cancel to be called before closing streams
await anyio.sleep_forever()

def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
Expand All @@ -172,9 +145,7 @@ def _raise_on_close(self, message: Message) -> None:
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))

def send(self, message: Message) -> None:
self._receive_queue.put(message)
if hasattr(self, "_queue_event"):
self.portal.start_task_soon(self._queue_event.set)
self.portal.call(self._receive_tx.send, message)

def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
Expand All @@ -193,11 +164,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})

def receive(self) -> Message:
message = self._send_queue.get()
assert message is not EOF
if isinstance(message, BaseException):
raise message
return message
return self.portal.call(self._send_rx.receive)

def receive_text(self) -> str:
message = self.receive()
Expand Down

0 comments on commit e16bacb

Please sign in to comment.