Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix problems with timeouts in graphql_transport_ws #2703

Merged
5 changes: 2 additions & 3 deletions strawberry/aiohttp/handlers/graphql_transport_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def close(self, code: int, reason: str) -> None:

async def handle_request(self) -> web.StreamResponse:
await self._ws.prepare(self._request)
self.on_request_accepted()

try:
async for ws_message in self._ws: # type: http.WSMessage
Expand All @@ -53,8 +54,6 @@ async def handle_request(self) -> web.StreamResponse:
error_message = "WebSocket message type must be text"
await self.handle_invalid_message(error_message)
finally:
for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.reap_completed_tasks()
await self.shutdown()

return self._ws
5 changes: 2 additions & 3 deletions strawberry/asgi/handlers/graphql_transport_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def close(self, code: int, reason: str) -> None:

async def handle_request(self) -> None:
await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL)
self.on_request_accepted()

try:
while self._ws.application_state != WebSocketState.DISCONNECTED:
Expand All @@ -59,6 +60,4 @@ async def handle_request(self) -> None:
except WebSocketDisconnect: # pragma: no cover
pass
finally:
for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.reap_completed_tasks()
await self.shutdown()
6 changes: 2 additions & 4 deletions strawberry/channels/handlers/graphql_transport_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:

async def handle_request(self) -> Any:
await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL)
self.on_request_accepted()

async def handle_disconnect(self, code) -> None:
for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)

await self.reap_completed_tasks()
await self.shutdown()
5 changes: 2 additions & 3 deletions strawberry/starlite/handlers/graphql_transport_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def close(self, code: int, reason: str) -> None:

async def handle_request(self) -> None:
await self._ws.accept(subprotocols=GRAPHQL_TRANSPORT_WS_PROTOCOL)
self.on_request_accepted()

try:
while self._ws.connection_state != "disconnect":
Expand All @@ -52,6 +53,4 @@ async def handle_request(self) -> None:
except WebSocketDisconnect: # pragma: no cover
pass
finally:
for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.reap_completed_tasks()
await self.shutdown()
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self.connection_init_timeout_task: Optional[asyncio.Task] = None
self.connection_init_received = False
self.connection_acknowledged = False
self.connection_timed_out = False
self.subscriptions: Dict[str, AsyncGenerator] = {}
self.tasks: Dict[str, asyncio.Task] = {}
self.completed_tasks: List[asyncio.Task] = []
Expand All @@ -73,19 +74,51 @@ async def handle_request(self) -> Any:
"""Handle the request this instance was created for"""

async def handle(self) -> Any:
timeout_handler = self.handle_connection_init_timeout()
self.connection_init_timeout_task = asyncio.create_task(timeout_handler)
return await self.handle_request()

async def shutdown(self) -> None:
if self.connection_init_timeout_task:
self.connection_init_timeout_task.cancel()
with suppress(asyncio.CancelledError):
await self.connection_init_timeout_task

for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.reap_completed_tasks()

def on_request_accepted(self) -> None:
# handle_request should call this once it has sent the
kristjanvalur marked this conversation as resolved.
Show resolved Hide resolved
# websocket.accept() response to start the timeout.
assert not self.connection_init_timeout_task
self.connection_init_timeout_task = asyncio.create_task(
self.handle_connection_init_timeout()
)

async def handle_connection_init_timeout(self) -> None:
delay = self.connection_init_wait_timeout.total_seconds()
await asyncio.sleep(delay=delay)
task = asyncio.current_task()
assert task
try:
delay = self.connection_init_wait_timeout.total_seconds()
await asyncio.sleep(delay=delay)

if self.connection_init_received:
return
if self.connection_init_received:
return # pragma: no cover

reason = "Connection initialisation timeout"
await self.close(code=4408, reason=reason)
self.connection_timed_out = True
reason = "Connection initialisation timeout"
await self.close(code=4408, reason=reason)
except asyncio.CancelledError:
raise
except Exception as error:
await self.handle_task_exception(error) # pragma: no cover
finally:
# do not clear self.connection_init_timeout_task
# so that unittests can inspect it.
self.completed_tasks.append(task)

async def handle_task_exception(self, _: Exception) -> None:
# TODO: Log the error
pass # pragma: no cover

async def handle_message(self, message: dict) -> None:
handler: Callable
Expand Down Expand Up @@ -126,6 +159,12 @@ async def handle_message(self, message: dict) -> None:
await self.reap_completed_tasks()

async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
if self.connection_timed_out:
# No way to reliably excercise this case during testing
return # pragma: no cover
if self.connection_init_timeout_task:
self.connection_init_timeout_task.cancel()

if message.payload is not UNSET and not isinstance(message.payload, dict):
await self.close(code=4400, reason="Invalid connection init payload")
return
Expand Down Expand Up @@ -228,6 +267,7 @@ async def operation_task(
Operation task top level method. Cleans up and de-registers the operation
once it is done.
"""
# TODO: Handle errors in this method using self.handle_task_exception()
try:
await self.handle_async_results(result_source, operation)
except BaseException: # pragma: no cover
Expand Down
11 changes: 11 additions & 0 deletions tests/http/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,17 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]:


class DebuggableGraphQLTransportWSMixin:
@staticmethod
def on_init(self):
"""
This method can be patched by unittests to get the instance of the
transport handler when it is initialized
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
DebuggableGraphQLTransportWSMixin.on_init(self)

async def get_context(self) -> object:
context = await super().get_context()
context["ws"] = self._ws
Expand Down
103 changes: 72 additions & 31 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import time
from datetime import timedelta
from typing import AsyncGenerator, Type
from unittest.mock import patch

try:
from unittest.mock import AsyncMock
except ImportError:
AsyncMock = None

import pytest
import pytest_asyncio
Expand All @@ -19,7 +25,8 @@
SubscribeMessage,
SubscribeMessagePayload,
)
from tests.http.clients import AioHttpClient
from tests.http.clients import AioHttpClient, ChannelsHttpClient
from tests.http.clients.base import DebuggableGraphQLTransportWSMixin

from ..http.clients import HttpClient, WebSocketClient

Expand Down Expand Up @@ -119,40 +126,33 @@ async def test_connection_init_timeout(request, http_client_class: Type[HttpClie


async def test_connection_init_timeout_cancellation(
http_client_class: Type[HttpClient],
ws_raw: WebSocketClient,
):
test_client = http_client_class()
test_client.create_app(connection_init_wait_timeout=timedelta(milliseconds=100))
async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await ws.send_json(ConnectionInitMessage().as_dict())

response = await ws.receive_json()
assert response == ConnectionAckMessage().as_dict()

await asyncio.sleep(0.2)
# Verify that the timeout task is cancelled after the connection Init
# message is received
ws = ws_raw
await ws.send_json(ConnectionInitMessage().as_dict())

await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { debug { isConnectionInitTimeoutTaskDone } }"
),
).as_dict()
)
response = await ws.receive_json()
assert response == ConnectionAckMessage().as_dict()

response = await ws.receive_json()
assert (
response
== NextMessage(
id="sub1",
payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}},
).as_dict()
)
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { debug { isConnectionInitTimeoutTaskDone } }"
),
).as_dict()
)

await ws.close()
assert ws.closed
response = await ws.receive_json()
assert (
response
== NextMessage(
id="sub1",
payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}},
).as_dict()
)


async def test_too_many_initialisation_requests(ws: WebSocketClient):
Expand Down Expand Up @@ -789,3 +789,44 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient):
end = time.time()
elapsed = end - start
assert elapsed < delay


async def test_error_handler_for_timeout(http_client: HttpClient):
"""
Test that the error handler is called when the timeout
task encounters an error
"""
if isinstance(http_client, ChannelsHttpClient):
pytest.skip("Can't patch on_init for this client")
if not AsyncMock:
pytest.skip("Don't have AsyncMock")
ws = ws_raw
handler = None
errorhandler = AsyncMock()

def on_init(_handler):
nonlocal handler
if handler:
return
handler = _handler
# patch the object
handler.handle_task_exception = errorhandler
# cause an attribute error in the timeout task
handler.connection_init_wait_timeout = None

with patch.object(DebuggableGraphQLTransportWSMixin, "on_init", on_init):
async with http_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await asyncio.sleep(0.01) # wait for the timeout task to start
await ws.send_json(ConnectionInitMessage().as_dict())
response = await ws.receive_json()
assert response == ConnectionAckMessage().as_dict()
await ws.close()

# the error hander should have been called
assert handler
errorhandler.assert_called_once()
args = errorhandler.call_args
assert isinstance(args[0][0], AttributeError)
assert "total_seconds" in str(args[0][0])