diff --git a/asgi_csrf.py b/asgi_csrf.py index 16a6f66..dfd9e71 100644 --- a/asgi_csrf.py +++ b/asgi_csrf.py @@ -1,14 +1,22 @@ from http.cookies import SimpleCookie -from enum import Enum -import fnmatch +from collections.abc import Awaitable, Callable, Container, Mapping, MutableMapping +from enum import IntEnum from functools import wraps from multipart import FormParser import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from urllib.parse import parse_qsl from itsdangerous.url_safe import URLSafeSerializer from itsdangerous import BadSignature import secrets +if TYPE_CHECKING: + Scope = MutableMapping[str, Any] + Message = MutableMapping[str, Any] + Receive = Callable[[], Awaitable[Message]] + Send = Callable[[Message], Awaitable[None]] + ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] + DEFAULT_COOKIE_NAME = "csrftoken" DEFAULT_COOKIE_PATH = "/" DEFAULT_COOKIE_DOMAIN = None @@ -21,7 +29,8 @@ ENV_SECRET = "ASGI_CSRF_SECRET" -class Errors(Enum): + +class Errors(IntEnum): FORM_URLENCODED_MISMATCH = 1 MULTIPART_MISMATCH = 2 FILE_BEFORE_TOKEN = 3 @@ -37,20 +46,20 @@ class Errors(Enum): def asgi_csrf_decorator( - cookie_name=DEFAULT_COOKIE_NAME, - http_header=DEFAULT_HTTP_HEADER, - form_input=DEFAULT_FORM_INPUT, - signing_secret=None, - signing_namespace=DEFAULT_SIGNING_NAMESPACE, - always_protect=None, - always_set_cookie=False, - skip_if_scope=None, - cookie_path=DEFAULT_COOKIE_PATH, - cookie_domain=DEFAULT_COOKIE_DOMAIN, - cookie_secure=DEFAULT_COOKIE_SECURE, - cookie_samesite=DEFAULT_COOKIE_SAMESITE, - send_csrf_failed=None, -): + cookie_name: str = DEFAULT_COOKIE_NAME, + http_header: str = DEFAULT_HTTP_HEADER, + form_input: str = DEFAULT_FORM_INPUT, + signing_secret: Optional[str] = None, + signing_namespace: str = DEFAULT_SIGNING_NAMESPACE, + always_protect: Optional[Container[str]] = None, + always_set_cookie: bool = False, + skip_if_scope: Optional[Callable[["Scope"], bool]] = None, + cookie_path: str = DEFAULT_COOKIE_PATH, + cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN, + cookie_secure: bool = DEFAULT_COOKIE_SECURE, + cookie_samesite: str = DEFAULT_COOKIE_SAMESITE, + send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None, +) -> Callable[["ASGIApp"], "ASGIApp"]: send_csrf_failed = send_csrf_failed or default_send_csrf_failed if signing_secret is None: signing_secret = os.environ.get(ENV_SECRET, None) @@ -58,9 +67,9 @@ def asgi_csrf_decorator( signing_secret = make_secret(128) signer = URLSafeSerializer(signing_secret) - def _asgi_csrf_decorator(app): + def _asgi_csrf_decorator(app: "ASGIApp") -> "ASGIApp": @wraps(app) - async def app_wrapped_with_csrf(scope, receive, send): + async def app_wrapped_with_csrf(scope: "Scope", receive: "Receive", send: "Send") -> None: if scope["type"] != "http": await app(scope, receive, send) return @@ -84,7 +93,7 @@ async def app_wrapped_with_csrf(scope, receive, send): if not has_csrftoken_cookie: csrftoken = signer.dumps(make_secret(16), signing_namespace) - def get_csrftoken(): + def get_csrftoken() -> Optional[str]: nonlocal should_set_cookie nonlocal page_needs_vary_header page_needs_vary_header = True @@ -94,7 +103,7 @@ def get_csrftoken(): scope = {**scope, **{SCOPE_KEY: get_csrftoken}} - async def wrapped_send(event): + async def wrapped_send(event: "Message") -> None: if event["type"] == "http.response.start": original_headers = event.get("headers") or [] new_headers = [] @@ -144,7 +153,7 @@ async def wrapped_send(event): await app(scope, receive, wrapped_send) else: # Check for CSRF token in various places - headers = dict(scope.get("headers" or [])) + headers = dict(scope.get("headers") or []) if secrets.compare_digest( headers.get(http_header.encode("latin-1"), b"").decode("latin-1"), csrftoken, @@ -226,7 +235,7 @@ async def wrapped_send(event): return _asgi_csrf_decorator -async def _parse_form_urlencoded(receive): +async def _parse_form_urlencoded(receive: "Receive") -> Tuple[Dict[str, str], "Receive"]: # Returns {key: value}, replay_receive # where replay_receive is an awaitable that can replay what was received # We ignore cases like foo=one&foo=two because we do not need to @@ -241,7 +250,7 @@ async def _parse_form_urlencoded(receive): body += message.get("body", b"") more_body = message.get("more_body", False) - async def replay_receive(): + async def replay_receive() -> "Message": if messages: return messages.pop(0) else: @@ -262,7 +271,7 @@ class FileBeforeToken(Exception): pass -async def _parse_multipart_form_data(boundary, receive): +async def _parse_multipart_form_data(boundary: bytes, receive: "Receive") -> Tuple[Optional[str], "Receive"]: # Returns (csrftoken, replay_receive) - or raises an exception csrftoken = None @@ -272,17 +281,17 @@ def on_field(field): raise TokenFound(csrftoken) class ErrorOnWrite: - def __init__(self, file_name, field_name, config): + def __init__(self, file_name: bytes | None, field_name: bytes | None, config: Mapping[str, Any]) -> None: pass - def write(self, data): + def write(self, data: bytes) -> int: raise FileBeforeToken body = b"" more_body = True - messages = [] + messages: List["Message"] = [] - async def replay_receive(): + async def replay_receive() -> "Message": if messages: return messages.pop(0) else: @@ -308,7 +317,7 @@ async def replay_receive(): return None, replay_receive -async def default_send_csrf_failed(scope, send, message_id): +async def default_send_csrf_failed(scope: "Scope", send: "Send", message_id: int) -> None: assert scope["type"] == "http" await send( { @@ -323,19 +332,19 @@ async def default_send_csrf_failed(scope, send, message_id): def asgi_csrf( app, - cookie_name=DEFAULT_COOKIE_NAME, - http_header=DEFAULT_HTTP_HEADER, - signing_secret=None, - signing_namespace=DEFAULT_SIGNING_NAMESPACE, - always_protect=None, - always_set_cookie=False, - skip_if_scope=None, - cookie_path=DEFAULT_COOKIE_PATH, - cookie_domain=DEFAULT_COOKIE_DOMAIN, - cookie_secure=DEFAULT_COOKIE_SECURE, - cookie_samesite=DEFAULT_COOKIE_SAMESITE, - send_csrf_failed=None, -): + cookie_name: str = DEFAULT_COOKIE_NAME, + http_header: str = DEFAULT_HTTP_HEADER, + signing_secret: Optional[str] = None, + signing_namespace: str = DEFAULT_SIGNING_NAMESPACE, + always_protect: Optional[Container[str]] = None, + always_set_cookie: bool = False, + skip_if_scope: Optional[Callable[["Scope"], bool]] = None, + cookie_path: str = DEFAULT_COOKIE_PATH, + cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN, + cookie_secure: bool = DEFAULT_COOKIE_SECURE, + cookie_samesite: str = DEFAULT_COOKIE_SAMESITE, + send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None, +) -> "ASGIApp": return asgi_csrf_decorator( cookie_name, http_header, @@ -352,7 +361,7 @@ def asgi_csrf( )(app) -def cookies_from_scope(scope): +def cookies_from_scope(scope: "Scope") -> Dict[str, str]: cookie = dict(scope.get("headers") or {}).get(b"cookie") if not cookie: return {} @@ -364,5 +373,5 @@ def cookies_from_scope(scope): allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -def make_secret(length): +def make_secret(length: int) -> str: return "".join(secrets.choice(allowed_chars) for i in range(length)) diff --git a/test_asgi_csrf.py b/test_asgi_csrf.py index fe69060..d09780b 100644 --- a/test_asgi_csrf.py +++ b/test_asgi_csrf.py @@ -186,7 +186,7 @@ async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken): @pytest.mark.asyncio @pytest.mark.parametrize("custom_errors", (False, True)) -async def test_prevents_post_if_cookie_not_sent_in_post( +async def test_prevents_post_if_cookie_different_than_data( custom_errors, app_csrf, app_csrf_custom_errors, csrftoken ): async with httpx.AsyncClient(