diff --git a/asgi_csrf.py b/asgi_csrf.py index f49e415..cbe8d9e 100644 --- a/asgi_csrf.py +++ b/asgi_csrf.py @@ -101,8 +101,7 @@ async def _parse_form_urlencoded(receive): more_body = message.get("more_body", False) async def replay_receive(): - for message in messages: - yield message + return messages.pop() return dict(parse_qsl(body.decode("utf-8"))), replay_receive diff --git a/setup.py b/setup.py index a099f8f..30f8967 100644 --- a/setup.py +++ b/setup.py @@ -22,5 +22,5 @@ def get_long_description(): license="Apache License, Version 2.0", version=VERSION, py_modules=["asgi_csrf"], - extras_require={"test": ["pytest", "pytest-asyncio", "httpx",]}, + extras_require={"test": ["pytest", "pytest-asyncio", "httpx", "starlette"]}, ) diff --git a/test_asgi_csrf.py b/test_asgi_csrf.py index ac1b66f..91646a7 100644 --- a/test_asgi_csrf.py +++ b/test_asgi_csrf.py @@ -1,3 +1,6 @@ +from starlette.applications import Starlette +from starlette.responses import JSONResponse +from starlette.routing import Route from asgi_csrf import asgi_csrf import httpx import pytest @@ -5,6 +8,15 @@ CSRF_TOKEN = "9izX9q37XP9knNNQ" +async def hello_world(request): + if request.method == "POST": + return JSONResponse(dict(await request.form())) + return JSONResponse({"hello": "world"}) + + +hello_world_app = Starlette(routes=[Route("/", hello_world),]) + + async def hello_world_app(scope, receive, send): assert scope["type"] == "http" await send(