diff --git a/starlette/testclient.py b/starlette/testclient.py index a14f646d4..9a0abbd7b 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,12 +1,10 @@ from __future__ import annotations import contextlib -import enum import inspect import io import json import math -import queue import sys import typing from concurrent.futures import Future @@ -85,14 +83,6 @@ class WebSocketDenialResponse( # type: ignore[misc] """ -class _Eof(enum.Enum): - EOF = enum.auto() - - -EOF: typing.Final = _Eof.EOF -Eof = typing.Literal[_Eof.EOF] - - class WebSocketTestSession: def __init__( self, @@ -104,8 +94,6 @@ def __init__( self.scope = scope self.accepted_subprotocol = None self.portal_factory = portal_factory - self._receive_queue: queue.Queue[Message] = queue.Queue() - self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue() self.extra_headers = None def __enter__(self) -> WebSocketTestSession: @@ -123,38 +111,23 @@ def __enter__(self) -> WebSocketTestSession: self.exit_stack = stack.pop_all() return self - def __exit__(self, *args: typing.Any) -> None: - self.exit_stack.close() - - while True: - message = self._send_queue.get() - if message is EOF: - break - if isinstance(message, BaseException): - raise message # pragma: no cover (defensive, should be impossible) + def __exit__(self, *args: typing.Any) -> bool | None: + return self.exit_stack.__exit__(*args) async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: """ The sub-thread in which the websocket session runs. """ - try: - with anyio.CancelScope() as cs: - task_status.started(cs) - await self.app(self.scope, self._asgi_receive, self._asgi_send) - except BaseException as exc: - self._send_queue.put(exc) - raise - finally: - self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+ - - async def _asgi_receive(self) -> Message: - while self._receive_queue.empty(): - self._queue_event = anyio.Event() - await self._queue_event.wait() - return self._receive_queue.get() + send_tx, send_rx = anyio.create_memory_object_stream[Message](math.inf) + receive_tx, receive_rx = anyio.create_memory_object_stream[Message](math.inf) + with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: + self._receive_tx = receive_tx + self._send_rx = send_rx + task_status.started(cs) + await self.app(self.scope, receive_rx.receive, send_tx.send) - async def _asgi_send(self, message: Message) -> None: - self._send_queue.put(message) + # wait for cs.cancel to be called before closing streams + await anyio.sleep_forever() def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": @@ -172,9 +145,7 @@ def _raise_on_close(self, message: Message) -> None: raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) def send(self, message: Message) -> None: - self._receive_queue.put(message) - if hasattr(self, "_queue_event"): - self.portal.start_task_soon(self._queue_event.set) + self.portal.call(self._receive_tx.send, message) def send_text(self, data: str) -> None: self.send({"type": "websocket.receive", "text": data}) @@ -193,11 +164,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None: self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) def receive(self) -> Message: - message = self._send_queue.get() - assert message is not EOF - if isinstance(message, BaseException): - raise message - return message + return self.portal.call(self._send_rx.receive) def receive_text(self) -> str: message = self.receive()