Skip to content

Commit

Permalink
Add unit test to verify that subscription AsyncGen is closed
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Apr 20, 2023
1 parent 67dfa8c commit b7255b3
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/aiohttp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ async def get_context(self) -> object:
context["ws"] = self._ws
context["tasks"] = self.tasks
context["connectionInitTimeoutTask"] = self.connection_init_timeout_task
context["handler"] = self
return context


Expand All @@ -19,6 +20,7 @@ async def get_context(self) -> object:
context["ws"] = self._ws
context["tasks"] = self.tasks
context["connectionInitTimeoutTask"] = None
context["handler"] = self
return context


Expand Down
2 changes: 2 additions & 0 deletions tests/http/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ async def get_context(self) -> object:
context["ws"] = self._ws
context["tasks"] = self.tasks
context["connectionInitTimeoutTask"] = self.connection_init_timeout_task
context["handler"] = self
return context


Expand All @@ -257,4 +258,5 @@ async def get_context(self) -> object:
context["ws"] = self._ws
context["tasks"] = self.tasks
context["connectionInitTimeoutTask"] = None
context["handler"] = self
return context
1 change: 1 addition & 0 deletions tests/http/clients/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class DebuggableGraphQLTransportWSConsumer(GraphQLWSConsumer):
async def get_context(self, *args, **kwargs) -> object:
context = await super().get_context(*args, **kwargs)
context.tasks = self._handler.tasks
context.handler = self._handler
context.connectionInitTimeoutTask = getattr(
self._handler, "connection_init_timeout_task", None
)
Expand Down
16 changes: 16 additions & 0 deletions tests/views/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def set_header(self, info: Info[Any, Any], name: str) -> str:

return name

@strawberry.field
def finalizer_state(self, info: Info[Any, Any]) -> bool:
return getattr(info.context["handler"], "_debug_finalizer", False)


@strawberry.type
class Mutation:
Expand Down Expand Up @@ -224,6 +228,18 @@ async def connection_params(
) -> AsyncGenerator[str, None]:
yield info.context["connection_params"]["strawberry"]

@strawberry.subscription
async def finalizer(self, info: Info[Any, Any]) -> AsyncGenerator[str, None]:
info.context["handler"]._debug_finalizer = True
try:
i = 0
while True:
yield f"finalizer {i}"
i += 1
await asyncio.sleep(0.01)
finally:
info.context["handler"]._debug_finalizer = False


schema = strawberry.Schema(
query=Query,
Expand Down
72 changes: 72 additions & 0 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import time
from datetime import timedelta
from typing import AsyncGenerator, Type

Expand Down Expand Up @@ -746,3 +747,74 @@ async def test_rejects_connection_params_not_unset(ws_raw: WebSocketClient):
assert ws.closed
assert ws.close_code == 4400
ws.assert_reason("Invalid connection init payload")


async def test_subscription_finializer_called(ws: WebSocketClient):
# Test that a subscription is proptly finalized when the client interrupts the
# subscription
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(query="subscription { finalizer }"),
).as_dict()
)

response = await ws.receive_json()
assert (
response
== NextMessage(
id="sub1", payload={"data": {"finalizer": "finalizer 0"}}
).as_dict()
)

# check that the context is live, finalize hasn't been called yet.
# finalizer_state is set True to when subscription is running, then back
# to False when the subscription is finalized
await ws.send_json(
SubscribeMessage(
id="sub2",
payload=SubscribeMessagePayload(query="query { finalizerState }"),
).as_dict()
)

# wait for the sub2 message
while True:
response = await ws.receive_json()
assert response["type"] in ["next", "complete"]
if response["type"] != "next" or response["id"] != "sub2":
continue
assert response["payload"]["data"]["finalizerState"] is True
break

# cancel the subscription
await ws.send_json(CompleteMessage(id="sub1").as_dict())

# wait until context is dead.
# We don't know exactly how many packets will arrive or how long it will take.
# Need manual timeout because async timeout doesn't work on async integrations
async def wait_for_finalize():
counter = 0
start = time.time()
while True:
now = time.time()
if now - start > 1:
raise TimeoutError("Timeout waiting for finalizer to be called")
counter += 1
id = f"check{counter}"
await ws.send_json(
SubscribeMessage(
id=id,
payload=SubscribeMessagePayload(query="query { finalizerState }"),
).as_dict()
)
# wait for the response showing that finalizer state is back to false
while True:
response = await ws.receive_json()
assert response["type"] in ["next", "complete"]
if response["type"] != "next" or response["id"] != id:
continue
if response["payload"]["data"]["finalizerState"] is False:
return
break

await asyncio.wait_for(wait_for_finalize(), timeout=1)

0 comments on commit b7255b3

Please sign in to comment.