From 3445d24c66f388604b7572d0e3b9bc084796b37d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 8 Mar 2023 22:39:22 +0000 Subject: [PATCH 01/24] Add tests for blocking operation validation --- tests/views/schema.py | 25 +++++++ tests/websockets/test_graphql_transport_ws.py | 67 +++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/tests/views/schema.py b/tests/views/schema.py index b0c14bfd76..2adb1e23a0 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -20,6 +20,19 @@ def has_permission(self, source: Any, info: strawberry.Info, **kwargs: Any) -> b return False +class ConditionalFailPermission(BasePermission): + @property + def message(self): + return f"failed after sleep {self.sleep}" + + async def has_permission(self, source, info, **kwargs: Any) -> bool: + self.sleep = kwargs.get("sleep", None) + self.fail = kwargs.get("fail", True) + if self.sleep is not None: + await asyncio.sleep(kwargs["sleep"]) + return not self.fail + + class MyExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} @@ -80,6 +93,12 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str def always_fail(self) -> Optional[str]: return "Hey" + @strawberry.field(permission_classes=[ConditionalFailPermission]) + def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> str: + return "Hey" + @strawberry.field async def error(self, message: str) -> AsyncGenerator[str, None]: yield GraphQLError(message) # type: ignore @@ -262,6 +281,12 @@ async def long_finalizer( finally: await asyncio.sleep(delay) + @strawberry.subscription(permission_classes=[ConditionalFailPermission]) + async def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> AsyncGenerator[str, None]: + yield "Hey" + class Schema(strawberry.Schema): def process_errors( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index f3fd4b74b8..d46f3c7ddc 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -962,3 +962,70 @@ async def test_subscription_errors_continue(ws: WebSocketClient): response = await ws.receive_json() assert response["type"] == CompleteMessage.type assert response["id"] == "sub1" + + +async def test_long_validation_concurrent_query(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + single-result-operation + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first query is stuck in validation + response = await ws.receive_json() + assert ( + response + == NextMessage( + id="sub2", payload={"data": {"conditionalFail": "Hey"}} + ).as_dict() + ) + + +@pytest.mark.xfail +async def test_long_validation_concurrent_subscription(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first operation is stuck in validation + response = await ws.receive_json() + assert ( + response + == NextMessage( + id="sub2", payload={"data": {"conditionalFail": "Hey"}} + ).as_dict() + ) From 2be1bf4c51de9ddd21205802160cab201cdaca27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 8 Apr 2023 15:11:16 +0000 Subject: [PATCH 02/24] Move validation into the task --- .../graphql_transport_ws/handlers.py | 137 ++++++++++-------- tests/websockets/test_graphql_transport_ws.py | 1 - 2 files changed, 80 insertions(+), 58 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 1275ecf304..70f94ef7f8 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -10,6 +10,7 @@ AsyncGenerator, AsyncIterator, Callable, + Coroutine, Dict, List, Optional, @@ -245,57 +246,49 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: root_value = await self.get_root_value() # Get an AsyncGenerator yielding the results - if operation_type == OperationType.SUBSCRIPTION: - result_source = await self.schema.subscribe( - query=message.payload.query, - variable_values=message.payload.variables, - operation_name=message.payload.operationName, - context_value=context, - root_value=root_value, - ) - else: - # create AsyncGenerator returning a single result - async def get_result_source() -> AsyncIterator[ExecutionResult]: - yield await self.schema.execute( # type: ignore + async def start_operation() -> AsyncGenerator[ExecutionResult, None]: + if operation_type == OperationType.SUBSCRIPTION: + return await self.schema.subscribe( + query=message.payload.query, + variable_values=message.payload.variables, + operation_name=message.payload.operationName, + context_value=context, + root_value=root_value, + ) + else: + # single results behave similarly to subscriptions, + # return either a GraphQLExecutionResult or an AsyncGenerator + result = await self.schema.execute( query=message.payload.query, variable_values=message.payload.variables, context_value=context, root_value=root_value, operation_name=message.payload.operationName, ) + if isinstance(result, GraphQLExecutionResult): + return result - result_source = get_result_source() + # create AsyncGenerator returning a single result + async def single_result() -> AsyncIterator[ExecutionResult]: + yield result - operation = Operation(self, message.id, operation_type) - - # Handle initial validation errors - if isinstance(result_source, GraphQLExecutionResult): - assert operation_type == OperationType.SUBSCRIPTION - assert result_source.errors - payload = [err.formatted for err in result_source.errors] - await self.send_message(ErrorMessage(id=message.id, payload=payload)) - self.schema.process_errors(result_source.errors) - return + return single_result() # Create task to handle this subscription, reserve the operation ID - operation.task = asyncio.create_task( - self.operation_task(result_source, operation) - ) + operation = Operation(self, message.id, operation_type, start_operation) + operation.task = asyncio.create_task(self.operation_task(operation)) self.operations[message.id] = operation - async def operation_task( - self, result_source: AsyncGenerator, operation: Operation - ) -> None: - """The operation task's top level method. Cleans-up and de-registers the operation once it is done.""" + async def operation_task(self, operation: Operation) -> None: + """ + The operation task's top level method. Cleans-up and de-registers the operation + once it is done. + """ # TODO: Handle errors in this method using self.handle_task_exception() try: - await self.handle_async_results(result_source, operation) + await self.handle_async_results(operation) except BaseException: # pragma: no cover # cleanup in case of something really unexpected - # wait for generator to be closed to ensure that any existing - # 'finally' statement is called - with suppress(RuntimeError): - await result_source.aclose() if operation.id in self.operations: del self.operations[operation.id] raise @@ -309,30 +302,49 @@ async def operation_task( async def handle_async_results( self, - result_source: AsyncGenerator, operation: Operation, ) -> None: try: - async for result in result_source: - if ( - result.errors - and operation.operation_type != OperationType.SUBSCRIPTION - ): - error_payload = [err.formatted for err in result.errors] - error_message = ErrorMessage(id=operation.id, payload=error_payload) - await operation.send_message(error_message) - # don't need to call schema.process_errors() here because - # it was already done by schema.execute() - return - else: - next_payload = {"data": result.data} - if result.errors: - self.schema.process_errors(result.errors) - next_payload["errors"] = [ - err.formatted for err in result.errors - ] - next_message = NextMessage(id=operation.id, payload=next_payload) - await operation.send_message(next_message) + result_source = await operation.start_operation() + # Handle validation errors + if isinstance(result_source, GraphQLExecutionResult): + assert result_source.errors + payload = [err.formatted for err in result_source.errors] + await operation.send_message( + ErrorMessage(id=operation.id, payload=payload) + ) + self.schema.process_errors(result_source.errors) + return + + try: + async for result in result_source: + if ( + result.errors + and operation.operation_type != OperationType.SUBSCRIPTION + ): + error_payload = [err.formatted for err in result.errors] + error_message = ErrorMessage( + id=operation.id, payload=error_payload + ) + await operation.send_message(error_message) + # don't need to call schema.process_errors() here because + # it was already done by schema.execute() + return + else: + next_payload = {"data": result.data} + if result.errors: + self.schema.process_errors(result.errors) + next_payload["errors"] = [ + err.formatted for err in result.errors + ] + next_message = NextMessage( + id=operation.id, payload=next_payload + ) + await operation.send_message(next_message) + finally: + # Close the AsyncGenerator in case of errors or cancellation + await result_source.aclose() + except Exception as error: # GraphQLErrors are handled by graphql-core and included in the # ExecutionResult @@ -378,17 +390,28 @@ async def reap_completed_tasks(self) -> None: class Operation: """A class encapsulating a single operation with its id. Helps enforce protocol state transition.""" - __slots__ = ["handler", "id", "operation_type", "completed", "task"] + __slots__ = [ + "handler", + "id", + "operation_type", + "start_operation", + "completed", + "task", + ] def __init__( self, handler: BaseGraphQLTransportWSHandler, id: str, operation_type: OperationType, + start_operation: Callable[ + [], Coroutine[Any, Any, AsyncGenerator[ExecutionResult, None]] + ], ) -> None: self.handler = handler self.id = id self.operation_type = operation_type + self.start_operation = start_operation self.completed = False self.task: Optional[asyncio.Task] = None diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index d46f3c7ddc..ee842a0102 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -997,7 +997,6 @@ async def test_long_validation_concurrent_query(ws: WebSocketClient): ) -@pytest.mark.xfail async def test_long_validation_concurrent_subscription(ws: WebSocketClient): """ Test that the websocket is not blocked while validating a From 02772eee50a7446f9a9c010f91b258b7cc79da85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 9 Apr 2023 11:22:47 +0000 Subject: [PATCH 03/24] Add some error/validation test cases --- tests/websockets/test_graphql_transport_ws.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index ee842a0102..60016e06cb 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -407,6 +407,28 @@ async def test_subscription_field_errors(ws: WebSocketClient): process_errors.assert_called_once() +async def test_query_field_errors(ws: WebSocketClient): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { notASubscriptionField }", + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") is None + assert response["payload"][0]["locations"] == [{"line": 1, "column": 9}] + assert ( + response["payload"][0]["message"] + == "Cannot query field 'notASubscriptionField' on type 'Query'." + ) + + async def test_subscription_cancellation(ws: WebSocketClient): await ws.send_json( SubscribeMessage( @@ -964,6 +986,50 @@ async def test_subscription_errors_continue(ws: WebSocketClient): assert response["id"] == "sub1" +async def test_validation_query(ws: WebSocketClient): + """ + Test validation for query + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_validation_subscription(ws: WebSocketClient): + """ + Test validation for subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + async def test_long_validation_concurrent_query(ws: WebSocketClient): """ Test that the websocket is not blocked while validating a From 7687c4e29c39234186d949e208dbc84b36921644 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 9 Apr 2023 12:43:46 +0000 Subject: [PATCH 04/24] Use duck typing to detect an ExecutionResult/GraphQLExeucitonResult --- .../graphql_transport_ws/handlers.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 70f94ef7f8..48b82d0b42 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -14,9 +14,9 @@ Dict, List, Optional, + Union, ) -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, GraphQLSyntaxError, parse from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( @@ -246,7 +246,13 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: root_value = await self.get_root_value() # Get an AsyncGenerator yielding the results - async def start_operation() -> AsyncGenerator[ExecutionResult, None]: + async def start_operation() -> Union[AsyncGenerator[Any, None], Any]: + # there is some type mismatch here which we need to gloss over with + # the use of Any. + # subscribe() returns + # Union[AsyncIterator[graphql.ExecutionResult], graphql.ExecutionResult]: + # whereas execute() returns strawberry.types.ExecutionResult. + # These execution result types are similar, but not the same. if operation_type == OperationType.SUBSCRIPTION: return await self.schema.subscribe( query=message.payload.query, @@ -257,7 +263,7 @@ async def start_operation() -> AsyncGenerator[ExecutionResult, None]: ) else: # single results behave similarly to subscriptions, - # return either a GraphQLExecutionResult or an AsyncGenerator + # return either a ExecutionResult or an AsyncGenerator result = await self.schema.execute( query=message.payload.query, variable_values=message.payload.variables, @@ -265,7 +271,8 @@ async def start_operation() -> AsyncGenerator[ExecutionResult, None]: root_value=root_value, operation_name=message.payload.operationName, ) - if isinstance(result, GraphQLExecutionResult): + # Both validation and execution errors are handled the same way. + if result.errors: return result # create AsyncGenerator returning a single result @@ -286,28 +293,27 @@ async def operation_task(self, operation: Operation) -> None: """ # TODO: Handle errors in this method using self.handle_task_exception() try: - await self.handle_async_results(operation) + await self.handle_operation(operation) except BaseException: # pragma: no cover # cleanup in case of something really unexpected if operation.id in self.operations: del self.operations[operation.id] raise - else: - await operation.send_message(CompleteMessage(id=operation.id)) finally: # add this task to a list to be reaped later task = asyncio.current_task() assert task is not None self.completed_tasks.append(task) - async def handle_async_results( + async def handle_operation( self, operation: Operation, ) -> None: try: result_source = await operation.start_operation() - # Handle validation errors - if isinstance(result_source, GraphQLExecutionResult): + # result_source is an ExcutionResult-like object or an AsyncGenerator + # Handle validation errors. Cannot check type directly. + if hasattr(result_source, "errors"): assert result_source.errors payload = [err.formatted for err in result_source.errors] await operation.send_message( @@ -341,6 +347,7 @@ async def handle_async_results( id=operation.id, payload=next_payload ) await operation.send_message(next_message) + await operation.send_message(CompleteMessage(id=operation.id)) finally: # Close the AsyncGenerator in case of errors or cancellation await result_source.aclose() @@ -405,7 +412,7 @@ def __init__( id: str, operation_type: OperationType, start_operation: Callable[ - [], Coroutine[Any, Any, AsyncGenerator[ExecutionResult, None]] + [], Coroutine[Any, Any, Union[Any, AsyncGenerator[Any, None]]] ], ) -> None: self.handler = handler From 619800790e7d1003985d7f2e6641f06b22d4581f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 9 Apr 2023 14:50:39 +0000 Subject: [PATCH 05/24] add an async context getter for tests which is easily patchable. --- tests/http/clients/aiohttp.py | 4 ++-- tests/http/clients/asgi.py | 4 ++-- tests/http/clients/channels.py | 6 +++--- tests/http/clients/fastapi.py | 4 ++-- tests/http/context.py | 15 +++++++++++++++ 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index cd552e877c..92479f8469 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -17,7 +17,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context_async as get_context from .base import ( JSON, DebuggableGraphQLTransportWSMixin, @@ -50,7 +50,7 @@ async def get_context( ) -> object: context = await super().get_context(request, response) - return get_context(context) + return await get_context(context) async def get_root_value(self, request: web.Request) -> Query: await super().get_root_value(request) # for coverage diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 72d9e95aa6..967f6adcf8 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -18,7 +18,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context_async as get_context from .base import ( JSON, DebuggableGraphQLTransportWSMixin, @@ -56,7 +56,7 @@ async def get_context( ) -> object: context = await super().get_context(request, response) - return get_context(context) + return await get_context(context) async def process_result( self, request: Request, result: ExecutionResult diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index da981403bb..34d143b01e 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -20,7 +20,7 @@ from strawberry.http.typevars import Context, RootValue from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context, get_context_async from .base import ( JSON, HttpClient, @@ -77,7 +77,7 @@ async def get_context(self, *args: str, **kwargs: Any) -> object: context["connectionInitTimeoutTask"] = getattr( self._handler, "connection_init_timeout_task", None ) - for key, val in get_context({}).items(): + for key, val in (await get_context_async({})).items(): context[key] = val return context @@ -95,7 +95,7 @@ async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue] async def get_context(self, request: ChannelsConsumer, response: Any) -> Context: context = await super().get_context(request, response) - return get_context(context) + return await get_context_async(context) async def process_result( self, request: ChannelsConsumer, result: Any diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 1a8148c136..f634bf88e1 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -17,7 +17,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context_async as get_context from .asgi import AsgiWebSocketClient from .base import ( JSON, @@ -50,7 +50,7 @@ async def fastapi_get_context( ws: WebSocket = None, # type: ignore custom_value: str = Depends(custom_context_dependency), ) -> Dict[str, object]: - return get_context( + return await get_context( { "request": request or ws, "background_tasks": background_tasks, diff --git a/tests/http/context.py b/tests/http/context.py index 99985b2434..c1ce5dbecf 100644 --- a/tests/http/context.py +++ b/tests/http/context.py @@ -2,6 +2,21 @@ def get_context(context: object) -> Dict[str, object]: + return get_context_inner(context) + + +# a patchable method for unittests +def get_context_inner(context: object) -> Dict[str, object]: assert isinstance(context, dict) + return {**context, "custom_value": "a value from context"} + +# async version for async frameworks +async def get_context_async(context: object) -> Dict[str, object]: + return await get_context_async_inner(context) + + +# a patchable method for unittests +async def get_context_async_inner(context: object) -> Dict[str, object]: + assert isinstance(context, dict) return {**context, "custom_value": "a value from context"} From 9ef3bf0f00469c963d00485516c76b4a9a7d995e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 9 Apr 2023 13:00:40 +0000 Subject: [PATCH 06/24] Add tests to ensure context_getter does not block connection --- tests/websockets/test_graphql_transport_ws.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 60016e06cb..b6d859121a 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -1094,3 +1094,60 @@ async def test_long_validation_concurrent_subscription(ws: WebSocketClient): id="sub2", payload={"data": {"conditionalFail": "Hey"}} ).as_dict() ) + + +@pytest.mark.xfail +async def test_long_custom_context(ws: WebSocketClient): + """ + Test that the websocket is not blocked evaluating the context + """ + + counter = 0 + + async def slow_get_context(ctxt): + nonlocal counter + old = counter + counter += 1 + if old == 0: + await asyncio.sleep(0.1) + ctxt["custom_value"] = "slow" + else: + ctxt["custom_value"] = "fast" + return ctxt + + with patch("tests.http.context.get_context_async_inner", slow_get_context): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload(query="query { valueFromContext }"), + ).as_dict() + ) + + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { valueFromContext }", + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first operation is stuck getting context + response = await ws.receive_json() + assert ( + response + == NextMessage( + id="sub2", payload={"data": {"valueFromContext": "fast"}} + ).as_dict() + ) + + response = await ws.receive_json() + if response == CompleteMessage(id="sub2").as_dict(): + response = await ws.receive_json() # ignore the complete message + assert ( + response + == NextMessage( + id="sub1", payload={"data": {"valueFromContext": "slow"}} + ).as_dict() + ) From 7cf537efa690fba106c7dcfb64b85a4cb4948eae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 9 Apr 2023 16:31:33 +0000 Subject: [PATCH 07/24] Move context getter and root getter into worker task --- .../graphql_transport_ws/handlers.py | 15 ++++++++------ tests/websockets/test_graphql_transport_ws.py | 20 ++++++++++++++++--- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 48b82d0b42..d706a6025d 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -240,12 +240,9 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: message.payload.variables, ) - context = await self.get_context() - if isinstance(context, dict): - context["connection_params"] = self.connection_params - root_value = await self.get_root_value() - - # Get an AsyncGenerator yielding the results + # The method to start the operation. Will be called on worker + # thread and so may contain long running async calls without + # blocking the main websocket handler. async def start_operation() -> Union[AsyncGenerator[Any, None], Any]: # there is some type mismatch here which we need to gloss over with # the use of Any. @@ -253,6 +250,12 @@ async def start_operation() -> Union[AsyncGenerator[Any, None], Any]: # Union[AsyncIterator[graphql.ExecutionResult], graphql.ExecutionResult]: # whereas execute() returns strawberry.types.ExecutionResult. # These execution result types are similar, but not the same. + + context = await self.get_context() + if isinstance(context, dict): + context["connection_params"] = self.connection_params + root_value = await self.get_root_value() + if operation_type == OperationType.SUBSCRIPTION: return await self.schema.subscribe( query=message.payload.query, diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index b6d859121a..83867ecbe4 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -32,8 +32,19 @@ from tests.http.clients.base import DebuggableGraphQLTransportWSMixin from tests.views.schema import Schema +from ..http.clients.base import WebSocketClient + +try: + from ..http.clients.fastapi import FastAPIHttpClient +except ImportError: + FastAPIHttpClient = None +try: + from ..http.clients.starlite import StarliteHttpClient +except ImportError: + StarliteHttpClient = None + if TYPE_CHECKING: - from ..http.clients.base import HttpClient, WebSocketClient + from ..http.clients.base import HttpClient @pytest_asyncio.fixture @@ -1096,11 +1107,14 @@ async def test_long_validation_concurrent_subscription(ws: WebSocketClient): ) -@pytest.mark.xfail -async def test_long_custom_context(ws: WebSocketClient): +async def test_long_custom_context( + ws: WebSocketClient, http_client_class: Type[HttpClient] +): """ Test that the websocket is not blocked evaluating the context """ + if http_client_class in (FastAPIHttpClient, StarliteHttpClient): + pytest.skip("Client evaluates the context only once per connection") counter = 0 From e41ffe9e30b7130f950aeab5bccb661a838b55c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 8 Jun 2023 14:39:39 +0000 Subject: [PATCH 08/24] Catch top level errors --- .../protocols/graphql_transport_ws/handlers.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index d706a6025d..4f4a8d2497 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -108,7 +108,7 @@ def on_request_accepted(self) -> None: async def handle_connection_init_timeout(self) -> None: task = asyncio.current_task() - assert task + assert task is not None # for typecheckers try: delay = self.connection_init_wait_timeout.total_seconds() await asyncio.sleep(delay=delay) @@ -294,18 +294,20 @@ async def operation_task(self, operation: Operation) -> None: The operation task's top level method. Cleans-up and de-registers the operation once it is done. """ - # TODO: Handle errors in this method using self.handle_task_exception() + task = asyncio.current_task() + assert task is not None # for type checkers try: await self.handle_operation(operation) - except BaseException: # pragma: no cover - # cleanup in case of something really unexpected - if operation.id in self.operations: - del self.operations[operation.id] + except asyncio.CancelledError: raise + except Exception as error: + await self.handle_task_exception(error) + # cleanup in case of something really unexpected finally: # add this task to a list to be reaped later - task = asyncio.current_task() - assert task is not None + if operation.id in self.operations: + del self.operations[operation.id] + # TODO: Stop collecting background tasks, not necessary. self.completed_tasks.append(task) async def handle_operation( From c3d6447c93a8fc725d1c81eef7b801742e9db0b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 9 Jun 2023 10:28:27 +0000 Subject: [PATCH 09/24] Add a test for the task error handler --- .../graphql_transport_ws/handlers.py | 5 ++- tests/websockets/test_graphql_transport_ws.py | 37 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 4f4a8d2497..273cf97f61 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -301,13 +301,14 @@ async def operation_task(self, operation: Operation) -> None: except asyncio.CancelledError: raise except Exception as error: + # Log any unhandled exceptions in the operation task await self.handle_task_exception(error) - # cleanup in case of something really unexpected finally: - # add this task to a list to be reaped later + # Clenaup. Remove the operation from the list of active operations if operation.id in self.operations: del self.operations[operation.id] # TODO: Stop collecting background tasks, not necessary. + # Add this task to a list to be reaped later self.completed_tasks.append(task) async def handle_operation( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 83867ecbe4..11dffa15cd 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -18,6 +18,9 @@ from pytest_mock import MockerFixture from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL +from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( + BaseGraphQLTransportWSHandler, +) from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -1165,3 +1168,37 @@ async def slow_get_context(ctxt): id="sub1", payload={"data": {"valueFromContext": "slow"}} ).as_dict() ) + + +async def test_task_error_handler(ws: WebSocketClient): + """ + Test that error handling works + """ + # can't use a simple Event here, because the handler may run + # on a different thread + wakeup = False + + # a replacement method which causes an error in th eTask + async def op(*args: Any, **kwargs: Any): + nonlocal wakeup + wakeup = True + raise ZeroDivisionError("test") + + with patch.object(BaseGraphQLTransportWSHandler, "task_logger") as logger: + with patch.object(BaseGraphQLTransportWSHandler, "handle_operation", op): + # send any old subscription request. It will raise an error + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0) }" + ), + ).as_dict() + ) + + # wait for the error to be logged + while not wakeup: + await asyncio.sleep(0.01) + # and another little bit, for the thread to finish + await asyncio.sleep(0.01) + assert logger.exception.called From a9589e467e404857c0a2b6ba2e92480b5b58f069 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 9 May 2023 20:10:33 +0000 Subject: [PATCH 10/24] add release.md --- RELEASE.md | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..2820214c74 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: patch + +Operations over `graphql-transport-ws` now create the Context and perform validation on +the worker `Task`, thus not blocking the websocket from accepting messages. From ca229a79e9994317a8c8dd86b049f433d2069400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 30 Jun 2023 14:44:03 +0000 Subject: [PATCH 11/24] Remove dead code, fix coverage --- .../graphql_transport_ws/handlers.py | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 273cf97f61..5055dc7a1c 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -325,34 +325,20 @@ async def handle_operation( await operation.send_message( ErrorMessage(id=operation.id, payload=payload) ) - self.schema.process_errors(result_source.errors) + if operation.operation_type == OperationType.SUBSCRIPTION: + self.schema.process_errors(result_source.errors) return try: async for result in result_source: - if ( - result.errors - and operation.operation_type != OperationType.SUBSCRIPTION - ): - error_payload = [err.formatted for err in result.errors] - error_message = ErrorMessage( - id=operation.id, payload=error_payload - ) - await operation.send_message(error_message) - # don't need to call schema.process_errors() here because - # it was already done by schema.execute() - return - else: - next_payload = {"data": result.data} - if result.errors: - self.schema.process_errors(result.errors) - next_payload["errors"] = [ - err.formatted for err in result.errors - ] - next_message = NextMessage( - id=operation.id, payload=next_payload - ) - await operation.send_message(next_message) + next_payload = {"data": result.data} + if result.errors: + self.schema.process_errors(result.errors) + next_payload["errors"] = [ + err.formatted for err in result.errors + ] + next_message = NextMessage(id=operation.id, payload=next_payload) + await operation.send_message(next_message) await operation.send_message(CompleteMessage(id=operation.id)) finally: # Close the AsyncGenerator in case of errors or cancellation @@ -429,8 +415,9 @@ def __init__( self.task: Optional[asyncio.Task] = None async def send_message(self, message: GraphQLTransportMessage) -> None: + # defensive check, should never happen if self.completed: - return + return # pragma: no cover if isinstance(message, (CompleteMessage, ErrorMessage)): self.completed = True # de-register the operation _before_ sending the final message From 1a62b1a61f33a6911b5c48213b0867ef1dd2eb60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 3 Aug 2023 12:20:18 +0000 Subject: [PATCH 12/24] remove special case for AsyncMock --- tests/websockets/test_graphql_transport_ws.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 11dffa15cd..accccb0aef 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -6,12 +6,7 @@ import time from datetime import timedelta from typing import TYPE_CHECKING, Any, AsyncGenerator, Type -from unittest.mock import Mock, patch - -try: - from unittest.mock import AsyncMock -except ImportError: - AsyncMock = None +from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio @@ -927,9 +922,6 @@ async def test_error_handler_for_timeout(http_client: HttpClient): if isinstance(http_client, ChannelsHttpClient): pytest.skip("Can't patch on_init for this client") - if not AsyncMock: - pytest.skip("Don't have AsyncMock") - ws = ws_raw handler = None errorhandler = AsyncMock() From 76716d04c1a5d933c52bc8a9e65c8e32532915d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 8 Nov 2023 11:25:59 +0000 Subject: [PATCH 13/24] Add "no cover" to schema code which is designed to not be hit. --- tests/views/schema.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/views/schema.py b/tests/views/schema.py index 2adb1e23a0..572ae8f86f 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -91,13 +91,13 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str @strawberry.field(permission_classes=[AlwaysFailPermission]) def always_fail(self) -> Optional[str]: - return "Hey" + return "Hey" # pragma: no cover @strawberry.field(permission_classes=[ConditionalFailPermission]) def conditional_fail( self, sleep: Optional[float] = None, fail: bool = False ) -> str: - return "Hey" + return "Hey" # pragma: no cover @strawberry.field async def error(self, message: str) -> AsyncGenerator[str, None]: @@ -285,7 +285,7 @@ async def long_finalizer( async def conditional_fail( self, sleep: Optional[float] = None, fail: bool = False ) -> AsyncGenerator[str, None]: - yield "Hey" + yield "Hey" # pragma: no cover class Schema(strawberry.Schema): From 1ad929b2fcadbe4a05b446c3f38d923d2942fbd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 31 Mar 2024 11:37:47 +0000 Subject: [PATCH 14/24] Update tests for litestar --- tests/http/clients/litestar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index ccf9999f7f..efd034b62c 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -17,7 +17,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context_async as get_context from .base import ( JSON, DebuggableGraphQLTransportWSMixin, @@ -35,7 +35,7 @@ def custom_context_dependency() -> str: async def litestar_get_context(request: Request = None): - return get_context({"request": request}) + return await get_context({"request": request}) async def get_root_value(request: Request = None): From 8ced4e4c42b3b7a976f797e633e87369a3181068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 31 Mar 2024 11:55:57 +0000 Subject: [PATCH 15/24] Litestar integration must be excluded from long test, like Starlite. --- tests/websockets/test_graphql_transport_ws.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index accccb0aef..e31177a2f8 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -40,6 +40,10 @@ from ..http.clients.starlite import StarliteHttpClient except ImportError: StarliteHttpClient = None +try: + from ..http.clients.litestar import LitestarHttpClient +except ImportError: + LitestarHttpClient = None if TYPE_CHECKING: from ..http.clients.base import HttpClient @@ -1108,7 +1112,7 @@ async def test_long_custom_context( """ Test that the websocket is not blocked evaluating the context """ - if http_client_class in (FastAPIHttpClient, StarliteHttpClient): + if http_client_class in (FastAPIHttpClient, StarliteHttpClient, LitestarHttpClient): pytest.skip("Client evaluates the context only once per connection") counter = 0 From 35a8e68d10e765517d6806b5674132a3de3f8dd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 31 Mar 2024 13:10:14 +0000 Subject: [PATCH 16/24] coverage --- tests/websockets/test_graphql_transport_ws.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index e31177a2f8..7de3c1664e 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -34,15 +34,15 @@ try: from ..http.clients.fastapi import FastAPIHttpClient -except ImportError: +except ImportError: # pragma: no cover FastAPIHttpClient = None try: from ..http.clients.starlite import StarliteHttpClient -except ImportError: +except ImportError: # pragma: no cover StarliteHttpClient = None try: from ..http.clients.litestar import LitestarHttpClient -except ImportError: +except ImportError: # pragma: no cover LitestarHttpClient = None if TYPE_CHECKING: From 4084639be7e1b711fdd450c6a6ac2e237df0f41f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 2 Apr 2024 08:35:26 +0000 Subject: [PATCH 17/24] Mark some test schema methods as no cover since they are not always used --- tests/views/schema.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/views/schema.py b/tests/views/schema.py index 572ae8f86f..d8a4e89107 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -77,7 +77,7 @@ class DebugInfo: @strawberry.type class Query: @strawberry.field - def greetings(self) -> str: + def greetings(self) -> str: # pragma: no cover return "hello" @strawberry.field @@ -108,7 +108,7 @@ async def exception(self, message: str) -> str: raise ValueError(message) @strawberry.field - def teapot(self, info: strawberry.Info[Any, None]) -> str: + def teapot(self, info: strawberry.Info[Any, None]) -> str: # pragma: no cover info.context["response"].status_code = 418 return "🫖" @@ -142,7 +142,7 @@ def set_header(self, info: strawberry.Info, name: str) -> str: @strawberry.type class Mutation: @strawberry.mutation - def echo(self, string_to_echo: str) -> str: + def echo(self, string_to_echo: str) -> str: # pragma: no cover return string_to_echo @strawberry.mutation @@ -162,7 +162,7 @@ def read_folder(self, folder: FolderInput) -> List[str]: return list(map(_read_file, folder.files)) @strawberry.mutation - def match_text(self, text_file: Upload, pattern: str) -> str: + def match_text(self, text_file: Upload, pattern: str) -> str: # pragma: no cover text = text_file.read().decode() return pattern if pattern in text else "" @@ -199,7 +199,7 @@ async def exception(self, message: str) -> AsyncGenerator[str, None]: raise ValueError(message) # Without this yield, the method is not recognised as an async generator - yield "Hi" + yield "Hi" # pragma: no cover @strawberry.subscription async def flavors(self) -> AsyncGenerator[Flavor, None]: From d17a4d41306d184f484aaacfc3a92c97d25c1e26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 7 Sep 2024 12:19:46 +0000 Subject: [PATCH 18/24] Mypy support for SubscriptionExecutionResult --- .../protocols/graphql_transport_ws/handlers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 5055dc7a1c..d22838d506 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -30,6 +30,7 @@ SubscribeMessage, SubscribeMessagePayload, ) +from strawberry.types import ExecutionResult from strawberry.types.graphql import OperationType from strawberry.types.unset import UNSET from strawberry.utils.debug import pretty_print_graphql_operation @@ -42,7 +43,7 @@ from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( GraphQLTransportMessage, ) - from strawberry.types import ExecutionResult + class BaseGraphQLTransportWSHandler(ABC): @@ -274,13 +275,16 @@ async def start_operation() -> Union[AsyncGenerator[Any, None], Any]: root_value=root_value, operation_name=message.payload.operationName, ) + # Note: result may be SubscriptionExecutionResult or ExecutionResult + # now, but we don't support the former properly yet, hence the "ignore" below. + # Both validation and execution errors are handled the same way. - if result.errors: + if isinstance(result, ExecutionResult) and result.errors: return result # create AsyncGenerator returning a single result async def single_result() -> AsyncIterator[ExecutionResult]: - yield result + yield result # type: ignore return single_result() From a1d0695c8ee62c6fae5496be57016f7ea951e5e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 7 Sep 2024 13:18:32 +0000 Subject: [PATCH 19/24] ruff --- .../protocols/graphql_transport_ws/handlers.py | 7 +++---- tests/websockets/test_graphql_transport_ws.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index d22838d506..94a8ca0903 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -43,7 +43,6 @@ from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( GraphQLTransportMessage, ) - class BaseGraphQLTransportWSHandler(ABC): @@ -294,9 +293,9 @@ async def single_result() -> AsyncIterator[ExecutionResult]: self.operations[message.id] = operation async def operation_task(self, operation: Operation) -> None: - """ - The operation task's top level method. Cleans-up and de-registers the operation - once it is done. + """The operation task's top level method. + + Cleans-up and de-registers the operation once it is done. """ task = asyncio.current_task() assert task is not None # for type checkers diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 7de3c1664e..2b583110e1 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -1192,8 +1192,8 @@ async def op(*args: Any, **kwargs: Any): ).as_dict() ) - # wait for the error to be logged - while not wakeup: + # wait for the error to be logged. Must use timed loop and not event. + while not wakeup: # noqa: ASYNC110 await asyncio.sleep(0.01) # and another little bit, for the thread to finish await asyncio.sleep(0.01) From e43aca81dd3eec8aecf52c8906546ede03f550c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 8 Sep 2024 11:49:29 +0000 Subject: [PATCH 20/24] Remove unused method for coverage --- tests/http/clients/litestar.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index efd034b62c..645eb425c3 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -30,10 +30,6 @@ ) -def custom_context_dependency() -> str: - return "Hi!" - - async def litestar_get_context(request: Request = None): return await get_context({"request": request}) From 1be5a06616862cd4411a2d3a990fa6f5e76292c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 13 Oct 2024 19:24:32 +0000 Subject: [PATCH 21/24] revert the handler to original state --- .../graphql_transport_ws/handlers.py | 37 +++++-------------- 1 file changed, 10 insertions(+), 27 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index e8797be929..f74e9def3d 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -8,11 +8,9 @@ Any, Awaitable, Callable, - Coroutine, Dict, List, Optional, - Union, ) from graphql import GraphQLError, GraphQLSyntaxError, parse @@ -105,7 +103,7 @@ def on_request_accepted(self) -> None: async def handle_connection_init_timeout(self) -> None: task = asyncio.current_task() - assert task is not None # for typecheckers + assert task try: delay = self.connection_init_wait_timeout.total_seconds() await asyncio.sleep(delay=delay) @@ -266,13 +264,12 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: operation_name=message.payload.operationName, ) - # create AsyncGenerator returning a single result - async def single_result() -> AsyncIterator[ExecutionResult]: - yield result # type: ignore + operation = Operation(self, message.id, operation_type) # Create task to handle this subscription, reserve the operation ID - operation = Operation(self, message.id, operation_type, start_operation) - operation.task = asyncio.create_task(self.operation_task(operation)) + operation.task = asyncio.create_task( + self.operation_task(result_source, operation) + ) self.operations[message.id] = operation async def operation_task( @@ -302,11 +299,9 @@ async def operation_task( self.operations.pop(operation.id, None) raise finally: - # Clenaup. Remove the operation from the list of active operations - if operation.id in self.operations: - del self.operations[operation.id] - # TODO: Stop collecting background tasks, not necessary. - # Add this task to a list to be reaped later + # add this task to a list to be reaped later + task = asyncio.current_task() + assert task is not None self.completed_tasks.append(task) def forget_id(self, id: str) -> None: @@ -344,35 +339,23 @@ async def reap_completed_tasks(self) -> None: class Operation: """A class encapsulating a single operation with its id. Helps enforce protocol state transition.""" - __slots__ = [ - "handler", - "id", - "operation_type", - "start_operation", - "completed", - "task", - ] + __slots__ = ["handler", "id", "operation_type", "completed", "task"] def __init__( self, handler: BaseGraphQLTransportWSHandler, id: str, operation_type: OperationType, - start_operation: Callable[ - [], Coroutine[Any, Any, Union[Any, AsyncGenerator[Any, None]]] - ], ) -> None: self.handler = handler self.id = id self.operation_type = operation_type - self.start_operation = start_operation self.completed = False self.task: Optional[asyncio.Task] = None async def send_message(self, message: GraphQLTransportMessage) -> None: - # defensive check, should never happen if self.completed: - return # pragma: no cover + return if isinstance(message, (CompleteMessage, ErrorMessage)): self.completed = True # de-register the operation _before_ sending the final message From c074e095c47d0c56c38f3f60a1d3ff7796d6478e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 13 Oct 2024 19:25:05 +0000 Subject: [PATCH 22/24] Remove tests for long contexts --- tests/websockets/test_graphql_transport_ws.py | 74 +------------------ 1 file changed, 2 insertions(+), 72 deletions(-) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index b88ede884e..e4cd0ad109 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -1071,12 +1071,7 @@ async def test_long_validation_concurrent_query(ws: WebSocketClient): # we expect the second query to arrive first, because the # first query is stuck in validation response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"conditionalFail": "Hey"}} - ).as_dict() - ) + assert_next(response, "sub2", {"conditionalFail": "Hey"}) async def test_long_validation_concurrent_subscription(ws: WebSocketClient): @@ -1104,72 +1099,7 @@ async def test_long_validation_concurrent_subscription(ws: WebSocketClient): # we expect the second query to arrive first, because the # first operation is stuck in validation response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"conditionalFail": "Hey"}} - ).as_dict() - ) - - -async def test_long_custom_context( - ws: WebSocketClient, http_client_class: Type[HttpClient] -): - """ - Test that the websocket is not blocked evaluating the context - """ - if http_client_class in (FastAPIHttpClient, StarliteHttpClient, LitestarHttpClient): - pytest.skip("Client evaluates the context only once per connection") - - counter = 0 - - async def slow_get_context(ctxt): - nonlocal counter - old = counter - counter += 1 - if old == 0: - await asyncio.sleep(0.1) - ctxt["custom_value"] = "slow" - else: - ctxt["custom_value"] = "fast" - return ctxt - - with patch("tests.http.context.get_context_async_inner", slow_get_context): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload(query="query { valueFromContext }"), - ).as_dict() - ) - - await ws.send_json( - SubscribeMessage( - id="sub2", - payload=SubscribeMessagePayload( - query="query { valueFromContext }", - ), - ).as_dict() - ) - - # we expect the second query to arrive first, because the - # first operation is stuck getting context - response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"valueFromContext": "fast"}} - ).as_dict() - ) - - response = await ws.receive_json() - if response == CompleteMessage(id="sub2").as_dict(): - response = await ws.receive_json() # ignore the complete message - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"valueFromContext": "slow"}} - ).as_dict() - ) + assert_next(response, "sub2", {"conditionalFail": "Hey"}) async def test_task_error_handler(ws: WebSocketClient): From 373400f1eddf5be24bd3a67a19a633d6ebf941a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 13 Oct 2024 19:28:33 +0000 Subject: [PATCH 23/24] Revert "add an async context getter for tests which is easily patchable." This reverts commit 619800790e7d1003985d7f2e6641f06b22d4581f. --- tests/http/clients/aiohttp.py | 4 ++-- tests/http/clients/asgi.py | 4 ++-- tests/http/clients/channels.py | 4 ++-- tests/http/clients/fastapi.py | 4 ++-- tests/http/clients/litestar.py | 4 ++-- tests/http/context.py | 15 --------------- 6 files changed, 10 insertions(+), 25 deletions(-) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index d817910d2c..89b0c718e8 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -16,7 +16,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context_async as get_context +from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -39,7 +39,7 @@ async def get_context( ) -> object: context = await super().get_context(request, response) - return await get_context(context) + return get_context(context) async def get_root_value(self, request: web.Request) -> Query: await super().get_root_value(request) # for coverage diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 335ec06aa9..7910e02f73 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -17,7 +17,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context_async as get_context +from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -45,7 +45,7 @@ async def get_context( ) -> object: context = await super().get_context(request, response) - return await get_context(context) + return get_context(context) async def process_result( self, request: Request, result: ExecutionResult diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 7e0627572c..bde2364128 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -20,7 +20,7 @@ from strawberry.http.typevars import Context, RootValue from tests.views.schema import Query, schema -from ..context import get_context, get_context_async +from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -78,7 +78,7 @@ async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue] async def get_context(self, request: ChannelsConsumer, response: Any) -> Context: context = await super().get_context(request, response) - return await get_context_async(context) + return get_context(context) async def process_result( self, request: ChannelsConsumer, result: Any diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 4da966ce40..b1b80625fa 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -16,7 +16,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context_async as get_context +from ..context import get_context from .asgi import AsgiWebSocketClient from .base import ( JSON, @@ -39,7 +39,7 @@ async def fastapi_get_context( ws: WebSocket = None, # type: ignore custom_value: str = Depends(custom_context_dependency), ) -> Dict[str, object]: - return await get_context( + return get_context( { "request": request or ws, "background_tasks": background_tasks, diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 661c8fc1ee..b29be52b32 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -16,7 +16,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context_async as get_context +from ..context import get_context from .base import ( JSON, DebuggableGraphQLTransportWSHandler, @@ -30,7 +30,7 @@ async def litestar_get_context(request: Request = None): - return await get_context({"request": request}) + return get_context({"request": request}) async def get_root_value(request: Request = None): diff --git a/tests/http/context.py b/tests/http/context.py index c1ce5dbecf..99985b2434 100644 --- a/tests/http/context.py +++ b/tests/http/context.py @@ -2,21 +2,6 @@ def get_context(context: object) -> Dict[str, object]: - return get_context_inner(context) - - -# a patchable method for unittests -def get_context_inner(context: object) -> Dict[str, object]: assert isinstance(context, dict) - return {**context, "custom_value": "a value from context"} - -# async version for async frameworks -async def get_context_async(context: object) -> Dict[str, object]: - return await get_context_async_inner(context) - - -# a patchable method for unittests -async def get_context_async_inner(context: object) -> Dict[str, object]: - assert isinstance(context, dict) return {**context, "custom_value": "a value from context"} From c4d0b0555cc98ffbb50ea7bd95f40fdcae22b5dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 13 Oct 2024 19:39:47 +0000 Subject: [PATCH 24/24] cleanup --- RELEASE.md | 3 +-- tests/http/clients/litestar.py | 4 ++++ tests/views/schema.py | 14 ++++++------- tests/websockets/test_graphql_transport_ws.py | 20 ++++--------------- 4 files changed, 16 insertions(+), 25 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 2820214c74..8f6812b821 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,3 @@ Release type: patch -Operations over `graphql-transport-ws` now create the Context and perform validation on -the worker `Task`, thus not blocking the websocket from accepting messages. +Fix error handling for query operations over graphql-transport-ws \ No newline at end of file diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index b29be52b32..2548dc563c 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -29,6 +29,10 @@ ) +def custom_context_dependency() -> str: + return "Hi!" + + async def litestar_get_context(request: Request = None): return get_context({"request": request}) diff --git a/tests/views/schema.py b/tests/views/schema.py index d8a4e89107..cf50dfe415 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -77,7 +77,7 @@ class DebugInfo: @strawberry.type class Query: @strawberry.field - def greetings(self) -> str: # pragma: no cover + def greetings(self) -> str: return "hello" @strawberry.field @@ -91,13 +91,13 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str @strawberry.field(permission_classes=[AlwaysFailPermission]) def always_fail(self) -> Optional[str]: - return "Hey" # pragma: no cover + return "Hey" @strawberry.field(permission_classes=[ConditionalFailPermission]) def conditional_fail( self, sleep: Optional[float] = None, fail: bool = False ) -> str: - return "Hey" # pragma: no cover + return "Hey" @strawberry.field async def error(self, message: str) -> AsyncGenerator[str, None]: @@ -108,7 +108,7 @@ async def exception(self, message: str) -> str: raise ValueError(message) @strawberry.field - def teapot(self, info: strawberry.Info[Any, None]) -> str: # pragma: no cover + def teapot(self, info: strawberry.Info[Any, None]) -> str: info.context["response"].status_code = 418 return "🫖" @@ -142,7 +142,7 @@ def set_header(self, info: strawberry.Info, name: str) -> str: @strawberry.type class Mutation: @strawberry.mutation - def echo(self, string_to_echo: str) -> str: # pragma: no cover + def echo(self, string_to_echo: str) -> str: return string_to_echo @strawberry.mutation @@ -162,7 +162,7 @@ def read_folder(self, folder: FolderInput) -> List[str]: return list(map(_read_file, folder.files)) @strawberry.mutation - def match_text(self, text_file: Upload, pattern: str) -> str: # pragma: no cover + def match_text(self, text_file: Upload, pattern: str) -> str: text = text_file.read().decode() return pattern if pattern in text else "" @@ -199,7 +199,7 @@ async def exception(self, message: str) -> AsyncGenerator[str, None]: raise ValueError(message) # Without this yield, the method is not recognised as an async generator - yield "Hi" # pragma: no cover + yield "Hi" @strawberry.subscription async def flavors(self) -> AsyncGenerator[Flavor, None]: diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index e4cd0ad109..fb4541978c 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -30,23 +30,8 @@ from tests.http.clients.base import DebuggableGraphQLTransportWSHandler from tests.views.schema import MyExtension, Schema -from ..http.clients.base import WebSocketClient - -try: - from ..http.clients.fastapi import FastAPIHttpClient -except ImportError: # pragma: no cover - FastAPIHttpClient = None -try: - from ..http.clients.starlite import StarliteHttpClient -except ImportError: # pragma: no cover - StarliteHttpClient = None -try: - from ..http.clients.litestar import LitestarHttpClient -except ImportError: # pragma: no cover - LitestarHttpClient = None - if TYPE_CHECKING: - from ..http.clients.base import HttpClient + from ..http.clients.base import HttpClient, WebSocketClient @pytest_asyncio.fixture @@ -913,6 +898,9 @@ async def test_error_handler_for_timeout(http_client: HttpClient): if isinstance(http_client, ChannelsHttpClient): pytest.skip("Can't patch on_init for this client") + if not AsyncMock: + pytest.skip("Don't have AsyncMock") + ws = ws_raw handler = None errorhandler = AsyncMock()