Skip to content

Commit

Permalink
Fix problems with timeouts in graphql_transport_ws (#2703)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Arminio <[email protected]>
  • Loading branch information
kristjanvalur and patrick91 authored May 2, 2023
1 parent e7aac0e commit 6e730d9
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 52 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This release improves the `graphql-transport-ws` implementation by starting the sub-protocol timeout only when the connection handshake is completed.
5 changes: 2 additions & 3 deletions strawberry/aiohttp/handlers/graphql_transport_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def close(self, code: int, reason: str) -> None:

async def handle_request(self) -> web.StreamResponse:
await self._ws.prepare(self._request)
self.on_request_accepted()

try:
async for ws_message in self._ws: # type: http.WSMessage
Expand All @@ -53,8 +54,6 @@ async def handle_request(self) -> web.StreamResponse:
error_message = "WebSocket message type must be text"
await self.handle_invalid_message(error_message)
finally:
for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.reap_completed_tasks()
await self.shutdown()

return self._ws
5 changes: 2 additions & 3 deletions strawberry/asgi/handlers/graphql_transport_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def close(self, code: int, reason: str) -> None:

async def handle_request(self) -> None:
await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL)
self.on_request_accepted()

try:
while self._ws.application_state != WebSocketState.DISCONNECTED:
Expand All @@ -59,6 +60,4 @@ async def handle_request(self) -> None:
except WebSocketDisconnect: # pragma: no cover
pass
finally:
for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.reap_completed_tasks()
await self.shutdown()
6 changes: 2 additions & 4 deletions strawberry/channels/handlers/graphql_transport_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:

async def handle_request(self) -> Any:
await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL)
self.on_request_accepted()

async def handle_disconnect(self, code) -> None:
for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)

await self.reap_completed_tasks()
await self.shutdown()
5 changes: 2 additions & 3 deletions strawberry/starlite/handlers/graphql_transport_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def close(self, code: int, reason: str) -> None:

async def handle_request(self) -> None:
await self._ws.accept(subprotocols=GRAPHQL_TRANSPORT_WS_PROTOCOL)
self.on_request_accepted()

try:
while self._ws.connection_state != "disconnect":
Expand All @@ -52,6 +53,4 @@ async def handle_request(self) -> None:
except WebSocketDisconnect: # pragma: no cover
pass
finally:
for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.reap_completed_tasks()
await self.shutdown()
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import logging
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional
Expand Down Expand Up @@ -35,6 +36,8 @@


class BaseGraphQLTransportWSHandler(ABC):
task_logger: logging.Logger = logging.getLogger("strawberry.ws.task")

def __init__(
self,
schema: BaseSchema,
Expand All @@ -47,6 +50,7 @@ def __init__(
self.connection_init_timeout_task: Optional[asyncio.Task] = None
self.connection_init_received = False
self.connection_acknowledged = False
self.connection_timed_out = False
self.subscriptions: Dict[str, AsyncGenerator] = {}
self.tasks: Dict[str, asyncio.Task] = {}
self.completed_tasks: List[asyncio.Task] = []
Expand All @@ -73,19 +77,50 @@ async def handle_request(self) -> Any:
"""Handle the request this instance was created for"""

async def handle(self) -> Any:
timeout_handler = self.handle_connection_init_timeout()
self.connection_init_timeout_task = asyncio.create_task(timeout_handler)
return await self.handle_request()

async def shutdown(self) -> None:
if self.connection_init_timeout_task:
self.connection_init_timeout_task.cancel()
with suppress(asyncio.CancelledError):
await self.connection_init_timeout_task

for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.reap_completed_tasks()

def on_request_accepted(self) -> None:
# handle_request should call this once it has sent the
# websocket.accept() response to start the timeout.
assert not self.connection_init_timeout_task
self.connection_init_timeout_task = asyncio.create_task(
self.handle_connection_init_timeout()
)

async def handle_connection_init_timeout(self) -> None:
delay = self.connection_init_wait_timeout.total_seconds()
await asyncio.sleep(delay=delay)
task = asyncio.current_task()
assert task
try:
delay = self.connection_init_wait_timeout.total_seconds()
await asyncio.sleep(delay=delay)

if self.connection_init_received:
return
if self.connection_init_received:
return # pragma: no cover

self.connection_timed_out = True
reason = "Connection initialisation timeout"
await self.close(code=4408, reason=reason)
except asyncio.CancelledError:
raise
except Exception as error:
await self.handle_task_exception(error) # pragma: no cover
finally:
# do not clear self.connection_init_timeout_task
# so that unittests can inspect it.
self.completed_tasks.append(task)

reason = "Connection initialisation timeout"
await self.close(code=4408, reason=reason)
async def handle_task_exception(self, error: Exception) -> None:
self.task_logger.exception("Exception in worker task", exc_info=error)

async def handle_message(self, message: dict) -> None:
handler: Callable
Expand Down Expand Up @@ -126,6 +161,12 @@ async def handle_message(self, message: dict) -> None:
await self.reap_completed_tasks()

async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
if self.connection_timed_out:
# No way to reliably excercise this case during testing
return # pragma: no cover
if self.connection_init_timeout_task:
self.connection_init_timeout_task.cancel()

if message.payload is not UNSET and not isinstance(message.payload, dict):
await self.close(code=4400, reason="Invalid connection init payload")
return
Expand Down Expand Up @@ -228,6 +269,7 @@ async def operation_task(
Operation task top level method. Cleans up and de-registers the operation
once it is done.
"""
# TODO: Handle errors in this method using self.handle_task_exception()
try:
await self.handle_async_results(result_source, operation)
except BaseException: # pragma: no cover
Expand Down
11 changes: 11 additions & 0 deletions tests/http/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,17 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]:


class DebuggableGraphQLTransportWSMixin:
@staticmethod
def on_init(self):
"""
This method can be patched by unittests to get the instance of the
transport handler when it is initialized
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
DebuggableGraphQLTransportWSMixin.on_init(self)

async def get_context(self) -> object:
context = await super().get_context()
context["ws"] = self._ws
Expand Down
103 changes: 72 additions & 31 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import time
from datetime import timedelta
from typing import AsyncGenerator, Type
from unittest.mock import patch

try:
from unittest.mock import AsyncMock
except ImportError:
AsyncMock = None

import pytest
import pytest_asyncio
Expand All @@ -19,7 +25,8 @@
SubscribeMessage,
SubscribeMessagePayload,
)
from tests.http.clients import AioHttpClient
from tests.http.clients import AioHttpClient, ChannelsHttpClient
from tests.http.clients.base import DebuggableGraphQLTransportWSMixin

from ..http.clients import HttpClient, WebSocketClient

Expand Down Expand Up @@ -120,40 +127,33 @@ async def test_connection_init_timeout(request, http_client_class: Type[HttpClie

@pytest.mark.flaky
async def test_connection_init_timeout_cancellation(
http_client_class: Type[HttpClient],
ws_raw: WebSocketClient,
):
test_client = http_client_class()
test_client.create_app(connection_init_wait_timeout=timedelta(milliseconds=100))
async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await ws.send_json(ConnectionInitMessage().as_dict())

response = await ws.receive_json()
assert response == ConnectionAckMessage().as_dict()

await asyncio.sleep(0.2)
# Verify that the timeout task is cancelled after the connection Init
# message is received
ws = ws_raw
await ws.send_json(ConnectionInitMessage().as_dict())

await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { debug { isConnectionInitTimeoutTaskDone } }"
),
).as_dict()
)
response = await ws.receive_json()
assert response == ConnectionAckMessage().as_dict()

response = await ws.receive_json()
assert (
response
== NextMessage(
id="sub1",
payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}},
).as_dict()
)
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { debug { isConnectionInitTimeoutTaskDone } }"
),
).as_dict()
)

await ws.close()
assert ws.closed
response = await ws.receive_json()
assert (
response
== NextMessage(
id="sub1",
payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}},
).as_dict()
)


async def test_too_many_initialisation_requests(ws: WebSocketClient):
Expand Down Expand Up @@ -790,3 +790,44 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient):
end = time.time()
elapsed = end - start
assert elapsed < delay


async def test_error_handler_for_timeout(http_client: HttpClient):
"""
Test that the error handler is called when the timeout
task encounters an error
"""
if isinstance(http_client, ChannelsHttpClient):
pytest.skip("Can't patch on_init for this client")
if not AsyncMock:
pytest.skip("Don't have AsyncMock")
ws = ws_raw
handler = None
errorhandler = AsyncMock()

def on_init(_handler):
nonlocal handler
if handler:
return
handler = _handler
# patch the object
handler.handle_task_exception = errorhandler
# cause an attribute error in the timeout task
handler.connection_init_wait_timeout = None

with patch.object(DebuggableGraphQLTransportWSMixin, "on_init", on_init):
async with http_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await asyncio.sleep(0.01) # wait for the timeout task to start
await ws.send_json(ConnectionInitMessage().as_dict())
response = await ws.receive_json()
assert response == ConnectionAckMessage().as_dict()
await ws.close()

# the error hander should have been called
assert handler
errorhandler.assert_called_once()
args = errorhandler.call_args
assert isinstance(args[0][0], AttributeError)
assert "total_seconds" in str(args[0][0])

0 comments on commit 6e730d9

Please sign in to comment.