diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..8c2894a1a5 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +This release improves the `graphql-transport-ws` implementation by starting the sub-protocol timeout only when the connection handshake is completed. diff --git a/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py b/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py index 4b1c1a5151..46faa1a7ba 100644 --- a/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +++ b/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py @@ -44,6 +44,7 @@ async def close(self, code: int, reason: str) -> None: async def handle_request(self) -> web.StreamResponse: await self._ws.prepare(self._request) + self.on_request_accepted() try: async for ws_message in self._ws: # type: http.WSMessage @@ -53,8 +54,6 @@ async def handle_request(self) -> web.StreamResponse: error_message = "WebSocket message type must be text" await self.handle_invalid_message(error_message) finally: - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - await self.reap_completed_tasks() + await self.shutdown() return self._ws diff --git a/strawberry/asgi/handlers/graphql_transport_ws_handler.py b/strawberry/asgi/handlers/graphql_transport_ws_handler.py index 6a5e4f6efc..c4f64fb323 100644 --- a/strawberry/asgi/handlers/graphql_transport_ws_handler.py +++ b/strawberry/asgi/handlers/graphql_transport_ws_handler.py @@ -46,6 +46,7 @@ async def close(self, code: int, reason: str) -> None: async def handle_request(self) -> None: await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL) + self.on_request_accepted() try: while self._ws.application_state != WebSocketState.DISCONNECTED: @@ -59,6 +60,4 @@ async def handle_request(self) -> None: except WebSocketDisconnect: # pragma: no cover pass finally: - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - await self.reap_completed_tasks() + await self.shutdown() diff --git a/strawberry/channels/handlers/graphql_transport_ws_handler.py b/strawberry/channels/handlers/graphql_transport_ws_handler.py index 34676a06c8..10d7f8455c 100644 --- a/strawberry/channels/handlers/graphql_transport_ws_handler.py +++ b/strawberry/channels/handlers/graphql_transport_ws_handler.py @@ -53,9 +53,7 @@ async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: async def handle_request(self) -> Any: await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL) + self.on_request_accepted() async def handle_disconnect(self, code) -> None: - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - await self.reap_completed_tasks() + await self.shutdown() diff --git a/strawberry/starlite/handlers/graphql_transport_ws_handler.py b/strawberry/starlite/handlers/graphql_transport_ws_handler.py index 4396170efe..0ce30cbd36 100644 --- a/strawberry/starlite/handlers/graphql_transport_ws_handler.py +++ b/strawberry/starlite/handlers/graphql_transport_ws_handler.py @@ -39,6 +39,7 @@ async def close(self, code: int, reason: str) -> None: async def handle_request(self) -> None: await self._ws.accept(subprotocols=GRAPHQL_TRANSPORT_WS_PROTOCOL) + self.on_request_accepted() try: while self._ws.connection_state != "disconnect": @@ -52,6 +53,4 @@ async def handle_request(self) -> None: except WebSocketDisconnect: # pragma: no cover pass finally: - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - await self.reap_completed_tasks() + await self.shutdown() diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index ffa48ccddf..85a58b7578 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging from abc import ABC, abstractmethod from contextlib import suppress from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional @@ -35,6 +36,8 @@ class BaseGraphQLTransportWSHandler(ABC): + task_logger: logging.Logger = logging.getLogger("strawberry.ws.task") + def __init__( self, schema: BaseSchema, @@ -47,6 +50,7 @@ def __init__( self.connection_init_timeout_task: Optional[asyncio.Task] = None self.connection_init_received = False self.connection_acknowledged = False + self.connection_timed_out = False self.subscriptions: Dict[str, AsyncGenerator] = {} self.tasks: Dict[str, asyncio.Task] = {} self.completed_tasks: List[asyncio.Task] = [] @@ -73,19 +77,50 @@ async def handle_request(self) -> Any: """Handle the request this instance was created for""" async def handle(self) -> Any: - timeout_handler = self.handle_connection_init_timeout() - self.connection_init_timeout_task = asyncio.create_task(timeout_handler) return await self.handle_request() + async def shutdown(self) -> None: + if self.connection_init_timeout_task: + self.connection_init_timeout_task.cancel() + with suppress(asyncio.CancelledError): + await self.connection_init_timeout_task + + for operation_id in list(self.subscriptions.keys()): + await self.cleanup_operation(operation_id) + await self.reap_completed_tasks() + + def on_request_accepted(self) -> None: + # handle_request should call this once it has sent the + # websocket.accept() response to start the timeout. + assert not self.connection_init_timeout_task + self.connection_init_timeout_task = asyncio.create_task( + self.handle_connection_init_timeout() + ) + async def handle_connection_init_timeout(self) -> None: - delay = self.connection_init_wait_timeout.total_seconds() - await asyncio.sleep(delay=delay) + task = asyncio.current_task() + assert task + try: + delay = self.connection_init_wait_timeout.total_seconds() + await asyncio.sleep(delay=delay) - if self.connection_init_received: - return + if self.connection_init_received: + return # pragma: no cover + + self.connection_timed_out = True + reason = "Connection initialisation timeout" + await self.close(code=4408, reason=reason) + except asyncio.CancelledError: + raise + except Exception as error: + await self.handle_task_exception(error) # pragma: no cover + finally: + # do not clear self.connection_init_timeout_task + # so that unittests can inspect it. + self.completed_tasks.append(task) - reason = "Connection initialisation timeout" - await self.close(code=4408, reason=reason) + async def handle_task_exception(self, error: Exception) -> None: + self.task_logger.exception("Exception in worker task", exc_info=error) async def handle_message(self, message: dict) -> None: handler: Callable @@ -126,6 +161,12 @@ async def handle_message(self, message: dict) -> None: await self.reap_completed_tasks() async def handle_connection_init(self, message: ConnectionInitMessage) -> None: + if self.connection_timed_out: + # No way to reliably excercise this case during testing + return # pragma: no cover + if self.connection_init_timeout_task: + self.connection_init_timeout_task.cancel() + if message.payload is not UNSET and not isinstance(message.payload, dict): await self.close(code=4400, reason="Invalid connection init payload") return @@ -228,6 +269,7 @@ async def operation_task( Operation task top level method. Cleans up and de-registers the operation once it is done. """ + # TODO: Handle errors in this method using self.handle_task_exception() try: await self.handle_async_results(result_source, operation) except BaseException: # pragma: no cover diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 7fb86b509d..ce00dcce26 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -243,6 +243,17 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]: class DebuggableGraphQLTransportWSMixin: + @staticmethod + def on_init(self): + """ + This method can be patched by unittests to get the instance of the + transport handler when it is initialized + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + DebuggableGraphQLTransportWSMixin.on_init(self) + async def get_context(self) -> object: context = await super().get_context() context["ws"] = self._ws diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 31533a7c22..3fddcdf30e 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -3,6 +3,12 @@ import time from datetime import timedelta from typing import AsyncGenerator, Type +from unittest.mock import patch + +try: + from unittest.mock import AsyncMock +except ImportError: + AsyncMock = None import pytest import pytest_asyncio @@ -19,7 +25,8 @@ SubscribeMessage, SubscribeMessagePayload, ) -from tests.http.clients import AioHttpClient +from tests.http.clients import AioHttpClient, ChannelsHttpClient +from tests.http.clients.base import DebuggableGraphQLTransportWSMixin from ..http.clients import HttpClient, WebSocketClient @@ -120,40 +127,33 @@ async def test_connection_init_timeout(request, http_client_class: Type[HttpClie @pytest.mark.flaky async def test_connection_init_timeout_cancellation( - http_client_class: Type[HttpClient], + ws_raw: WebSocketClient, ): - test_client = http_client_class() - test_client.create_app(connection_init_wait_timeout=timedelta(milliseconds=100)) - async with test_client.ws_connect( - "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - await ws.send_json(ConnectionInitMessage().as_dict()) - - response = await ws.receive_json() - assert response == ConnectionAckMessage().as_dict() - - await asyncio.sleep(0.2) + # Verify that the timeout task is cancelled after the connection Init + # message is received + ws = ws_raw + await ws.send_json(ConnectionInitMessage().as_dict()) - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { debug { isConnectionInitTimeoutTaskDone } }" - ), - ).as_dict() - ) + response = await ws.receive_json() + assert response == ConnectionAckMessage().as_dict() - response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", - payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}}, - ).as_dict() - ) + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { debug { isConnectionInitTimeoutTaskDone } }" + ), + ).as_dict() + ) - await ws.close() - assert ws.closed + response = await ws.receive_json() + assert ( + response + == NextMessage( + id="sub1", + payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}}, + ).as_dict() + ) async def test_too_many_initialisation_requests(ws: WebSocketClient): @@ -790,3 +790,44 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient): end = time.time() elapsed = end - start assert elapsed < delay + + +async def test_error_handler_for_timeout(http_client: HttpClient): + """ + Test that the error handler is called when the timeout + task encounters an error + """ + if isinstance(http_client, ChannelsHttpClient): + pytest.skip("Can't patch on_init for this client") + if not AsyncMock: + pytest.skip("Don't have AsyncMock") + ws = ws_raw + handler = None + errorhandler = AsyncMock() + + def on_init(_handler): + nonlocal handler + if handler: + return + handler = _handler + # patch the object + handler.handle_task_exception = errorhandler + # cause an attribute error in the timeout task + handler.connection_init_wait_timeout = None + + with patch.object(DebuggableGraphQLTransportWSMixin, "on_init", on_init): + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await asyncio.sleep(0.01) # wait for the timeout task to start + await ws.send_json(ConnectionInitMessage().as_dict()) + response = await ws.receive_json() + assert response == ConnectionAckMessage().as_dict() + await ws.close() + + # the error hander should have been called + assert handler + errorhandler.assert_called_once() + args = errorhandler.call_args + assert isinstance(args[0][0], AttributeError) + assert "total_seconds" in str(args[0][0])