Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support streaming response in asgi transport with asyncio #994

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import typing
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, cast

import httpcore
import sniffio

from .._content_streams import ByteStream
from .._content_streams import AsyncIteratorStream, ByteStream, ContentStream
from .._utils import warn_deprecated

if typing.TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -96,17 +96,17 @@ async def request(
status_code = None
response_headers = None
body_parts = []
body_parts_event = create_event()
request_complete = False
response_started = False
response_started = create_event()
response_complete = create_event()

headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream

request_body_chunks = stream.__aiter__()

async def receive() -> dict:
nonlocal request_complete, response_complete
nonlocal request_complete

if request_complete:
await response_complete.wait()
Expand All @@ -120,15 +120,14 @@ async def receive() -> dict:
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers, body_parts
nonlocal response_started, response_complete
nonlocal status_code, response_headers

if message["type"] == "http.response.start":
assert not response_started
assert not response_started.is_set()

status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True
response_started.set()

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
Expand All @@ -137,23 +136,77 @@ async def send(message: dict) -> None:

if body and method != b"HEAD":
body_parts.append(body)
body_parts_event.set()

if not more_body:
response_complete.set()

try:
await self.app(scope, receive, send)
except Exception:
def handle_exception(ex: Exception) -> None:
if self.raise_app_exceptions or not response_complete:
raise
raise ex from None

response_stream: ContentStream
if sniffio.current_async_library() == "asyncio":
import asyncio

# Tasks need to be created to run the coroutines in the background
loop = asyncio.get_event_loop()
app_task = loop.create_task(self.app(scope, receive, send))
response_task = loop.create_task(response_started.wait())
done, pending = await asyncio.wait(
[app_task, response_task], return_when=asyncio.FIRST_COMPLETED
)

if response_task in done:

async def response_generator() -> typing.AsyncIterator[bytes]:
while True:
if body_parts:
# Body parts immediately available, yield and continue
yield body_parts.pop(0)
continue
# Wait for either more body parts or request to finish
body_parts_task = loop.create_task(body_parts_event.wait())
done, pending = await asyncio.wait(
[app_task, body_parts_task],
return_when=asyncio.FIRST_COMPLETED,
)
if app_task in done and body_parts_task in pending:
# Application finished and no more body parts available
body_parts_task.cancel()
break
cast(asyncio.Event, body_parts_event).clear()
try:
# Make sure the application task is joined before finishing
await app_task
except Exception as ex:
handle_exception(ex)

assert response_started.is_set()
response_stream = AsyncIteratorStream(response_generator())
else:
# Application finished in the middle of making the request
response_task.cancel()
try:
await app_task
except Exception as ex:
handle_exception(ex)

assert response_complete.is_set()
response_stream = ByteStream(b"".join(body_parts))
else:
try:
await self.app(scope, receive, send)
except Exception as ex:
handle_exception(ex)

assert response_complete.is_set()
response_stream = ByteStream(b"".join(body_parts))

assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

stream = ByteStream(b"".join(body_parts))

return (b"HTTP/1.1", status_code, b"", response_headers, stream)
return (b"HTTP/1.1", status_code, b"", response_headers, response_stream)


class ASGIDispatch(ASGITransport):
Expand Down