From a4b4cc5784e739492048aac1579b8a26fdcbac30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 31 May 2023 01:06:35 +0000 Subject: [PATCH] Use async schema extensions only for integrations which support it --- tests/http/clients/aiohttp.py | 3 +- tests/http/clients/asgi.py | 3 +- tests/http/clients/channels.py | 2 +- tests/http/clients/fastapi.py | 3 +- tests/http/clients/starlite.py | 3 +- tests/views/schema.py | 46 ++++++++++++++++++- tests/websockets/test_graphql_transport_ws.py | 6 +-- 7 files changed, 56 insertions(+), 10 deletions(-) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 454b7b26b5..f7d42a2f56 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -14,7 +14,8 @@ from strawberry.aiohttp.views import GraphQLView as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 8207cd4a05..0930303d1a 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -15,7 +15,8 @@ from strawberry.asgi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 54925dbde1..1e354fde18 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -8,7 +8,7 @@ from channels.testing import WebsocketCommunicator from strawberry.channels import GraphQLWSConsumer -from tests.views.schema import schema +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index bf5d33bf7b..c488263563 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -14,7 +14,8 @@ from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .asgi import AsgiWebSocketClient diff --git a/tests/http/clients/starlite.py b/tests/http/clients/starlite.py index 0ba88dc6ff..d56d4e8a17 100644 --- a/tests/http/clients/starlite.py +++ b/tests/http/clients/starlite.py @@ -14,7 +14,8 @@ from strawberry.starlite import make_graphql_controller from strawberry.starlite.controller import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/views/schema.py b/tests/views/schema.py index 1d27285037..2713d05f18 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -39,6 +39,41 @@ class MyExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} + def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): + self.resolve_called() + return _next(root, info, *args, **kwargs) + + def resolve_called(self): + pass + + def lifecycle_called(self, event, phase): + pass + + def on_operation(self): + self.lifecycle_called("operation", "before") + yield + self.lifecycle_called("operation", "after") + + def on_validate(self): + self.lifecycle_called("validate", "before") + yield + self.lifecycle_called("validate", "after") + + def on_parse(self): + self.lifecycle_called("parse", "before") + yield + self.lifecycle_called("parse", "after") + + def on_execute(self): + self.lifecycle_called("execute", "before") + yield + self.lifecycle_called("execute", "after") + + +class MyAsyncExtension(SchemaExtension): + def get_results(self) -> Dict[str, str]: + return {"example": "example"} + async def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): self.resolve_called() result = _next(root, info, *args, **kwargs) @@ -62,12 +97,12 @@ def on_validate(self): yield self.lifecycle_called("validate", "after") - async def on_parse(self): + def on_parse(self): self.lifecycle_called("parse", "before") yield self.lifecycle_called("parse", "after") - def on_execute(self): + async def on_execute(self): self.lifecycle_called("execute", "before") yield self.lifecycle_called("execute", "after") @@ -301,3 +336,10 @@ async def conditional_fail( subscription=Subscription, extensions=[MyExtension], ) + +async_schema = strawberry.Schema( + query=Query, + mutation=Mutation, + subscription=Subscription, + extensions=[MyAsyncExtension], +) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index b578bbe6fd..6ca2a0d1df 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -32,7 +32,7 @@ from tests.http.clients.base import DebuggableGraphQLTransportWSMixin from ..http.clients import HttpClient, WebSocketClient -from ..views.schema import MyExtension +from ..views.schema import MyAsyncExtension @pytest_asyncio.fixture @@ -856,8 +856,8 @@ async def test_extensions(ws: WebSocketClient): resolve_called = Mock() lifecycle_called = Mock() - with patch.object(MyExtension, "resolve_called", resolve_called): - with patch.object(MyExtension, "lifecycle_called", lifecycle_called): + with patch.object(MyAsyncExtension, "resolve_called", resolve_called): + with patch.object(MyAsyncExtension, "lifecycle_called", lifecycle_called): await ws.send_json( SubscribeMessage( id="sub1",