From 16015df4bd27cdeeef9ef9e9a59de6f94e82b25a Mon Sep 17 00:00:00 2001 From: Nir <88795475+nrbnlulu@users.noreply.github.com> Date: Tue, 21 Feb 2023 14:41:52 +0200 Subject: [PATCH] fix mypy. --- .../protocols/graphql_transport_ws/handlers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 10b9dcfb04..687d07c524 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -3,7 +3,7 @@ import asyncio from abc import ABC, abstractmethod from contextlib import suppress -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional from graphql import GraphQLError, GraphQLSyntaxError, parse from graphql.error.graphql_error import format_error as format_graphql_error @@ -29,9 +29,11 @@ from datetime import timedelta from strawberry.schema import Schema + from strawberry.schema.subscribe import Subscription from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( GraphQLTransportMessage, ) + from strawberry.types.execution import ExecutionResult class BaseGraphQLTransportWSHandler(ABC): @@ -47,7 +49,7 @@ def __init__( self.connection_init_timeout_task: Optional[asyncio.Task] = None self.connection_init_received = False self.connection_acknowledged = False - self.subscriptions: Dict[str, AsyncGenerator] = {} + self.subscriptions: Dict[str, Subscription] = {} self.tasks: Dict[str, asyncio.Task] = {} self.completed_tasks: List[asyncio.Task] = [] self.connection_params: Optional[Dict[str, Any]] = None @@ -220,7 +222,7 @@ async def get_result_source(): ) async def operation_task( - self, result_source: AsyncGenerator, operation_id: str + self, result_source: AsyncIterator[ExecutionResult], operation_id: str ) -> None: """ Operation task top level method. Cleans up and de-registers the operation @@ -252,7 +254,7 @@ async def operation_task( async def handle_async_results( self, - result_source: AsyncGenerator, + result_source: AsyncIterator[ExecutionResult], operation_id: str, ) -> None: try: