Skip to content

Commit

Permalink
collect errors more reliably from websocket test client (#2814)
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert authored Dec 29, 2024
1 parent 31d182c commit 27b6f4c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 49 deletions.
79 changes: 36 additions & 43 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
import inspect
import io
import json
Expand All @@ -9,7 +10,6 @@
import sys
import typing
from concurrent.futures import Future
from functools import cached_property
from types import GeneratorType
from urllib.parse import unquote, urljoin

Expand Down Expand Up @@ -85,6 +85,14 @@ class WebSocketDenialResponse( # type: ignore[misc]
"""


class _Eof(enum.Enum):
EOF = enum.auto()


EOF: typing.Final = _Eof.EOF
Eof = typing.Literal[_Eof.EOF]


class WebSocketTestSession:
def __init__(
self,
Expand All @@ -97,63 +105,47 @@ def __init__(
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None

def __enter__(self) -> WebSocketTestSession:
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())

try:
_: Future[None] = self.portal.start_task_soon(self._run)
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())
fut, cs = portal.start_task(self._run)
stack.callback(fut.result)
stack.callback(portal.call, cs.cancel)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
except Exception:
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
return self

@cached_property
def should_close(self) -> anyio.Event:
return anyio.Event()

async def _notify_close(self) -> None:
self.should_close.set()
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self

def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
self.portal.start_task_soon(self._notify_close)
self.exit_stack.close()
while not self._send_queue.empty():
self.exit_stack.close()

while True:
message = self._send_queue.get()
if message is EOF:
break
if isinstance(message, BaseException):
raise message
raise message # pragma: no cover (defensive, should be impossible)

async def _run(self) -> None:
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""

async def run_app(tg: anyio.abc.TaskGroup) -> None:
try:
try:
with anyio.CancelScope() as cs:
task_status.started(cs)
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except anyio.get_cancelled_exc_class():
...
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
tg.cancel_scope.cancel()

async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+

async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
Expand Down Expand Up @@ -202,6 +194,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None:

def receive(self) -> Message:
message = self._send_queue.get()
assert message is not EOF
if isinstance(message, BaseException):
raise message
return message
Expand Down
18 changes: 12 additions & 6 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,25 @@ async def asgi(receive: Receive, send: Send) -> None:


def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None:
cancelled = False

def app(scope: Scope) -> ASGIInstance:
async def asgi(receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
while True:
await anyio.sleep(0.1)
nonlocal cancelled
try:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await anyio.sleep_forever()
except anyio.get_cancelled_exc_class():
cancelled = True
raise

return asgi

client = test_client_factory(app) # type: ignore
with client.websocket_connect("/") as websocket:
with client.websocket_connect("/"):
...
assert websocket.should_close.is_set()
assert cancelled


def test_client(test_client_factory: TestClientFactory) -> None:
Expand Down

0 comments on commit 27b6f4c

Please sign in to comment.