From 4dd8c5de79df2ee56959dba7bc20bbcdbb4d833d Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 13 May 2021 10:31:10 -0500 Subject: [PATCH] Refactor anyio.sleep into an event --- starlette/testclient.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index f4a960a3f..4de2e46a2 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -169,17 +169,16 @@ def send( request_complete = False response_started = False - response_complete = False + response_complete: anyio.Event raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] template = None context = None async def receive() -> Message: - nonlocal request_complete, response_complete + nonlocal request_complete if request_complete: - while not response_complete: - await anyio.sleep(0.0001) + await response_complete.wait() return {"type": "http.disconnect"} body = request.body @@ -203,7 +202,7 @@ async def receive() -> Message: return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: - nonlocal raw_kwargs, response_started, response_complete, template, context + nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert ( @@ -225,7 +224,7 @@ async def send(message: Message) -> None: response_started ), 'Received "http.response.body" without "http.response.start".' assert ( - not response_complete + not response_complete.is_set() ), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) @@ -233,13 +232,14 @@ async def send(message: Message) -> None: raw_kwargs["body"].write(body) if not more_body: raw_kwargs["body"].seek(0) - response_complete = True + response_complete.set() elif message["type"] == "http.response.template": template = message["template"] context = message["context"] try: with anyio.start_blocking_portal(**self.async_backend) as portal: + response_complete = portal.call(anyio.Event) portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: