Skip to content

Commit

Permalink
Refactor, run tests on trio
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Jun 13, 2020
1 parent 47384a6 commit 927f471
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 138 deletions.
280 changes: 151 additions & 129 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import typing
from typing import Callable, Dict, List, Optional, Tuple
import contextlib
from typing import (
TYPE_CHECKING,
Union,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
AsyncIterator,
)

import httpcore
import sniffio

from .._content_streams import AsyncIteratorStream, ByteStream
from .._utils import warn_deprecated

if typing.TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
import asyncio
import trio

Event = typing.Union[asyncio.Event, trio.Event]
Event = Union[asyncio.Event, trio.Event]


def create_event() -> "Event":
Expand All @@ -25,19 +35,18 @@ def create_event() -> "Event":
return asyncio.Event()


async def create_background_task(async_fn: typing.Callable) -> typing.Callable:
async def create_background_task(
async_fn: Callable[[], Awaitable[None]]
) -> Callable[[], Awaitable[None]]:
if sniffio.current_async_library() == "trio":
import trio

nursery_manager = trio.open_nursery()
nursery = await nursery_manager.__aenter__()
nursery.start_soon(async_fn)

async def aclose(exc: Exception = None) -> None:
if exc is not None:
await nursery_manager.__aexit__(type(exc), exc, exc.__traceback__)
else:
await nursery_manager.__aexit__(None, None, None)
async def aclose() -> None:
await nursery_manager.__aexit__(None, None, None)

return aclose

Expand All @@ -47,52 +56,63 @@ async def aclose(exc: Exception = None) -> None:
loop = asyncio.get_event_loop()
task = loop.create_task(async_fn())

async def aclose(exc: Exception = None) -> None:
if not task.done():
task.cancel()
async def aclose() -> None:
task.cancel()
# Task must be awaited in all cases to avoid debug warnings.
with contextlib.suppress(asyncio.CancelledError):
await task

return aclose


def create_channel(
capacity: int,
) -> typing.Tuple[
typing.Callable[[], typing.Awaitable[bytes]],
typing.Callable[[bytes], typing.Awaitable[None]],
) -> Tuple[
Callable[[bytes], Awaitable[None]],
Callable[[], Awaitable[None]],
Callable[[], AsyncIterator[bytes]],
]:
"""
Create an in-memory channel to pass data chunks between tasks.
* `produce()`: send data through the channel, blocking if necessary.
* `consume()`: iterate over data in the channel.
* `aclose_produce()`: mark that no more data will be produced, causing
`consume()` to flush remaining data chunks then stop.
"""
if sniffio.current_async_library() == "trio":
import trio

send_channel, receive_channel = trio.open_memory_channel[bytes](capacity)
return receive_channel.receive, send_channel.send

async def consume() -> AsyncIterator[bytes]:
async for chunk in receive_channel:
yield chunk

return send_channel.send, send_channel.aclose, consume

else:
import asyncio

queue: asyncio.Queue[bytes] = asyncio.Queue(capacity)
return queue.get, queue.put


async def run_until_first_complete(*async_fns: typing.Callable) -> None:
if sniffio.current_async_library() == "trio":
import trio
produce_closed = False

async with trio.open_nursery() as nursery:
async def produce(chunk: bytes) -> None:
assert not produce_closed
await queue.put(chunk)

async def run(async_fn: typing.Callable) -> None:
await async_fn()
nursery.cancel_scope.cancel()
async def aclose_produce() -> None:
nonlocal produce_closed
await queue.put(b"") # Make sure (*) doesn't block forever.
produce_closed = True

for async_fn in async_fns:
nursery.start_soon(run, async_fn)

else:
import asyncio
async def consume() -> AsyncIterator[bytes]:
while True:
if produce_closed and queue.empty():
break
yield await queue.get() # (*)

coros = [async_fn() for async_fn in async_fns]
done, pending = await asyncio.wait(coros, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
return produce, aclose_produce, consume


class ASGITransport(httpcore.AsyncHTTPTransport):
Expand Down Expand Up @@ -149,6 +169,7 @@ async def request(
timeout: Dict[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
scheme, host, port, full_path = url
headers = [] if headers is None else headers
path, _, query = full_path.partition(b"?")
scope = {
"type": "http",
Expand All @@ -163,107 +184,108 @@ async def request(
"client": self.client,
"root_path": self.root_path,
}
status_code = None
response_headers = None
consume_response_body_chunk, produce_response_body_chunk = create_channel(1)
request_complete = False
response_started = create_event()
response_complete = create_event()
app_crashed = create_event()
app_exception: typing.Optional[Exception] = None

headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream

request_body_chunks = stream.__aiter__()

async def receive() -> dict:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers

if message["type"] == "http.response.start":
assert not response_started.is_set()

status_code = message["status"]
response_headers = message.get("headers", [])
response_started.set()

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and method != b"HEAD":
await produce_response_body_chunk(body)

if not more_body:
response_complete.set()

async def run_app() -> None:
nonlocal app_exception
try:
await self.app(scope, receive, send)
except Exception as exc:
app_exception = exc
app_crashed.set()

aclose_app = await create_background_task(run_app)

await run_until_first_complete(app_crashed.wait, response_started.wait)

if app_crashed.is_set():
assert app_exception is not None
await aclose_app(app_exception)
if self.raise_app_exceptions or not response_started.is_set():
raise app_exception

assert response_started.is_set()
assert status_code is not None
assert response_headers is not None

async def aiter_response_body_chunks() -> typing.AsyncIterator[bytes]:
chunk = b""

async def consume_chunk() -> None:
nonlocal chunk
chunk = await consume_response_body_chunk()

while True:
await run_until_first_complete(
app_crashed.wait, consume_chunk, response_complete.wait
)
responder = ASGIResponder(
self.app,
raise_app_exceptions=self.raise_app_exceptions,
scope=scope,
request_body=stream.__aiter__(),
)

if app_crashed.is_set():
assert app_exception is not None
if self.raise_app_exceptions:
raise app_exception
else:
break
return await responder()

yield chunk

if response_complete.is_set():
break
class ASGIResponder:
def __init__(
self,
app: Callable,
raise_app_exceptions: bool,
scope: dict,
request_body: AsyncIterator[bytes],
) -> None:
self._app = app
self._raise_app_exceptions = raise_app_exceptions

# Request.
self._scope = scope
self._request_complete = False
self._request_body = request_body

# Response.
self._response_headers: Optional[List[Tuple[bytes, bytes]]] = None
self._status_code: Optional[int] = None
self._response_started_or_app_crashed = create_event()
self._produce_body, self._aclose_body, self._consume_body = create_channel(1)
self._response_complete = create_event()

# Error handling.
self._app_exception: Optional[Exception] = None

async def _receive(self) -> dict:
if self._request_complete:
await self._response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await self._request_body.__anext__()
except StopAsyncIteration:
self._request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}

return {"type": "http.request", "body": body, "more_body": True}

async def _send(self, message: dict) -> None:
if message["type"] == "http.response.start":
# App is sending the response headers.
assert not self._response_started_or_app_crashed.is_set()
self._status_code = message["status"]
self._response_headers = message.get("headers", [])
self._response_started_or_app_crashed.set()

elif message["type"] == "http.response.body":
# App is sending a chunk of the response body.
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and self._scope["method"] != b"HEAD":
await self._produce_body(body)

if not more_body:
await self._aclose_body()
self._response_complete.set()

async def _run_app(self) -> None:
try:
await self._app(self._scope, self._receive, self._send)
except Exception as exc:
self._app_exception = exc
self._response_started_or_app_crashed.set()
await self._aclose_body() # Stop response body consumer once flushed (*).

async def _aiter_response_body(self) -> AsyncIterator[bytes]:
async for chunk in self._consume_body(): # (*)
yield chunk

if self._app_exception is not None and self._raise_app_exceptions:
raise self._app_exception

async def __call__(
self,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
aclose = await create_background_task(self._run_app)

async def aclose() -> None:
await aclose_app(app_exception)
await self._response_started_or_app_crashed.wait()

stream = AsyncIteratorStream(aiter_response_body_chunks(), close_func=aclose)
if self._app_exception is not None:
await aclose()
if self._raise_app_exceptions:
raise self._app_exception

return (b"HTTP/1.1", status_code, b"", response_headers, stream)
assert self._status_code is not None
assert self._response_headers is not None
stream = AsyncIteratorStream(self._aiter_response_body(), close_func=aclose)
return (b"HTTP/1.1", self._status_code, b"", self._response_headers, stream)


class ASGIDispatch(ASGITransport):
Expand Down
Loading

0 comments on commit 927f471

Please sign in to comment.