Skip to content

Commit

Permalink
Fix ASGIMiddleware Receive (#59)
Browse files Browse the repository at this point in the history
* -

* -

* add baize.asgi test
  • Loading branch information
abersheeran authored Jun 26, 2024
1 parent aaf1733 commit dd66cda
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 34 deletions.
61 changes: 39 additions & 22 deletions a2wsgi/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,26 +152,28 @@ def __init__(
self.wait_time = wait_time

self.sync_event = SyncEvent()
self.async_event = AsyncEvent(loop)
self.async_lock: asyncio.Lock
self.sync_event_set_lock: asyncio.Lock

self.receive_event = AsyncEvent(loop)
self.send_event = AsyncEvent(loop)

def _init_async_lock():
self.async_lock = asyncio.Lock()
self.sync_event_set_lock = asyncio.Lock()

loop.call_soon_threadsafe(_init_async_lock)

self.asgi_done = threading.Event()
self.wsgi_should_stop: bool = False

async def asgi_receive(self) -> ReceiveEvent:
async with self.async_lock:
self.sync_event.set({"type": "receive"})
return await self.async_event.wait()
await self.sync_event_set_lock.acquire()
self.sync_event.set({"type": "receive"})
return await self.receive_event.wait()

async def asgi_send(self, message: SendEvent) -> None:
async with self.async_lock:
self.sync_event.set(message)
await self.async_event.wait()
await self.sync_event_set_lock.acquire()
self.sync_event.set(message)
await self.send_event.wait()

def asgi_done_callback(self, future: asyncio.Future) -> None:
try:
Expand Down Expand Up @@ -209,13 +211,16 @@ def __call__(
read_count: int = 0
body = environ["wsgi.input"] or BytesIO()
content_length = int(environ.get("CONTENT_LENGTH", None) or 0)
receive_eof = False
body_sent = False

asgi_task = self.start_asgi_app(environ)
# activate loop
self.loop.call_soon_threadsafe(lambda: None)

while True:
message = self.sync_event.wait()
self.loop.call_soon_threadsafe(self.sync_event_set_lock.release)
message_type = message["type"]

if message_type == "http.response.start":
Expand All @@ -230,13 +235,21 @@ def __call__(
],
None,
)
self.send_event.set(None)
elif message_type == "http.response.body":
yield message.get("body", b"")
body_sent = True
self.wsgi_should_stop = not message.get("more_body", False)
self.send_event.set(None)
elif message_type == "http.response.disconnect":
self.wsgi_should_stop = True
self.send_event.set(None)
# ASGI application error
elif message_type == "a2wsgi.error":
if body_sent:
raise message["exception"][1].with_traceback(
message["exception"][2]
)
start_response(
"500 Internal Server Error",
[
Expand All @@ -248,23 +261,27 @@ def __call__(
yield b"Server got itself in trouble"
self.wsgi_should_stop = True
elif message_type == "receive":
pass
else:
raise RuntimeError(f"Unknown message type: {message_type}")

if message_type == "receive":
read_size = min(65536, content_length - read_count)
data: bytes = body.read(read_size)
read_count += len(data)
more_body = read_count < content_length
self.async_event.set(
{"type": "http.request", "body": data, "more_body": more_body}
)
if read_size == 0: # No more body, so don't read anymore
if not receive_eof:
self.receive_event.set(
{"type": "http.request", "body": b"", "more_body": False}
)
receive_eof = True
else:
pass # let `await receive()` wait
else:
data: bytes = body.read(read_size)
read_count += len(data)
more_body = read_count < content_length
self.receive_event.set(
{"type": "http.request", "body": data, "more_body": more_body}
)
else:
self.async_event.set(None)
raise RuntimeError(f"Unknown message type: {message_type}")

if self.wsgi_should_stop:
self.async_event.set_nowait()
self.receive_event.set({"type": "http.disconnect"})
break

if asgi_task.done():
Expand Down
40 changes: 38 additions & 2 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 8 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
[project]
authors = [{ name = "abersheeran", email = "[email protected]" }]
classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
"typing_extensions; python_version<'3.11'"
]
dependencies = ["typing_extensions; python_version<'3.11'"]
description = "Convert WSGI app to ASGI app or ASGI app to WSGI app."
license = { text = "Apache-2.0" }
name = "a2wsgi"
Expand All @@ -20,13 +18,17 @@ dev = [
"asgiref<4.0.0,>=3.2.7",
"black",
"flake8",
"pytest<8.0.0,>=7.0.1",
"pytest-cov<4.0.0,>=3.0.0",
"pytest-asyncio<1.0.0,>=0.11.0",
"mypy",
"httpx<1.0.0,>=0.22.0",
]
benchmark = ["uvicorn>=0.16.0", "asgiref>=3.4.1"]
test = [
"pytest<8.0.0,>=7.0.1",
"pytest-cov<4.0.0,>=3.0.0",
"pytest-asyncio<1.0.0,>=0.11.0",
"starlette>=0.37.2",
"baize>=0.20.8",
]

[tool.pdm.build]
includes = ["a2wsgi"]
Expand Down
54 changes: 50 additions & 4 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ async def hello_world(scope, receive, send):
"status": 200,
"headers": [
[b"content-type", b"text/plain"],
[b"content-length", b"13"],
],
}
)
await send(
{"type": "http.response.body", "body": b"Hello, world!", "more_body": True}
)
await send({"type": "http.response.disconnect"})
await send({"type": "http.response.body", "body": b"Hello, world!"})


async def echo_body(scope, receive, send):
Expand Down Expand Up @@ -194,3 +192,51 @@ def test_http_content_headers():
counter = Counter(scope["headers"])
assert counter[(b"content-type", content_type.encode())] == 1
assert counter[(b"content-length", content_length.encode())] == 1


def test_starlette_stream_response():
from starlette.responses import StreamingResponse

app = ASGIMiddleware(StreamingResponse(content=map(str, range(10))))
with httpx.Client(
transport=httpx.WSGITransport(app=app), base_url="http://testserver:80"
) as client:
response = client.get("/")
assert response.status_code == 200
assert response.text == "0123456789"


def test_starlette_base_http_middleware():
from starlette.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

class Middleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers["x-middleware"] = "true"
return response

app = ASGIMiddleware(Middleware(JSONResponse({"hello": "world"})))
with httpx.Client(
transport=httpx.WSGITransport(app=app), base_url="http://testserver:80"
) as client:
response = client.get("/")
assert response.status_code == 200
assert response.text == '{"hello":"world"}'
assert response.headers["x-middleware"] == "true"


def test_baize_stream_response():
from baize.asgi import StreamResponse

async def stream():
for i in range(10):
yield str(i).encode()

app = ASGIMiddleware(StreamResponse(stream()))
with httpx.Client(
transport=httpx.WSGITransport(app=app), base_url="http://testserver:80"
) as client:
response = client.get("/")
assert response.status_code == 200
assert response.text == "0123456789"

0 comments on commit dd66cda

Please sign in to comment.