diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..8f6812b821 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +Fix error handling for query operations over graphql-transport-ws \ No newline at end of file diff --git a/tests/views/schema.py b/tests/views/schema.py index b0c14bfd76..cf50dfe415 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -20,6 +20,19 @@ def has_permission(self, source: Any, info: strawberry.Info, **kwargs: Any) -> b return False +class ConditionalFailPermission(BasePermission): + @property + def message(self): + return f"failed after sleep {self.sleep}" + + async def has_permission(self, source, info, **kwargs: Any) -> bool: + self.sleep = kwargs.get("sleep", None) + self.fail = kwargs.get("fail", True) + if self.sleep is not None: + await asyncio.sleep(kwargs["sleep"]) + return not self.fail + + class MyExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} @@ -80,6 +93,12 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str def always_fail(self) -> Optional[str]: return "Hey" + @strawberry.field(permission_classes=[ConditionalFailPermission]) + def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> str: + return "Hey" + @strawberry.field async def error(self, message: str) -> AsyncGenerator[str, None]: yield GraphQLError(message) # type: ignore @@ -262,6 +281,12 @@ async def long_finalizer( finally: await asyncio.sleep(delay) + @strawberry.subscription(permission_classes=[ConditionalFailPermission]) + async def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> AsyncGenerator[str, None]: + yield "Hey" # pragma: no cover + class Schema(strawberry.Schema): def process_errors( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 4dbea524f4..fb4541978c 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -13,6 +13,9 @@ from pytest_mock import MockerFixture from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL +from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( + BaseGraphQLTransportWSHandler, +) from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -437,6 +440,28 @@ async def test_subscription_field_errors(ws: WebSocketClient): process_errors.assert_called_once() +async def test_query_field_errors(ws: WebSocketClient): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { notASubscriptionField }", + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") is None + assert response["payload"][0]["locations"] == [{"line": 1, "column": 9}] + assert ( + response["payload"][0]["message"] + == "Cannot query field 'notASubscriptionField' on type 'Query'." + ) + + async def test_subscription_cancellation(ws: WebSocketClient): await ws.send_json( SubscribeMessage( @@ -963,3 +988,137 @@ async def test_no_extensions_results_wont_send_extensions_in_payload( mock.assert_called_once() assert_next(response, "sub1", {"echo": "Hi"}) assert "extensions" not in response["payload"] + + +async def test_validation_query(ws: WebSocketClient): + """ + Test validation for query + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_validation_subscription(ws: WebSocketClient): + """ + Test validation for subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_long_validation_concurrent_query(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + single-result-operation + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first query is stuck in validation + response = await ws.receive_json() + assert_next(response, "sub2", {"conditionalFail": "Hey"}) + + +async def test_long_validation_concurrent_subscription(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first operation is stuck in validation + response = await ws.receive_json() + assert_next(response, "sub2", {"conditionalFail": "Hey"}) + + +async def test_task_error_handler(ws: WebSocketClient): + """ + Test that error handling works + """ + # can't use a simple Event here, because the handler may run + # on a different thread + wakeup = False + + # a replacement method which causes an error in th eTask + async def op(*args: Any, **kwargs: Any): + nonlocal wakeup + wakeup = True + raise ZeroDivisionError("test") + + with patch.object(BaseGraphQLTransportWSHandler, "task_logger") as logger: + with patch.object(BaseGraphQLTransportWSHandler, "handle_operation", op): + # send any old subscription request. It will raise an error + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0) }" + ), + ).as_dict() + ) + + # wait for the error to be logged. Must use timed loop and not event. + while not wakeup: # noqa: ASYNC110 + await asyncio.sleep(0.01) + # and another little bit, for the thread to finish + await asyncio.sleep(0.01) + assert logger.exception.called