Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: Add module type hints #33

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 54 additions & 45 deletions asgi_csrf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from http.cookies import SimpleCookie
from enum import Enum
import fnmatch
from collections.abc import Awaitable, Callable, Container, Mapping, MutableMapping
from enum import IntEnum
from functools import wraps
from multipart import FormParser
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from urllib.parse import parse_qsl
from itsdangerous.url_safe import URLSafeSerializer
from itsdangerous import BadSignature
import secrets

if TYPE_CHECKING:
Scope = MutableMapping[str, Any]
Message = MutableMapping[str, Any]
Receive = Callable[[], Awaitable[Message]]
Send = Callable[[Message], Awaitable[None]]
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]

DEFAULT_COOKIE_NAME = "csrftoken"
DEFAULT_COOKIE_PATH = "/"
DEFAULT_COOKIE_DOMAIN = None
Expand All @@ -21,7 +29,8 @@
ENV_SECRET = "ASGI_CSRF_SECRET"


class Errors(Enum):

class Errors(IntEnum):
FORM_URLENCODED_MISMATCH = 1
MULTIPART_MISMATCH = 2
FILE_BEFORE_TOKEN = 3
Expand All @@ -37,30 +46,30 @@ class Errors(Enum):


def asgi_csrf_decorator(
cookie_name=DEFAULT_COOKIE_NAME,
http_header=DEFAULT_HTTP_HEADER,
form_input=DEFAULT_FORM_INPUT,
signing_secret=None,
signing_namespace=DEFAULT_SIGNING_NAMESPACE,
always_protect=None,
always_set_cookie=False,
skip_if_scope=None,
cookie_path=DEFAULT_COOKIE_PATH,
cookie_domain=DEFAULT_COOKIE_DOMAIN,
cookie_secure=DEFAULT_COOKIE_SECURE,
cookie_samesite=DEFAULT_COOKIE_SAMESITE,
send_csrf_failed=None,
):
cookie_name: str = DEFAULT_COOKIE_NAME,
http_header: str = DEFAULT_HTTP_HEADER,
form_input: str = DEFAULT_FORM_INPUT,
signing_secret: Optional[str] = None,
signing_namespace: str = DEFAULT_SIGNING_NAMESPACE,
always_protect: Optional[Container[str]] = None,
always_set_cookie: bool = False,
skip_if_scope: Optional[Callable[["Scope"], bool]] = None,
cookie_path: str = DEFAULT_COOKIE_PATH,
cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN,
cookie_secure: bool = DEFAULT_COOKIE_SECURE,
cookie_samesite: str = DEFAULT_COOKIE_SAMESITE,
send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None,
) -> Callable[["ASGIApp"], "ASGIApp"]:
send_csrf_failed = send_csrf_failed or default_send_csrf_failed
if signing_secret is None:
signing_secret = os.environ.get(ENV_SECRET, None)
if signing_secret is None:
signing_secret = make_secret(128)
signer = URLSafeSerializer(signing_secret)

def _asgi_csrf_decorator(app):
def _asgi_csrf_decorator(app: "ASGIApp") -> "ASGIApp":
@wraps(app)
async def app_wrapped_with_csrf(scope, receive, send):
async def app_wrapped_with_csrf(scope: "Scope", receive: "Receive", send: "Send") -> None:
if scope["type"] != "http":
await app(scope, receive, send)
return
Expand All @@ -84,7 +93,7 @@ async def app_wrapped_with_csrf(scope, receive, send):
if not has_csrftoken_cookie:
csrftoken = signer.dumps(make_secret(16), signing_namespace)

def get_csrftoken():
def get_csrftoken() -> Optional[str]:
nonlocal should_set_cookie
nonlocal page_needs_vary_header
page_needs_vary_header = True
Expand All @@ -94,7 +103,7 @@ def get_csrftoken():

scope = {**scope, **{SCOPE_KEY: get_csrftoken}}

async def wrapped_send(event):
async def wrapped_send(event: "Message") -> None:
if event["type"] == "http.response.start":
original_headers = event.get("headers") or []
new_headers = []
Expand Down Expand Up @@ -144,7 +153,7 @@ async def wrapped_send(event):
await app(scope, receive, wrapped_send)
else:
# Check for CSRF token in various places
headers = dict(scope.get("headers" or []))
headers = dict(scope.get("headers") or [])
if secrets.compare_digest(
headers.get(http_header.encode("latin-1"), b"").decode("latin-1"),
csrftoken,
Expand Down Expand Up @@ -226,7 +235,7 @@ async def wrapped_send(event):
return _asgi_csrf_decorator


async def _parse_form_urlencoded(receive):
async def _parse_form_urlencoded(receive: "Receive") -> Tuple[Dict[str, str], "Receive"]:
# Returns {key: value}, replay_receive
# where replay_receive is an awaitable that can replay what was received
# We ignore cases like foo=one&foo=two because we do not need to
Expand All @@ -241,7 +250,7 @@ async def _parse_form_urlencoded(receive):
body += message.get("body", b"")
more_body = message.get("more_body", False)

async def replay_receive():
async def replay_receive() -> "Message":
if messages:
return messages.pop(0)
else:
Expand All @@ -262,7 +271,7 @@ class FileBeforeToken(Exception):
pass


async def _parse_multipart_form_data(boundary, receive):
async def _parse_multipart_form_data(boundary: bytes, receive: "Receive") -> Tuple[Optional[str], "Receive"]:
# Returns (csrftoken, replay_receive) - or raises an exception
csrftoken = None

Expand All @@ -272,17 +281,17 @@ def on_field(field):
raise TokenFound(csrftoken)

class ErrorOnWrite:
def __init__(self, file_name, field_name, config):
def __init__(self, file_name: bytes | None, field_name: bytes | None, config: Mapping[str, Any]) -> None:
pass

def write(self, data):
def write(self, data: bytes) -> int:
raise FileBeforeToken

body = b""
more_body = True
messages = []
messages: List["Message"] = []

async def replay_receive():
async def replay_receive() -> "Message":
if messages:
return messages.pop(0)
else:
Expand All @@ -308,7 +317,7 @@ async def replay_receive():
return None, replay_receive


async def default_send_csrf_failed(scope, send, message_id):
async def default_send_csrf_failed(scope: "Scope", send: "Send", message_id: int) -> None:
assert scope["type"] == "http"
await send(
{
Expand All @@ -323,19 +332,19 @@ async def default_send_csrf_failed(scope, send, message_id):

def asgi_csrf(
app,
cookie_name=DEFAULT_COOKIE_NAME,
http_header=DEFAULT_HTTP_HEADER,
signing_secret=None,
signing_namespace=DEFAULT_SIGNING_NAMESPACE,
always_protect=None,
always_set_cookie=False,
skip_if_scope=None,
cookie_path=DEFAULT_COOKIE_PATH,
cookie_domain=DEFAULT_COOKIE_DOMAIN,
cookie_secure=DEFAULT_COOKIE_SECURE,
cookie_samesite=DEFAULT_COOKIE_SAMESITE,
send_csrf_failed=None,
):
cookie_name: str = DEFAULT_COOKIE_NAME,
http_header: str = DEFAULT_HTTP_HEADER,
signing_secret: Optional[str] = None,
signing_namespace: str = DEFAULT_SIGNING_NAMESPACE,
always_protect: Optional[Container[str]] = None,
always_set_cookie: bool = False,
skip_if_scope: Optional[Callable[["Scope"], bool]] = None,
cookie_path: str = DEFAULT_COOKIE_PATH,
cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN,
cookie_secure: bool = DEFAULT_COOKIE_SECURE,
cookie_samesite: str = DEFAULT_COOKIE_SAMESITE,
send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None,
) -> "ASGIApp":
return asgi_csrf_decorator(
cookie_name,
http_header,
Expand All @@ -352,7 +361,7 @@ def asgi_csrf(
)(app)


def cookies_from_scope(scope):
def cookies_from_scope(scope: "Scope") -> Dict[str, str]:
cookie = dict(scope.get("headers") or {}).get(b"cookie")
if not cookie:
return {}
Expand All @@ -364,5 +373,5 @@ def cookies_from_scope(scope):
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"


def make_secret(length):
def make_secret(length: int) -> str:
return "".join(secrets.choice(allowed_chars) for i in range(length))
2 changes: 1 addition & 1 deletion test_asgi_csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken):

@pytest.mark.asyncio
@pytest.mark.parametrize("custom_errors", (False, True))
async def test_prevents_post_if_cookie_not_sent_in_post(
async def test_prevents_post_if_cookie_different_than_data(
custom_errors, app_csrf, app_csrf_custom_errors, csrftoken
):
async with httpx.AsyncClient(
Expand Down