diff --git a/tests/test_asgi_csrf.py b/tests/test_asgi_csrf.py index 3616871..110ceff 100644 --- a/tests/test_asgi_csrf.py +++ b/tests/test_asgi_csrf.py @@ -85,7 +85,9 @@ def csrftoken(): @pytest.mark.asyncio async def test_hello_world_app(): - async with httpx.AsyncClient(app=hello_world_app) as client: + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=hello_world_app) + ) as client: response = await client.get("http://localhost/") assert b'{"hello":"world"}' == response.content @@ -112,7 +114,7 @@ def _get_secret_key(app): @pytest.mark.asyncio async def test_asgi_csrf_sets_cookie(app_csrf): - async with httpx.AsyncClient(app=app_csrf) as client: + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app_csrf)) as client: response = await client.get("http://localhost/") assert b'{"hello":"world"}' == response.content assert "csrftoken" in response.cookies @@ -122,7 +124,7 @@ async def test_asgi_csrf_sets_cookie(app_csrf): @pytest.mark.asyncio async def test_asgi_csrf_modifies_existing_vary_header(app_csrf): - async with httpx.AsyncClient(app=app_csrf) as client: + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app_csrf)) as client: response = await client.get("http://localhost/?_vary=User-Agent") assert b'{"hello":"world"}' == response.content assert "csrftoken" in response.cookies @@ -132,7 +134,7 @@ async def test_asgi_csrf_modifies_existing_vary_header(app_csrf): @pytest.mark.asyncio async def test_asgi_csrf_sets_no_cookie_or_vary_if_page_has_no_form(app_csrf): - async with httpx.AsyncClient(app=app_csrf) as client: + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app_csrf)) as client: response = await client.get("http://localhost/static") assert b'{"hello":"world","static":true}' == response.content assert "csrftoken" not in response.cookies @@ -141,46 +143,36 @@ async def test_asgi_csrf_sets_no_cookie_or_vary_if_page_has_no_form(app_csrf): @pytest.mark.asyncio async def test_vary_header_only_if_page_contains_csrftoken(app_csrf, csrftoken): - async with httpx.AsyncClient(app=app_csrf) as client: - assert ( - "vary" - in ( - await client.get("http://localhost/", cookies={"csrftoken": csrftoken}) - ).headers - ) - assert ( - "vary" - not in ( - await client.get( - "http://localhost/?_no_token=1", cookies={"csrftoken": csrftoken} - ) - ).headers - ) + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app_csrf), cookies={"csrftoken": csrftoken} + ) as client: + assert "vary" in (await client.get("http://localhost/")).headers + assert "vary" not in (await client.get("http://localhost/?_no_token=1")).headers @pytest.mark.asyncio async def test_headers_passed_through_correctly(app_csrf): - async with httpx.AsyncClient(app=app_csrf) as client: + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app_csrf)) as client: response = await client.get("http://localhost/static") assert "application/json" == response.headers["content-type"] @pytest.mark.asyncio async def test_asgi_csrf_does_not_set_cookie_if_one_sent(app_csrf, csrftoken): - async with httpx.AsyncClient(app=app_csrf) as client: - response = await client.get( - "http://localhost/", cookies={"csrftoken": csrftoken} - ) + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app_csrf), cookies={"csrftoken": csrftoken} + ) as client: + response = await client.get("http://localhost/") assert b'{"hello":"world"}' == response.content assert "csrftoken" not in response.cookies @pytest.mark.asyncio async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken): - async with httpx.AsyncClient(app=app_csrf) as client: - response = await client.post( - "http://localhost/", cookies={"csrftoken": csrftoken} - ) + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app_csrf), cookies={"csrftoken": csrftoken} + ) as client: + response = await client.post("http://localhost/") assert 403 == response.status_code @@ -190,11 +182,13 @@ async def test_prevents_post_if_cookie_not_sent_in_post( custom_errors, app_csrf, app_csrf_custom_errors, csrftoken ): async with httpx.AsyncClient( - app=app_csrf_custom_errors if custom_errors else app_csrf + transport=httpx.ASGITransport( + app=app_csrf_custom_errors if custom_errors else app_csrf + ), + cookies={"csrftoken": csrftoken}, ) as client: response = await client.post( "http://localhost/", - cookies={"csrftoken": csrftoken}, data={"csrftoken": csrftoken[-1]}, ) assert 403 == response.status_code @@ -207,11 +201,12 @@ async def test_prevents_post_if_cookie_not_sent_in_post( @pytest.mark.asyncio async def test_allows_post_if_cookie_duplicated_in_header(app_csrf, csrftoken): - async with httpx.AsyncClient(app=app_csrf) as client: + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app_csrf), cookies={"csrftoken": csrftoken} + ) as client: response = await client.post( "http://localhost/", headers={"x-csrftoken": csrftoken}, - cookies={"csrftoken": csrftoken}, ) assert 200 == response.status_code @@ -219,12 +214,14 @@ async def test_allows_post_if_cookie_duplicated_in_header(app_csrf, csrftoken): @pytest.mark.asyncio async def test_allows_post_if_cookie_duplicated_in_post_data(csrftoken): async with httpx.AsyncClient( - app=asgi_csrf(hello_world_app, signing_secret=SECRET) + transport=httpx.ASGITransport( + app=asgi_csrf(hello_world_app, signing_secret=SECRET) + ), + cookies={"csrftoken": csrftoken}, ) as client: response = await client.post( "http://localhost/", data={"csrftoken": csrftoken, "hello": "world"}, - cookies={"csrftoken": csrftoken}, ) assert 200 == response.status_code assert {"csrftoken": csrftoken, "hello": "world"} == json.loads(response.content) @@ -233,13 +230,15 @@ async def test_allows_post_if_cookie_duplicated_in_post_data(csrftoken): @pytest.mark.asyncio async def test_multipart(csrftoken): async with httpx.AsyncClient( - app=asgi_csrf(hello_world_app, signing_secret=SECRET) + transport=httpx.ASGITransport( + app=asgi_csrf(hello_world_app, signing_secret=SECRET) + ), + cookies={"csrftoken": csrftoken}, ) as client: response = await client.post( "http://localhost/", data={"csrftoken": csrftoken}, files={"csv": ("data.csv", "blah,foo\n1,2", "text/csv")}, - cookies={"csrftoken": csrftoken}, ) assert response.status_code == 200 assert response.json() == {"csrftoken": csrftoken, "csv": "blah,foo\n1,2"} @@ -249,17 +248,19 @@ async def test_multipart(csrftoken): @pytest.mark.parametrize("custom_errors", (False, True)) async def test_multipart_failure_wrong_token(csrftoken, custom_errors): async with httpx.AsyncClient( - app=asgi_csrf( - hello_world_app, - signing_secret=SECRET, - send_csrf_failed=custom_csrf_failed if custom_errors else None, - ) + transport=httpx.ASGITransport( + app=asgi_csrf( + hello_world_app, + signing_secret=SECRET, + send_csrf_failed=custom_csrf_failed if custom_errors else None, + ) + ), + cookies={"csrftoken": csrftoken[:-1]}, ) as client: response = await client.post( "http://localhost/", data={"csrftoken": csrftoken}, files={"csv": ("data.csv", "blah,foo\n1,2", "text/csv")}, - cookies={"csrftoken": csrftoken[:-1]}, ) assert response.status_code == 403 assert ( @@ -279,17 +280,19 @@ def __bool__(self): @pytest.mark.parametrize("custom_errors", (False, True)) async def test_multipart_failure_missing_token(csrftoken, custom_errors): async with httpx.AsyncClient( - app=asgi_csrf( - hello_world_app, - signing_secret=SECRET, - send_csrf_failed=custom_csrf_failed if custom_errors else None, - ) + transport=httpx.ASGITransport( + app=asgi_csrf( + hello_world_app, + signing_secret=SECRET, + send_csrf_failed=custom_csrf_failed if custom_errors else None, + ) + ), + cookies={"csrftoken": csrftoken}, ) as client: response = await client.post( "http://localhost/", data={"foo": "bar"}, files=TrickEmptyDictionary(), - cookies={"csrftoken": csrftoken}, ) assert response.status_code == 403 assert response.text == ( @@ -303,10 +306,12 @@ async def test_multipart_failure_missing_token(csrftoken, custom_errors): @pytest.mark.parametrize("custom_errors", (False, True)) async def test_multipart_failure_file_comes_before_token(csrftoken, custom_errors): async with httpx.AsyncClient( - app=asgi_csrf( - hello_world_app, - signing_secret=SECRET, - send_csrf_failed=custom_csrf_failed if custom_errors else None, + transport=httpx.ASGITransport( + app=asgi_csrf( + hello_world_app, + signing_secret=SECRET, + send_csrf_failed=custom_csrf_failed if custom_errors else None, + ) ) ) as client: request = httpx.Request( @@ -342,12 +347,14 @@ async def test_multipart_failure_file_comes_before_token(csrftoken, custom_error ) async def test_post_with_authorization(authorization, expected_status): async with httpx.AsyncClient( - app=asgi_csrf(hello_world_app, signing_secret=SECRET) + transport=httpx.ASGITransport( + app=asgi_csrf(hello_world_app, signing_secret=SECRET) + ), + cookies={"foo": "bar"}, ) as client: response = await client.post( "http://localhost/", headers={"Authorization": authorization}, - cookies={"foo": "bar"}, ) assert expected_status == response.status_code @@ -366,9 +373,14 @@ async def test_no_cookies_skips_check_unless_path_required( cookies, path, expected_status ): async with httpx.AsyncClient( - app=asgi_csrf(hello_world_app, signing_secret=SECRET, always_protect={"/login"}) + transport=httpx.ASGITransport( + app=asgi_csrf( + hello_world_app, signing_secret=SECRET, always_protect={"/login"} + ) + ), + cookies=cookies, ) as client: - response = await client.post("http://localhost{}".format(path), cookies=cookies) + response = await client.post("http://localhost{}".format(path)) assert expected_status == response.status_code @@ -386,13 +398,16 @@ async def test_no_cookies_skips_check_unless_path_required( ) async def test_skip_if_scope(cookies, path, expected_status): async with httpx.AsyncClient( - app=asgi_csrf( - hello_world_app, - signing_secret=SECRET, - skip_if_scope=lambda scope: scope["path"].startswith("/api/"), - ) + transport=httpx.ASGITransport( + app=asgi_csrf( + hello_world_app, + signing_secret=SECRET, + skip_if_scope=lambda scope: scope["path"].startswith("/api/"), + ) + ), + cookies=cookies, ) as client: - response = await client.post("http://localhost{}".format(path), cookies=cookies) + response = await client.post("http://localhost{}".format(path)) assert expected_status == response.status_code @@ -400,8 +415,12 @@ async def test_skip_if_scope(cookies, path, expected_status): @pytest.mark.parametrize("always_set_cookie", [True, False]) async def test_always_set_cookie(always_set_cookie): async with httpx.AsyncClient( - app=asgi_csrf( - hello_world_app, signing_secret=SECRET, always_set_cookie=always_set_cookie + transport=httpx.ASGITransport( + app=asgi_csrf( + hello_world_app, + signing_secret=SECRET, + always_set_cookie=always_set_cookie, + ) ) ) as client: response = await client.get("http://localhost/static") @@ -415,13 +434,18 @@ async def test_always_set_cookie(always_set_cookie): @pytest.mark.asyncio @pytest.mark.parametrize("send_csrftoken_cookie", [True, False]) async def test_always_set_cookie_unless_cookie_is_set(send_csrftoken_cookie, csrftoken): + cookies = {} + if send_csrftoken_cookie: + cookies["csrftoken"] = csrftoken async with httpx.AsyncClient( - app=asgi_csrf(hello_world_app, signing_secret=SECRET, always_set_cookie=True) + transport=httpx.ASGITransport( + app=asgi_csrf( + hello_world_app, signing_secret=SECRET, always_set_cookie=True + ) + ), + cookies=cookies, ) as client: - cookies = {} - if send_csrftoken_cookie: - cookies["csrftoken"] = csrftoken - response = await client.get("http://localhost/static", cookies=cookies) + response = await client.get("http://localhost/static") assert 200 == response.status_code if send_csrftoken_cookie: assert "csrftoken" not in response.cookies @@ -433,11 +457,13 @@ async def test_always_set_cookie_unless_cookie_is_set(send_csrftoken_cookie, csr async def test_asgi_lifespan(): app = asgi_csrf(hello_world_app, signing_secret=SECRET) async with LifespanManager(app): - async with httpx.AsyncClient(app=app) as client: + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + cookies={"foo": "bar"}, + ) as client: response = await client.post( "http://localhost/", headers={"Authorization": "Bearer xxx"}, - cookies={"foo": "bar"}, ) assert 200 == response.status_code @@ -452,7 +478,9 @@ async def test_cookie_name(cookie_name): hello_world_app, signing_secret="secret", cookie_name=cookie_name ) transport = httpx.ASGITransport(app=wrapped_app) - async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: response = await client.get("http://testserver/") assert cookie_name in response.cookies @@ -464,7 +492,9 @@ async def test_cookie_path(cookie_path): hello_world_app, signing_secret="secret", cookie_path=cookie_path ) transport = httpx.ASGITransport(app=wrapped_app) - async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: response = await client.get("http://testserver/") assert f"Path={cookie_path}" in response.headers["set-cookie"] @@ -476,7 +506,9 @@ async def test_cookie_domain(cookie_domain): hello_world_app, signing_secret="secret", cookie_domain=cookie_domain ) transport = httpx.ASGITransport(app=wrapped_app) - async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: response = await client.get("http://testserver/") if cookie_domain: assert f"Domain={cookie_domain}" in response.headers["set-cookie"] @@ -491,7 +523,9 @@ async def test_cookie_secure(cookie_secure): hello_world_app, signing_secret="secret", cookie_secure=cookie_secure ) transport = httpx.ASGITransport(app=wrapped_app) - async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: response = await client.get("http://testserver/") if cookie_secure: assert "Secure" in response.headers["set-cookie"] @@ -506,7 +540,9 @@ async def test_cookie_samesite(cookie_samesite): hello_world_app, signing_secret="secret", cookie_samesite=cookie_samesite ) transport = httpx.ASGITransport(app=wrapped_app) - async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: response = await client.get("http://testserver/") assert f"SameSite={cookie_samesite}" in response.headers["set-cookie"] @@ -515,7 +551,9 @@ async def test_cookie_samesite(cookie_samesite): async def test_default_cookie_options(): wrapped_app = asgi_csrf(hello_world_app, signing_secret="secret") transport = httpx.ASGITransport(app=wrapped_app) - async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: response = await client.get("http://testserver/") set_cookie = response.headers["set-cookie"] assert "csrftoken" in set_cookie