From 3dc066c5b32123f7063e0f6b96e28193a5006577 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sun, 22 Sep 2024 10:05:32 +0200 Subject: [PATCH 1/8] typing: type jsonschema --- falcon/media/validators/jsonschema.py | 32 ++++++++++++++++++++++----- pyproject.toml | 1 - 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/falcon/media/validators/jsonschema.py b/falcon/media/validators/jsonschema.py index 0e6f14b67..561fcfdf1 100644 --- a/falcon/media/validators/jsonschema.py +++ b/falcon/media/validators/jsonschema.py @@ -1,5 +1,8 @@ +from __future__ import annotations + from functools import wraps from inspect import iscoroutinefunction +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING import falcon @@ -8,8 +11,17 @@ except ImportError: # pragma: nocover pass +if TYPE_CHECKING: + import falcon as wsgi + from falcon import asgi + +Schema = Optional[Dict[str, Any]] +ResponderMethod = Callable[..., Any] + -def validate(req_schema=None, resp_schema=None, is_async=False): +def validate( + req_schema: Schema = None, resp_schema: Schema = None, is_async: bool = False +) -> Callable[[ResponderMethod], ResponderMethod]: """Validate ``req.media`` using JSON Schema. This decorator provides standard JSON Schema validation via the @@ -99,7 +111,7 @@ async def on_post(self, req, resp): """ - def decorator(func): + def decorator(func: ResponderMethod) -> ResponderMethod: if iscoroutinefunction(func) or is_async: return _validate_async(func, req_schema, resp_schema) @@ -108,9 +120,13 @@ def decorator(func): return decorator -def _validate(func, req_schema=None, resp_schema=None): +def _validate( + func: ResponderMethod, req_schema: Schema = None, resp_schema: Schema = None +) -> ResponderMethod: @wraps(func) - def wrapper(self, req, resp, *args, **kwargs): + def wrapper( + self: Any, req: wsgi.Request, resp: wsgi.Response, *args: Any, **kwargs: Any + ) -> Any: if req_schema is not None: try: jsonschema.validate( @@ -141,9 +157,13 @@ def wrapper(self, req, resp, *args, **kwargs): return wrapper -def _validate_async(func, req_schema=None, resp_schema=None): +def _validate_async( + func: ResponderMethod, req_schema: Schema = None, resp_schema: Schema = None +) -> ResponderMethod: @wraps(func) - async def wrapper(self, req, resp, *args, **kwargs): + async def wrapper( + self: Any, req: asgi.Request, resp: asgi.Response, *args: Any, **kwargs: Any + ) -> Any: if req_schema is not None: m = await req.get_media() diff --git a/pyproject.toml b/pyproject.toml index ee74ee31d..d2a4b6c11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,6 @@ exclude = ["examples", "tests"] [[tool.mypy.overrides]] module = [ - "falcon.media.validators.*", "falcon.testing.*", "falcon.vendor.*", ] From 177d0c399834cbeec03d72236df6c396727909b9 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sun, 22 Sep 2024 10:19:31 +0200 Subject: [PATCH 2/8] typing: add missing future annotation; add lint rule to check for it --- examples/ws_tutorial/ws_tutorial/app.py | 2 ++ falcon/app.py | 2 ++ falcon/asgi/app.py | 2 ++ falcon/testing/client.py | 2 ++ falcon/testing/helpers.py | 2 ++ falcon/util/deprecation.py | 2 ++ falcon/util/sync.py | 2 ++ falcon/util/time.py | 2 ++ falcon/util/uri.py | 2 ++ pyproject.toml | 5 ++--- tests/test_status_codes.py | 3 +-- 11 files changed, 21 insertions(+), 5 deletions(-) diff --git a/examples/ws_tutorial/ws_tutorial/app.py b/examples/ws_tutorial/ws_tutorial/app.py index 054514a61..9dfab7e81 100644 --- a/examples/ws_tutorial/ws_tutorial/app.py +++ b/examples/ws_tutorial/ws_tutorial/app.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime import logging import pathlib diff --git a/falcon/app.py b/falcon/app.py index b66247051..a82e75445 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -14,6 +14,8 @@ """Falcon App class.""" +from __future__ import annotations + from functools import wraps from inspect import iscoroutinefunction import pathlib diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py index 4478728c2..69f154c15 100644 --- a/falcon/asgi/app.py +++ b/falcon/asgi/app.py @@ -14,6 +14,8 @@ """ASGI application class.""" +from __future__ import annotations + import asyncio from inspect import isasyncgenfunction from inspect import iscoroutinefunction diff --git a/falcon/testing/client.py b/falcon/testing/client.py index 1521128e7..75a4f2882 100644 --- a/falcon/testing/client.py +++ b/falcon/testing/client.py @@ -18,6 +18,8 @@ WSGI callable, without having to stand up a WSGI server. """ +from __future__ import annotations + import asyncio import datetime as dt import inspect diff --git a/falcon/testing/helpers.py b/falcon/testing/helpers.py index 05513b885..5b1a896e7 100644 --- a/falcon/testing/helpers.py +++ b/falcon/testing/helpers.py @@ -22,6 +22,8 @@ wsgi_environ = testing.create_environ() """ +from __future__ import annotations + import asyncio from collections import defaultdict from collections import deque diff --git a/falcon/util/deprecation.py b/falcon/util/deprecation.py index 56ccf213c..fdb99e721 100644 --- a/falcon/util/deprecation.py +++ b/falcon/util/deprecation.py @@ -17,6 +17,8 @@ This module provides decorators to mark functions and classes as deprecated. """ +from __future__ import annotations + import functools from typing import Any, Callable, Optional import warnings diff --git a/falcon/util/sync.py b/falcon/util/sync.py index 0bfc24021..4aebe23ae 100644 --- a/falcon/util/sync.py +++ b/falcon/util/sync.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from concurrent.futures import ThreadPoolExecutor from functools import partial diff --git a/falcon/util/time.py b/falcon/util/time.py index 6e4cce993..ad1e517b6 100644 --- a/falcon/util/time.py +++ b/falcon/util/time.py @@ -9,6 +9,8 @@ tz = falcon.TimezoneGMT() """ +from __future__ import annotations + import datetime from typing import Optional diff --git a/falcon/util/uri.py b/falcon/util/uri.py index 80e6b62bb..549dc721b 100644 --- a/falcon/util/uri.py +++ b/falcon/util/uri.py @@ -23,6 +23,8 @@ name, port = uri.parse_host('example.org:8080') """ +from __future__ import annotations + from typing import Callable, Dict, List, Optional, overload, Tuple, Union from falcon.constants import PYPY diff --git a/pyproject.toml b/pyproject.toml index d2a4b6c11..a6b6cc96e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,6 @@ exclude = ["examples", "tests"] [[tool.mypy.overrides]] module = [ - "falcon.testing.*", "falcon.vendor.*", ] disallow_untyped_defs = false @@ -189,7 +188,6 @@ exclude = ["examples", "tests"] "build", "dist", "docs", - "examples", "falcon/bench/nuts", ] @@ -199,7 +197,8 @@ exclude = ["examples", "tests"] "E", "F", "W", - "I" + "I", + "FA" ] [tool.ruff.lint.mccabe] diff --git a/tests/test_status_codes.py b/tests/test_status_codes.py index b89736883..9017d0165 100644 --- a/tests/test_status_codes.py +++ b/tests/test_status_codes.py @@ -1,6 +1,5 @@ import http import sys -from typing import Tuple import pytest @@ -22,7 +21,7 @@ def test_statuses_are_in_compliance_with_http_from_python313(self, status): else: assert http_status.phrase == message - def _status_code_and_message(self, status: str) -> Tuple[int, str]: + def _status_code_and_message(self, status: str): status = getattr(status_codes, status) value, message = status.split(' ', 1) return int(value), message From 227402eb390800ba973fb9702e1b3ca7ca9c20b9 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sun, 22 Sep 2024 10:20:00 +0200 Subject: [PATCH 3/8] refactor: rename HeaderList to HeaderArg --- falcon/asgi/ws.py | 6 +-- falcon/errors.py | 82 +++++++++++++++++++------------------- falcon/http_error.py | 4 +- falcon/http_status.py | 6 +-- falcon/testing/resource.py | 4 +- falcon/typing.py | 3 +- 6 files changed, 53 insertions(+), 52 deletions(-) diff --git a/falcon/asgi/ws.py b/falcon/asgi/ws.py index c02031993..f2c693d5d 100644 --- a/falcon/asgi/ws.py +++ b/falcon/asgi/ws.py @@ -17,7 +17,7 @@ from falcon.constants import WebSocketPayloadType from falcon.typing import AsgiReceive from falcon.typing import AsgiSend -from falcon.typing import HeaderList +from falcon.typing import HeaderArg from falcon.util import misc __all__ = ('WebSocket',) @@ -144,7 +144,7 @@ def supports_accept_headers(self) -> bool: async def accept( self, subprotocol: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, ) -> None: """Accept the incoming WebSocket connection. @@ -166,7 +166,7 @@ async def accept( client may choose to abandon the connection in this case, if it does not receive an explicit protocol selection. - headers (HeaderList): An iterable of ``(name: str, value: str)`` + headers (HeaderArg): An iterable of ``(name: str, value: str)`` two-item iterables, representing a collection of HTTP headers to include in the handshake response. Both *name* and *value* must be of type ``str`` and contain only US-ASCII characters. diff --git a/falcon/errors.py b/falcon/errors.py index afc1faa8a..a95d52c8d 100644 --- a/falcon/errors.py +++ b/falcon/errors.py @@ -45,7 +45,7 @@ def on_get(self, req, resp): from falcon.util.misc import dt_to_http if TYPE_CHECKING: - from falcon.typing import HeaderList + from falcon.typing import HeaderArg from falcon.typing import Headers @@ -229,7 +229,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ) -> None: super().__init__( @@ -312,7 +312,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, challenges: Optional[Iterable[str]] = None, **kwargs: HTTPErrorKeywordArguments, ): @@ -390,7 +390,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -461,7 +461,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -588,7 +588,7 @@ def __init__( allowed_methods: Iterable[str], title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): headers = _load_headers(headers) @@ -659,7 +659,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -732,7 +732,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -811,7 +811,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -875,7 +875,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -940,7 +940,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1016,7 +1016,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, retry_after: RetryAfter = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ) -> None: super().__init__( @@ -1086,7 +1086,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1151,7 +1151,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1230,7 +1230,7 @@ def __init__( resource_length: int, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): headers = _load_headers(headers) @@ -1300,7 +1300,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1362,7 +1362,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1423,7 +1423,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1492,7 +1492,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1566,7 +1566,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, retry_after: RetryAfter = None, **kwargs: HTTPErrorKeywordArguments, ): @@ -1635,7 +1635,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1710,7 +1710,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1771,7 +1771,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1839,7 +1839,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1900,7 +1900,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -1977,7 +1977,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, retry_after: RetryAfter = None, **kwargs: HTTPErrorKeywordArguments, ): @@ -2040,7 +2040,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -2107,7 +2107,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -2172,7 +2172,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -2234,7 +2234,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -2308,7 +2308,7 @@ def __init__( self, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): super().__init__( @@ -2367,7 +2367,7 @@ def __init__( self, msg: str, header_name: str, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): description = 'The value provided for the "{0}" header is invalid. {1}' @@ -2426,7 +2426,7 @@ class HTTPMissingHeader(HTTPBadRequest): def __init__( self, header_name: str, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ): description = 'The "{0}" header is required.' @@ -2489,7 +2489,7 @@ def __init__( self, msg: str, param_name: str, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ) -> None: description = 'The "{0}" parameter is invalid. {1}' @@ -2550,7 +2550,7 @@ class HTTPMissingParam(HTTPBadRequest): def __init__( self, param_name: str, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ) -> None: description = 'The "{0}" parameter is required.' @@ -2649,7 +2649,7 @@ class MediaMalformedError(HTTPBadRequest): """ def __init__( - self, media_type: str, **kwargs: Union[HeaderList, HTTPErrorKeywordArguments] + self, media_type: str, **kwargs: Union[HeaderArg, HTTPErrorKeywordArguments] ): super().__init__( title='Invalid {0}'.format(media_type), description=None, **kwargs @@ -2718,7 +2718,7 @@ def __init__( *, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, **kwargs: HTTPErrorKeywordArguments, ) -> None: super().__init__( @@ -2753,7 +2753,7 @@ class MultipartParseError(MediaMalformedError): def __init__( self, description: Optional[str] = None, - **kwargs: Union[HeaderList, HTTPErrorKeywordArguments], + **kwargs: Union[HeaderArg, HTTPErrorKeywordArguments], ) -> None: HTTPBadRequest.__init__( self, @@ -2768,7 +2768,7 @@ def __init__( # ----------------------------------------------------------------------------- -def _load_headers(headers: Optional[HeaderList]) -> Headers: +def _load_headers(headers: Optional[HeaderArg]) -> Headers: """Transform the headers to dict.""" if headers is None: return {} @@ -2778,9 +2778,9 @@ def _load_headers(headers: Optional[HeaderList]) -> Headers: def _parse_retry_after( - headers: Optional[HeaderList], + headers: Optional[HeaderArg], retry_after: RetryAfter, -) -> Optional[HeaderList]: +) -> Optional[HeaderArg]: """Set the Retry-After to the headers when required.""" if retry_after is None: return headers diff --git a/falcon/http_error.py b/falcon/http_error.py index 3f4af635c..d6492397f 100644 --- a/falcon/http_error.py +++ b/falcon/http_error.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from falcon.media import BaseHandler - from falcon.typing import HeaderList + from falcon.typing import HeaderArg from falcon.typing import Link from falcon.typing import ResponseStatus @@ -123,7 +123,7 @@ def __init__( status: ResponseStatus, title: Optional[str] = None, description: Optional[str] = None, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, href: Optional[str] = None, href_text: Optional[str] = None, code: Optional[int] = None, diff --git a/falcon/http_status.py b/falcon/http_status.py index 1a591fcf4..fcc002d97 100644 --- a/falcon/http_status.py +++ b/falcon/http_status.py @@ -21,7 +21,7 @@ from falcon.util.deprecation import AttributeRemovedError if TYPE_CHECKING: - from falcon.typing import HeaderList + from falcon.typing import HeaderArg from falcon.typing import ResponseStatus @@ -48,7 +48,7 @@ class HTTPStatus(Exception): """The HTTP status line or integer code for the status that this exception represents. """ - headers: Optional[HeaderList] + headers: Optional[HeaderArg] """Extra headers to add to the response.""" text: Optional[str] """String representing response content. @@ -58,7 +58,7 @@ class HTTPStatus(Exception): def __init__( self, status: ResponseStatus, - headers: Optional[HeaderList] = None, + headers: Optional[HeaderArg] = None, text: Optional[str] = None, ) -> None: self.status = status diff --git a/falcon/testing/resource.py b/falcon/testing/resource.py index cb1715bcb..32b8833f8 100644 --- a/falcon/testing/resource.py +++ b/falcon/testing/resource.py @@ -33,7 +33,7 @@ if typing.TYPE_CHECKING: # pragma: no cover from falcon import app as wsgi from falcon.asgi import app as asgi - from falcon.typing import HeaderList + from falcon.typing import HeaderArg from falcon.typing import Resource @@ -181,7 +181,7 @@ def __init__( status: typing.Optional[str] = None, body: typing.Optional[str] = None, json: typing.Optional[dict[str, str]] = None, - headers: typing.Optional[HeaderList] = None, + headers: typing.Optional[HeaderArg] = None, ): self._default_status = status self._default_headers = headers diff --git a/falcon/typing.py b/falcon/typing.py index a548b32ac..7bcdfa185 100644 --- a/falcon/typing.py +++ b/falcon/typing.py @@ -117,7 +117,8 @@ async def __call__( # class SinkCallable(Protocol): # def __call__(sef, req: Request, resp: Response, ): ... Headers = Dict[str, str] -HeaderList = Union[Headers, List[Tuple[str, str]]] +HeaderList = List[Tuple[str, str]] +HeaderArg = Union[Headers, HeaderList] ResponseStatus = Union[http.HTTPStatus, str, int] StoreArgument = Optional[Dict[str, Any]] Resource = object From e22a0b761df7aa685dad74288da8928725dc9ad3 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sun, 22 Sep 2024 11:23:21 +0200 Subject: [PATCH 4/8] typing: type testing helpers --- falcon/testing/helpers.py | 277 +++++++++++++++++++++---------------- falcon/testing/resource.py | 4 +- falcon/testing/srmock.py | 23 ++- 3 files changed, 178 insertions(+), 126 deletions(-) diff --git a/falcon/testing/helpers.py b/falcon/testing/helpers.py index 5b1a896e7..c44208ec0 100644 --- a/falcon/testing/helpers.py +++ b/falcon/testing/helpers.py @@ -30,6 +30,7 @@ import contextlib from enum import auto from enum import Enum +from http.cookiejar import Cookie import io import itertools import json @@ -38,16 +39,35 @@ import socket import sys import time -from typing import Any, Dict, Iterable, Optional, Union +from typing import ( + Any, + Callable, + Deque, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + TextIO, + Tuple, + Type, + Union, +) import falcon from falcon import errors as falcon_errors import falcon.asgi +from falcon.asgi_spec import AsgiEvent from falcon.asgi_spec import EventType from falcon.asgi_spec import ScopeType from falcon.asgi_spec import WSCloseCode from falcon.constants import SINGLETON_HEADERS import falcon.request +from falcon.typing import HeaderArg +from falcon.typing import HeaderList +from falcon.typing import ResponseStatus +from falcon.util import code_to_http_status from falcon.util import uri from falcon.util.mediatypes import parse_header @@ -75,11 +95,11 @@ class ASGILifespanEventEmitter: emitting the final shutdown event (``'lifespan.shutdown``). """ - def __init__(self, shutting_down): + def __init__(self, shutting_down: asyncio.Condition) -> None: self._state = 0 self._shutting_down = shutting_down - async def emit(self): + async def emit(self) -> AsgiEvent: if self._state == 0: self._state += 1 return {'type': EventType.LIFESPAN_STARTUP} @@ -134,14 +154,14 @@ class ASGIRequestEventEmitter: # TODO(kgriffs): If this pattern later becomes useful elsewhere, # factor out into a standalone helper class. - _branch_decider = defaultdict(bool) # type: defaultdict + _branch_decider: dict[str, bool] = defaultdict(bool) def __init__( self, body: Optional[Union[str, bytes]] = None, chunk_size: Optional[int] = None, disconnect_at: Optional[Union[int, float]] = None, - ): + ) -> None: if body is None: body = b'' elif not isinstance(body, bytes): @@ -155,7 +175,7 @@ def __init__( if chunk_size is None: chunk_size = 4096 - self._body = body # type: Optional[memoryview] + self._body: Optional[memoryview] = body self._chunk_size = chunk_size self._emit_empty_chunks = True self._disconnect_at = disconnect_at @@ -166,10 +186,10 @@ def __init__( self._emitted_empty_chunk_b = False @property - def disconnected(self): + def disconnected(self) -> bool: return self._disconnected or (self._disconnect_at <= time.time()) - def disconnect(self, exhaust_body: Optional[bool] = None): + def disconnect(self, exhaust_body: Optional[bool] = None) -> None: """Set the client connection state to disconnected. Call this method to simulate an immediate client disconnect and @@ -186,7 +206,7 @@ def disconnect(self, exhaust_body: Optional[bool] = None): self._disconnected = True - async def emit(self) -> Dict[str, Any]: + async def emit(self) -> AsgiEvent: # NOTE(kgriffs): Special case: if we are immediately disconnected, # the first event should be 'http.disconnnect' if self._disconnect_at == 0: @@ -210,7 +230,7 @@ async def emit(self) -> Dict[str, Any]: return {'type': EventType.HTTP_DISCONNECT} - event = {'type': EventType.HTTP_REQUEST} # type: Dict[str, Any] + event: Dict[str, Any] = {'type': EventType.HTTP_REQUEST} if self._emit_empty_chunks: # NOTE(kgriffs): Return a couple variations on empty chunks @@ -266,7 +286,7 @@ async def emit(self) -> Dict[str, Any]: __call__ = emit - def _toggle_branch(self, name: str): + def _toggle_branch(self, name: str) -> bool: self._branch_decider[name] = not self._branch_decider[name] return self._branch_decider[name] @@ -305,14 +325,14 @@ class ASGIResponseEventCollector: _HEADER_NAME_RE = re.compile(rb'^[a-zA-Z][a-zA-Z0-9\-_]*$') _BAD_HEADER_VALUE_RE = re.compile(rb'[\000-\037]') - def __init__(self): - self.events = [] - self.headers = [] - self.status = None - self.body_chunks = [] - self.more_body = None + def __init__(self) -> None: + self.events: List[AsgiEvent] = [] + self.headers: HeaderList = [] + self.status: Optional[ResponseStatus] = None + self.body_chunks: list[bytes] = [] + self.more_body: Optional[bool] = None - async def collect(self, event: Dict[str, Any]): + async def collect(self, event: AsgiEvent) -> None: if self.more_body is False: # NOTE(kgriffs): According to the ASGI spec, once we get a # message setting more_body to False, any further messages @@ -410,17 +430,17 @@ class ASGIWebSocketSimulator: _DEFAULT_WAIT_READY_TIMEOUT = 5 - def __init__(self): + def __init__(self) -> None: self.__msgpack = None self._state = _WebSocketState.CONNECT self._disconnect_emitted = False - self._close_code = None - self._close_reason = None - self._accepted_subprotocol = None - self._accepted_headers = None - self._collected_server_events = deque() - self._collected_client_events = deque() + self._close_code: Optional[int] = None + self._close_reason: Optional[str] = None + self._accepted_subprotocol: Optional[str] = None + self._accepted_headers: Optional[List[Tuple[bytes, bytes]]] = None + self._collected_server_events: Deque[AsgiEvent] = deque() + self._collected_client_events: Deque[AsgiEvent] = deque() self._event_handshake_complete = asyncio.Event() @@ -433,22 +453,22 @@ def closed(self) -> bool: return self._state in {_WebSocketState.DENIED, _WebSocketState.CLOSED} @property - def close_code(self) -> int: + def close_code(self) -> Optional[int]: return self._close_code @property - def close_reason(self) -> str: + def close_reason(self) -> Optional[str]: return self._close_reason @property - def subprotocol(self) -> str: + def subprotocol(self) -> Optional[str]: return self._accepted_subprotocol @property - def headers(self) -> Iterable[Iterable[bytes]]: + def headers(self) -> Optional[List[Tuple[bytes, bytes]]]: return self._accepted_headers - async def wait_ready(self, timeout: Optional[int] = None): + async def wait_ready(self, timeout: Optional[int] = None) -> None: """Wait until the connection has been accepted or denied. This coroutine can be awaited in order to pause execution until the @@ -478,7 +498,9 @@ async def wait_ready(self, timeout: Optional[int] = None): # NOTE(kgriffs): This is a coroutine just in case we need it to be # in a future code revision. It also makes it more consistent # with the other methods. - async def close(self, code: Optional[int] = None, reason: Optional[str] = None): + async def close( + self, code: Optional[int] = None, reason: Optional[str] = None + ) -> None: """Close the simulated connection. Keyword Args: @@ -511,7 +533,7 @@ async def close(self, code: Optional[int] = None, reason: Optional[str] = None): self._close_code = code self._close_reason = reason - async def send_text(self, payload: str): + async def send_text(self, payload: str) -> None: """Send a message to the app with a Unicode string payload. Arguments: @@ -528,7 +550,7 @@ async def send_text(self, payload: str): # but the server will be expecting websocket.receive await self._send(text=payload) - async def send_data(self, payload: Union[bytes, bytearray, memoryview]): + async def send_data(self, payload: Union[bytes, bytearray, memoryview]) -> None: """Send a message to the app with a binary data payload. Arguments: @@ -545,7 +567,7 @@ async def send_data(self, payload: Union[bytes, bytearray, memoryview]): # but the server will be expecting websocket.receive await self._send(data=bytes(payload)) - async def send_json(self, media: object): + async def send_json(self, media: object) -> None: """Send a message to the app with a JSON-encoded payload. Arguments: @@ -555,7 +577,7 @@ async def send_json(self, media: object): text = json.dumps(media) await self.send_text(text) - async def send_msgpack(self, media: object): + async def send_msgpack(self, media: object) -> None: """Send a message to the app with a MessagePack-encoded payload. Arguments: @@ -613,7 +635,7 @@ async def receive_data(self) -> bytes: return data - async def receive_json(self) -> object: + async def receive_json(self) -> Any: """Receive a message from the app with a JSON-encoded TEXT payload. Awaiting this coroutine will block until a message is available or @@ -623,7 +645,7 @@ async def receive_json(self) -> object: text = await self.receive_text() return json.loads(text) - async def receive_msgpack(self) -> object: + async def receive_msgpack(self) -> Any: """Receive a message from the app with a MessagePack-encoded BINARY payload. Awaiting this coroutine will block until a message is available or @@ -634,7 +656,7 @@ async def receive_msgpack(self) -> object: return self._msgpack.unpackb(data, use_list=True, raw=False) @property - def _msgpack(self): + def _msgpack(self) -> Any: # NOTE(kgriffs): A property is used in lieu of referencing # the msgpack module directly, in order to bubble up the # import error in an obvious way, when the package has @@ -647,7 +669,7 @@ def _msgpack(self): return self.__msgpack - def _require_accepted(self): + def _require_accepted(self) -> None: if self._state == _WebSocketState.ACCEPTED: return @@ -674,13 +696,14 @@ def _require_accepted(self): # NOTE(kgriffs): This is a coroutine just in case we need it to be # in a future code revision. It also makes it more consistent # with the other methods. - async def _send(self, data: Optional[bytes] = None, text: Optional[str] = None): + async def _send( + self, data: Optional[bytes] = None, text: Optional[str] = None + ) -> None: self._require_accepted() # NOTE(kgriffs): From the client's perspective, it was a send, # but the server will be expecting websocket.receive - event = {'type': EventType.WS_RECEIVE} # type: Dict[str, Union[bytes, str]] - + event: Dict[str, Any] = {'type': EventType.WS_RECEIVE} if data is not None: event['bytes'] = data @@ -694,7 +717,7 @@ async def _send(self, data: Optional[bytes] = None, text: Optional[str] = None): # like it's 1992.) await asyncio.sleep(0) - async def _receive(self) -> Dict[str, Any]: + async def _receive(self) -> AsgiEvent: while not self._collected_server_events: self._require_accepted() await asyncio.sleep(0) @@ -702,7 +725,7 @@ async def _receive(self) -> Dict[str, Any]: self._require_accepted() return self._collected_server_events.popleft() - async def _emit(self) -> Dict[str, Any]: + async def _emit(self) -> AsgiEvent: if self._state == _WebSocketState.CONNECT: self._state = _WebSocketState.HANDSHAKE return {'type': EventType.WS_CONNECT} @@ -719,7 +742,7 @@ async def _emit(self) -> Dict[str, Any]: return self._collected_client_events.popleft() - async def _collect(self, event: Dict[str, Any]): + async def _collect(self, event: AsgiEvent) -> None: assert event if self._state == _WebSocketState.CONNECT: @@ -763,7 +786,7 @@ async def _collect(self, event: Dict[str, Any]): # connection is closed with a 403 and there is no websocket # close code). self._close_code = WSCloseCode.FORBIDDEN - self._close_reason = falcon.util.code_to_http_status( + self._close_reason = code_to_http_status( WSCloseCode.FORBIDDEN - 3000 ) @@ -796,7 +819,7 @@ async def _collect(self, event: Dict[str, Any]): # collected data/text event a chance to progress. await asyncio.sleep(0) - def _create_checked_disconnect(self) -> Dict[str, Any]: + def _create_checked_disconnect(self) -> AsgiEvent: if self._disconnect_emitted: raise falcon_errors.OperationNotAllowed( 'The websocket.disconnect event has already been emitted, ' @@ -816,7 +839,7 @@ def _create_checked_disconnect(self) -> Dict[str, Any]: # get_encoding_from_headers() is Copyright 2016 Kenneth Reitz, and is # used here under the terms of the Apache License, Version 2.0. -def get_encoding_from_headers(headers): +def get_encoding_from_headers(headers: Dict[str, str]) -> Optional[str]: """Return encoding from given HTTP Header Dict. Args: @@ -863,7 +886,7 @@ def get_unused_port() -> int: return s.getsockname()[1] -def rand_string(min, max) -> str: +def rand_string(min: int, max: int) -> str: """Return a randomly-generated string, of a random length. Args: @@ -877,20 +900,23 @@ def rand_string(min, max) -> str: return ''.join([chr(int_gen(ord(' '), ord('~'))) for __ in range(string_length)]) +_CookieType = Mapping[str, Union[str, Cookie]] + + def create_scope( - path='/', - query_string='', - method='GET', - headers=None, - host=DEFAULT_HOST, - scheme=None, - port=None, - http_version='1.1', - remote_addr=None, - root_path=None, - content_length=None, - include_server=True, - cookies=None, + path: str = '/', + query_string: str = '', + method: str = 'GET', + headers: Optional[HeaderArg] = None, + host: str = DEFAULT_HOST, + scheme: Optional[str] = None, + port: Optional[int] = None, + http_version: str = '1.1', + remote_addr: Optional[str] = None, + root_path: Optional[str] = None, + content_length: Optional[int] = None, + include_server: bool = True, + cookies: Optional[_CookieType] = None, ) -> Dict[str, Any]: """Create a mock ASGI scope ``dict`` for simulating HTTP requests. @@ -948,12 +974,12 @@ def create_scope( path = uri.decode(path, unquote_plus=False) # NOTE(kgriffs): Handles both None and '' - query_string = query_string.encode() if query_string else b'' + query_string_bytes = query_string.encode() if query_string else b'' - if query_string and query_string.startswith(b'?'): + if query_string_bytes and query_string_bytes.startswith(b'?'): raise ValueError("query_string should not start with '?'") - scope = { + scope: Dict[str, Any] = { 'type': ScopeType.HTTP, 'asgi': { 'version': '3.0', @@ -962,7 +988,7 @@ def create_scope( 'http_version': http_version, 'method': method.upper(), 'path': path, - 'query_string': query_string, + 'query_string': query_string_bytes, } # NOTE(kgriffs): Explicitly test against None so that the caller @@ -1016,18 +1042,18 @@ def create_scope( def create_scope_ws( - path='/', - query_string='', - headers=None, - host=DEFAULT_HOST, - scheme=None, - port=None, - http_version='1.1', - remote_addr=None, - root_path=None, - include_server=True, - subprotocols=None, - spec_version='2.1', + path: str = '/', + query_string: str = '', + headers: Optional[HeaderArg] = None, + host: str = DEFAULT_HOST, + scheme: Optional[str] = None, + port: Optional[int] = None, + http_version: str = '1.1', + remote_addr: Optional[str] = None, + root_path: Optional[str] = None, + include_server: bool = True, + subprotocols: Optional[str] = None, + spec_version: str = '2.1', ) -> Dict[str, Any]: """Create a mock ASGI scope ``dict`` for simulating WebSocket requests. @@ -1099,21 +1125,21 @@ def create_scope_ws( def create_environ( - path='/', - query_string='', - http_version='1.1', - scheme='http', - host=DEFAULT_HOST, - port=None, - headers=None, - app=None, - body='', - method='GET', - wsgierrors=None, - file_wrapper=None, - remote_addr=None, - root_path=None, - cookies=None, + path: str = '/', + query_string: str = '', + http_version: str = '1.1', + scheme: str = 'http', + host: str = DEFAULT_HOST, + port: Optional[int] = None, + headers: Optional[HeaderArg] = None, + app: Optional[str] = None, + body: str = '', + method: str = 'GET', + wsgierrors: Optional[io.StringIO] = None, + file_wrapper: Optional[Callable[..., Any]] = None, + remote_addr: Optional[str] = None, + root_path: Optional[str] = None, + cookies: Optional[_CookieType] = None, ) -> Dict[str, Any]: """Create a mock PEP-3333 environ ``dict`` for simulating WSGI requests. @@ -1176,7 +1202,7 @@ def create_environ( if query_string and query_string.startswith('?'): raise ValueError("query_string should not start with '?'") - body = io.BytesIO(body.encode() if isinstance(body, str) else body) + body_bytes = io.BytesIO(body.encode() if isinstance(body, str) else body) # NOTE(kgriffs): wsgiref, gunicorn, and uWSGI all unescape # the paths before setting PATH_INFO but preserve raw original @@ -1201,11 +1227,11 @@ def create_environ( scheme = scheme.lower() if port is None: - port = '80' if scheme == 'http' else '443' + port_str = '80' if scheme == 'http' else '443' else: # NOTE(kgriffs): Running it through int() first ensures that if # a string was passed, it is a valid integer. - port = str(int(port)) + port_str = str(int(port)) root_path = root_path or app or '' @@ -1215,7 +1241,7 @@ def create_environ( if root_path and not root_path.startswith('/'): root_path = '/' + root_path - env = { + env: dict[str, Any] = { 'SERVER_PROTOCOL': 'HTTP/' + http_version, 'SERVER_SOFTWARE': 'gunicorn/0.17.0', 'SCRIPT_NAME': (root_path or ''), @@ -1225,10 +1251,10 @@ def create_environ( 'REMOTE_PORT': '65133', 'RAW_URI': raw_path, 'SERVER_NAME': host, - 'SERVER_PORT': port, + 'SERVER_PORT': port_str, 'wsgi.version': (1, 0), 'wsgi.url_scheme': scheme, - 'wsgi.input': body, + 'wsgi.input': body_bytes, 'wsgi.errors': wsgierrors or sys.stderr, 'wsgi.multithread': False, 'wsgi.multiprocess': True, @@ -1248,16 +1274,16 @@ def create_environ( host_header = host if scheme == 'https': - if port != '443': - host_header += ':' + port + if port_str != '443': + host_header += ':' + port_str else: - if port != '80': - host_header += ':' + port + if port_str != '80': + host_header += ':' + port_str env['HTTP_HOST'] = host_header - content_length = body.seek(0, 2) - body.seek(0) + content_length = body_bytes.seek(0, 2) + body_bytes.seek(0) if content_length != 0: env['CONTENT_LENGTH'] = str(content_length) @@ -1272,7 +1298,9 @@ def create_environ( return env -def create_req(options=None, **kwargs) -> falcon.Request: +def create_req( + options: Optional[falcon.request.RequestOptions] = None, **kwargs: Any +) -> falcon.Request: """Create and return a new Request instance. This function can be used to conveniently create a WSGI environ @@ -1290,7 +1318,10 @@ def create_req(options=None, **kwargs) -> falcon.Request: def create_asgi_req( - body=None, req_type=None, options=None, **kwargs + body: Optional[bytes] = None, + req_type: Optional[Type[falcon.asgi.Request]] = None, + options: Optional[falcon.request.RequestOptions] = None, + **kwargs: Any, ) -> falcon.asgi.Request: """Create and return a new ASGI Request instance. @@ -1326,7 +1357,9 @@ def create_asgi_req( @contextlib.contextmanager -def redirected(stdout=sys.stdout, stderr=sys.stderr): +def redirected( + stdout: TextIO = sys.stdout, stderr: TextIO = sys.stderr +) -> Iterator[None]: """Redirect stdout or stderr temporarily. e.g.: @@ -1343,7 +1376,7 @@ def redirected(stdout=sys.stdout, stderr=sys.stderr): sys.stderr, sys.stdout = old_stderr, old_stdout -def closed_wsgi_iterable(iterable): +def closed_wsgi_iterable(iterable: Iterable[bytes]) -> Iterable[bytes]: """Wrap an iterable to ensure its ``close()`` method is called. Wraps the given `iterable` in an iterator utilizing a ``for`` loop as @@ -1366,7 +1399,7 @@ def closed_wsgi_iterable(iterable): iterator: An iterator yielding the same bytestrings as `iterable` """ - def wrapper(): + def wrapper() -> Iterator[bytes]: try: for item in iterable: yield item @@ -1375,6 +1408,7 @@ def wrapper(): iterable.close() wrapped = wrapper() + head: Tuple[bytes, ...] try: head = (next(wrapped),) except StopIteration: @@ -1387,10 +1421,10 @@ def wrapper(): # --------------------------------------------------------------------- -def _add_headers_to_environ(env, headers): +def _add_headers_to_environ(env: Dict[str, Any], headers: Optional[HeaderArg]) -> None: if headers: try: - items = headers.items() + items = headers.items() # type: ignore[union-attr] except AttributeError: items = headers @@ -1413,14 +1447,21 @@ def _add_headers_to_environ(env, headers): def _add_headers_to_scope( - scope, headers, content_length, host, port, scheme, http_version, cookies -): + scope: dict[str, Any], + headers: Optional[HeaderArg], + content_length: Optional[int], + host: str, + port: int, + scheme: Optional[str], + http_version: str, + cookies: Optional[_CookieType], +) -> None: found_ua = False - prepared_headers = [] + prepared_headers: List[Iterable[bytes]] = [] if headers: try: - items = headers.items() + items = headers.items() # type: ignore[union-attr] except AttributeError: items = headers @@ -1463,7 +1504,7 @@ def _add_headers_to_scope( scope['headers'] = iter(prepared_headers) -def _fixup_http_version(http_version) -> str: +def _fixup_http_version(http_version: str) -> str: if http_version not in ('2', '2.0', '1.1', '1.0', '1'): raise ValueError('Invalid http_version specified: ' + http_version) @@ -1477,7 +1518,7 @@ def _fixup_http_version(http_version) -> str: return http_version -def _make_cookie_values(cookies: Dict) -> str: +def _make_cookie_values(cookies: _CookieType) -> str: return '; '.join( [ '{}={}'.format(key, cookie.value if hasattr(cookie, 'value') else cookie) diff --git a/falcon/testing/resource.py b/falcon/testing/resource.py index 32b8833f8..9896a9da6 100644 --- a/falcon/testing/resource.py +++ b/falcon/testing/resource.py @@ -182,7 +182,7 @@ def __init__( body: typing.Optional[str] = None, json: typing.Optional[dict[str, str]] = None, headers: typing.Optional[HeaderArg] = None, - ): + ) -> None: self._default_status = status self._default_headers = headers @@ -209,7 +209,7 @@ def __init__( self.captured_req_body: typing.Optional[bytes] = None @property - def called(self): + def called(self) -> bool: return self.captured_req is not None @falcon.before(capture_responder_args) diff --git a/falcon/testing/srmock.py b/falcon/testing/srmock.py index 389b85744..bae2613c9 100644 --- a/falcon/testing/srmock.py +++ b/falcon/testing/srmock.py @@ -18,7 +18,13 @@ used, along with a mock environ dict, to simulate a WSGI request. """ +from __future__ import annotations + +from typing import Optional + from falcon import util +from falcon.typing import HeaderList +from falcon.typing import ResponseStatus class StartResponseMock: @@ -35,13 +41,18 @@ class StartResponseMock: """ - def __init__(self): + def __init__(self) -> None: self._called = 0 - self.status = None - self.headers = None - self.exc_info = None + self.status: Optional[ResponseStatus] = None + self.headers: Optional[HeaderList] = None + self.exc_info: Optional[Exception] = None - def __call__(self, status, headers, exc_info=None): + def __call__( + self, + status: ResponseStatus, + headers: HeaderList, + exc_info: Optional[Exception] = None, + ) -> None: """Implement the PEP-3333 `start_response` protocol.""" self._called += 1 @@ -56,5 +67,5 @@ def __call__(self, status, headers, exc_info=None): self.exc_info = exc_info @property - def call_count(self): + def call_count(self) -> int: return self._called From 13dacd621be0dd66110fef3e18237fee8979cabf Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Mon, 23 Sep 2024 08:50:56 +0200 Subject: [PATCH 5/8] typing: type testing client --- falcon/testing/client.py | 299 ++++++++++++++++++++++-------------- falcon/testing/helpers.py | 19 +-- falcon/testing/srmock.py | 13 +- falcon/testing/test_case.py | 2 +- falcon/typing.py | 4 +- falcon/util/misc.py | 6 +- 6 files changed, 204 insertions(+), 139 deletions(-) diff --git a/falcon/testing/client.py b/falcon/testing/client.py index 75a4f2882..3774e7bf6 100644 --- a/falcon/testing/client.py +++ b/falcon/testing/client.py @@ -22,19 +22,39 @@ import asyncio import datetime as dt +from http.cookies import Morsel import inspect import json as json_module import time -from typing import Dict, Optional, Sequence, Union +from typing import ( + Any, + Awaitable, + Callable, + cast, + Coroutine, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + TextIO, + Tuple, + TypeVar, + Union, +) import warnings import wsgiref.validate +from falcon.asgi_spec import AsgiEvent from falcon.asgi_spec import ScopeType from falcon.constants import COMBINED_METHODS from falcon.constants import MEDIA_JSON from falcon.errors import CompatibilityError from falcon.testing import helpers from falcon.testing.srmock import StartResponseMock +from falcon.typing import CookieArg +from falcon.typing import HeaderList +from falcon.typing import Headers from falcon.util import async_to_sync from falcon.util import CaseInsensitiveDict from falcon.util import code_to_http_status @@ -50,18 +70,21 @@ 0, ) +_T = TypeVar('_T', bound=Callable[..., Any]) -def _simulate_method_alias(method, version_added='3.1', replace_name=None): - return_type = inspect.signature(method).return_annotation - def alias(client, *args, **kwargs) -> return_type: +def _simulate_method_alias( + method: _T, version_added: str = '3.1', replace_name: Optional[str] = None +) -> _T: + def alias(client: Any, *args: Any, **kwargs: Any) -> Any: return method(client, *args, **kwargs) - async def async_alias(client, *args, **kwargs) -> return_type: + async def async_alias(client: Any, *args: Any, **kwargs: Any) -> Any: return await method(client, *args, **kwargs) alias = async_alias if inspect.iscoroutinefunction(method) else alias + assert method.__doc__ alias.__doc__ = method.__doc__ + '\n .. versionadded:: {}\n'.format( version_added ) @@ -71,7 +94,7 @@ async def async_alias(client, *args, **kwargs) -> return_type: else: alias.__name__ = method.__name__.partition('simulate_')[-1] - return alias + return cast(_T, alias) class Cookie: @@ -103,7 +126,16 @@ class Cookie: ``Partitioned`` flag set. """ - def __init__(self, morsel): + _expires: Optional[str] + _path: str + _domain: str + _max_age: Optional[str] + _secure: Optional[str] + _httponly: Optional[str] + _samesite: Optional[str] + _partitioned: Optional[str] + + def __init__(self, morsel: Morsel) -> None: self._name = morsel.key self._value = morsel.value @@ -130,38 +162,38 @@ def value(self) -> str: @property def expires(self) -> Optional[dt.datetime]: - if self._expires: # type: ignore[attr-defined] - return http_date_to_dt(self._expires, obs_date=True) # type: ignore[attr-defined] # noqa E501 + if self._expires: + return http_date_to_dt(self._expires, obs_date=True) return None @property def path(self) -> str: - return self._path # type: ignore[attr-defined] + return self._path @property def domain(self) -> str: - return self._domain # type: ignore[attr-defined] + return self._domain @property def max_age(self) -> Optional[int]: - return int(self._max_age) if self._max_age else None # type: ignore[attr-defined] # noqa E501 + return int(self._max_age) if self._max_age else None @property def secure(self) -> bool: - return bool(self._secure) # type: ignore[attr-defined] + return bool(self._secure) @property def http_only(self) -> bool: - return bool(self._httponly) # type: ignore[attr-defined] + return bool(self._httponly) @property def same_site(self) -> Optional[str]: - return self._samesite if self._samesite else None # type: ignore[attr-defined] + return self._samesite if self._samesite else None @property def partitioned(self) -> bool: - return bool(self._partitioned) # type: ignore[attr-defined] + return bool(self._partitioned) class _ResultBase: @@ -201,7 +233,7 @@ class _ResultBase: if the encoding can not be determined. """ - def __init__(self, status, headers): + def __init__(self, status: str, headers: HeaderList) -> None: self._status = status self._status_code = int(status[:3]) self._headers = CaseInsensitiveDict(headers) @@ -241,7 +273,7 @@ def cookies(self) -> Dict[str, Cookie]: return self._cookies @property - def encoding(self) -> str: + def encoding(self) -> Optional[str]: return self._encoding @@ -254,7 +286,7 @@ class ResultBodyStream: collected. """ - def __init__(self, chunks: Sequence[bytes]): + def __init__(self, chunks: Sequence[bytes]) -> None: self._chunks = chunks self._chunk_pos = 0 @@ -322,10 +354,12 @@ class Result(_ResultBase): if the response is not valid JSON. """ - def __init__(self, iterable, status, headers): + def __init__( + self, iterable: Iterable[bytes], status: str, headers: HeaderList + ) -> None: super().__init__(status, headers) - self._text = None + self._text: Optional[str] = None self._content = b''.join(iterable) @property @@ -348,13 +382,13 @@ def text(self) -> str: return self._text @property - def json(self) -> Optional[Union[dict, list, str, int, float, bool]]: + def json(self) -> Any: if not self.text: return None return json_module.loads(self.text) - def __repr__(self): + def __repr__(self) -> str: content_type = self.headers.get('Content-Type', '') if len(self.content) > 40: @@ -409,7 +443,14 @@ class StreamedResult(_ResultBase): stream (ResultStream): Raw response body, as a byte stream. """ - def __init__(self, body_chunks, status, headers, task, req_event_emitter): + def __init__( + self, + body_chunks: Sequence[bytes], + status: str, + headers: HeaderList, + task: asyncio.Task, + req_event_emitter: helpers.ASGIRequestEventEmitter, + ): super().__init__(status, headers) self._task = task @@ -420,7 +461,7 @@ def __init__(self, body_chunks, status, headers, task, req_event_emitter): def stream(self) -> ResultBodyStream: return self._stream - async def finalize(self): + async def finalize(self) -> None: """Finalize the encapsulated simulated request. This method causes the request event emitter to begin emitting @@ -436,28 +477,28 @@ async def finalize(self): # appears to be "hanging", which might indicates that the app is # not handling the reception of events correctly. def simulate_request( - app, - method='GET', - path='/', - query_string=None, - headers=None, - content_type=None, - body=None, - json=None, - file_wrapper=None, - wsgierrors=None, - params=None, - params_csv=False, - protocol='http', - host=helpers.DEFAULT_HOST, - remote_addr=None, - extras=None, - http_version='1.1', - port=None, - root_path=None, - cookies=None, - asgi_chunk_size=4096, - asgi_disconnect_ttl=300, + app: Callable[..., Any], # accept any asgi/wsgi app + method: str = 'GET', + path: str = '/', + query_string: Optional[str] = None, + headers: Optional[Headers] = None, + content_type: Optional[str] = None, + body: Optional[Union[str, bytes]] = None, + json: Optional[Any] = None, + file_wrapper: Optional[Callable[..., Any]] = None, + wsgierrors: Optional[TextIO] = None, + params: Optional[Mapping[str, Any]] = None, + params_csv: bool = False, + protocol: str = 'http', + host: str = helpers.DEFAULT_HOST, + remote_addr: Optional[str] = None, + extras: Optional[Mapping[str, Any]] = None, + http_version: str = '1.1', + port: Optional[int] = None, + root_path: Optional[str] = None, + cookies: Optional[CookieArg] = None, + asgi_chunk_size: int = 4096, + asgi_disconnect_ttl: int = 300, ) -> _ResultBase: """Simulate a request to a WSGI or ASGI application. @@ -613,7 +654,7 @@ def simulate_request( path=path, query_string=(query_string or ''), headers=headers, - body=body, + body=body or b'', file_wrapper=file_wrapper, host=host, remote_addr=remote_addr, @@ -641,6 +682,7 @@ def simulate_request( iterable = validator(env, srmock) + assert srmock.status is not None and srmock.headers is not None return Result(helpers.closed_wsgi_iterable(iterable), srmock.status, srmock.headers) @@ -649,33 +691,33 @@ def simulate_request( # appears to be "hanging", which might indicates that the app is # not handling the reception of events correctly. async def _simulate_request_asgi( - app, - method='GET', - path='/', - query_string=None, - headers=None, - content_type=None, - body=None, - json=None, - params=None, - params_csv=True, - protocol='http', - host=helpers.DEFAULT_HOST, - remote_addr=None, - extras=None, - http_version='1.1', - port=None, - root_path=None, - asgi_chunk_size=4096, - asgi_disconnect_ttl=300, - cookies=None, + app: Callable[..., Coroutine[Any, Any, Any]], # accept any asgi app + method: str = 'GET', + path: str = '/', + query_string: Optional[str] = None, + headers: Optional[Headers] = None, + content_type: Optional[str] = None, + body: Optional[Union[str, bytes]] = None, + json: Optional[Any] = None, + params: Optional[Mapping[str, Any]] = None, + params_csv: bool = False, + protocol: str = 'http', + host: str = helpers.DEFAULT_HOST, + remote_addr: Optional[str] = None, + extras: Optional[Mapping[str, Any]] = None, + http_version: str = '1.1', + port: Optional[int] = None, + root_path: Optional[str] = None, + asgi_chunk_size: int = 4096, + asgi_disconnect_ttl: int = 300, + cookies: Optional[CookieArg] = None, # NOTE(kgriffs): These are undocumented because they are only # meant to be used internally by the framework (i.e., they are # not part of the public interface.) In case we ever expose # simulate_request_asgi() as part of the public interface, we # don't want these kwargs to be documented. - _one_shot=True, - _stream_result=False, + _one_shot: bool = True, + _stream_result: bool = False, ) -> _ResultBase: """Simulate a request to an ASGI application. @@ -818,7 +860,7 @@ async def _simulate_request_asgi( # --------------------------------------------------------------------- if asgi_disconnect_ttl == 0: # Special case - disconnect_at = 0 + disconnect_at = 0.0 else: disconnect_at = time.time() + max(0, asgi_disconnect_ttl) @@ -874,7 +916,7 @@ async def _simulate_request_asgi( lifespan_event_collector = helpers.ASGIResponseEventCollector() # --------------------------------------------------------------------- - async def conductor(): + async def conductor() -> None: # NOTE(kgriffs): We assume this is a Falcon ASGI app, which supports # the lifespan protocol and thus we do not need to catch # exceptions that would signify no lifespan protocol support. @@ -998,7 +1040,11 @@ async def get_events_sse(): """ - def __init__(self, app, headers=None): + def __init__( + self, + app: Callable[..., Coroutine[Any, Any, Any]], + headers: Optional[Headers] = None, + ): if not _is_asgi_app(app): raise CompatibilityError('ASGIConductor may only be used with an ASGI app') @@ -1007,9 +1053,9 @@ def __init__(self, app, headers=None): self._shutting_down = asyncio.Condition() self._lifespan_event_collector = helpers.ASGIResponseEventCollector() - self._lifespan_task = None + self._lifespan_task: Optional[asyncio.Task] = None - async def __aenter__(self): + async def __aenter__(self) -> ASGIConductor: lifespan_scope = { 'type': ScopeType.LIFESPAN, 'asgi': { @@ -1033,7 +1079,7 @@ async def __aenter__(self): return self - async def __aexit__(self, ex_type, ex, tb): + async def __aexit__(self, ex_type: Any, ex: Any, tb: Any) -> bool: if ex_type: return False @@ -1043,18 +1089,21 @@ async def __aexit__(self, ex_type, ex, tb): self._shutting_down.notify() await _wait_for_shutdown(self._lifespan_event_collector.events) + assert self._lifespan_task is not None await self._lifespan_task return True - async def simulate_get(self, path='/', **kwargs) -> _ResultBase: + async def simulate_get(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a GET request to an ASGI application. (See also: :meth:`falcon.testing.simulate_get`) """ return await self.simulate_request('GET', path, **kwargs) - def simulate_get_stream(self, path='/', **kwargs): + def simulate_get_stream( + self, path: str = '/', **kwargs: Any + ) -> _AsyncContextManager: """Simulate a GET request to an ASGI application with a streamed response. (See also: :meth:`falcon.testing.simulate_get` for a list of @@ -1088,7 +1137,7 @@ def simulate_get_stream(self, path='/', **kwargs): return _AsyncContextManager(self.simulate_request('GET', path, **kwargs)) - def simulate_ws(self, path='/', **kwargs): + def simulate_ws(self, path: str = '/', **kwargs: Any) -> _WSContextManager: """Simulate a WebSocket connection to an ASGI application. All keyword arguments are passed through to @@ -1116,49 +1165,49 @@ def simulate_ws(self, path='/', **kwargs): return _WSContextManager(ws, task_req) - async def simulate_head(self, path='/', **kwargs) -> _ResultBase: + async def simulate_head(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a HEAD request to an ASGI application. (See also: :meth:`falcon.testing.simulate_head`) """ return await self.simulate_request('HEAD', path, **kwargs) - async def simulate_post(self, path='/', **kwargs) -> _ResultBase: + async def simulate_post(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a POST request to an ASGI application. (See also: :meth:`falcon.testing.simulate_post`) """ return await self.simulate_request('POST', path, **kwargs) - async def simulate_put(self, path='/', **kwargs) -> _ResultBase: + async def simulate_put(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a PUT request to an ASGI application. (See also: :meth:`falcon.testing.simulate_put`) """ return await self.simulate_request('PUT', path, **kwargs) - async def simulate_options(self, path='/', **kwargs) -> _ResultBase: + async def simulate_options(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate an OPTIONS request to an ASGI application. (See also: :meth:`falcon.testing.simulate_options`) """ return await self.simulate_request('OPTIONS', path, **kwargs) - async def simulate_patch(self, path='/', **kwargs) -> _ResultBase: + async def simulate_patch(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a PATCH request to an ASGI application. (See also: :meth:`falcon.testing.simulate_patch`) """ return await self.simulate_request('PATCH', path, **kwargs) - async def simulate_delete(self, path='/', **kwargs) -> _ResultBase: + async def simulate_delete(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a DELETE request to an ASGI application. (See also: :meth:`falcon.testing.simulate_delete`) """ return await self.simulate_request('DELETE', path, **kwargs) - async def simulate_request(self, *args, **kwargs) -> _ResultBase: + async def simulate_request(self, *args: Any, **kwargs: Any) -> _ResultBase: """Simulate a request to an ASGI application. Wraps :meth:`falcon.testing.simulate_request` to perform a @@ -1194,7 +1243,7 @@ async def simulate_request(self, *args, **kwargs) -> _ResultBase: websocket = _simulate_method_alias(simulate_ws, replace_name='websocket') -def simulate_get(app, path, **kwargs) -> _ResultBase: +def simulate_get(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: """Simulate a GET request to a WSGI or ASGI application. Equivalent to:: @@ -1297,7 +1346,7 @@ def simulate_get(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'GET', path, **kwargs) -def simulate_head(app, path, **kwargs) -> _ResultBase: +def simulate_head(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: """Simulate a HEAD request to a WSGI or ASGI application. Equivalent to:: @@ -1394,7 +1443,7 @@ def simulate_head(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'HEAD', path, **kwargs) -def simulate_post(app, path, **kwargs) -> _ResultBase: +def simulate_post(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: """Simulate a POST request to a WSGI or ASGI application. Equivalent to:: @@ -1505,7 +1554,7 @@ def simulate_post(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'POST', path, **kwargs) -def simulate_put(app, path, **kwargs) -> _ResultBase: +def simulate_put(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: """Simulate a PUT request to a WSGI or ASGI application. Equivalent to:: @@ -1616,7 +1665,7 @@ def simulate_put(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'PUT', path, **kwargs) -def simulate_options(app, path, **kwargs) -> _ResultBase: +def simulate_options(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: """Simulate an OPTIONS request to a WSGI or ASGI application. Equivalent to:: @@ -1705,7 +1754,7 @@ def simulate_options(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'OPTIONS', path, **kwargs) -def simulate_patch(app, path, **kwargs) -> _ResultBase: +def simulate_patch(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: """Simulate a PATCH request to a WSGI or ASGI application. Equivalent to:: @@ -1811,7 +1860,7 @@ def simulate_patch(app, path, **kwargs) -> _ResultBase: return simulate_request(app, 'PATCH', path, **kwargs) -def simulate_delete(app, path, **kwargs) -> _ResultBase: +def simulate_delete(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: """Simulate a DELETE request to a WSGI or ASGI application. Equivalent to:: @@ -1982,12 +2031,16 @@ class TestClient: # NOTE(aryaniyaps): Prevent pytest from collecting tests on the class. __test__ = False - def __init__(self, app, headers=None): + def __init__( + self, + app: Callable[..., Any], # accept any asgi/wsgi app + headers: Optional[Headers] = None, + ) -> None: self.app = app self._default_headers = headers - self._conductor = None + self._conductor: Optional[ASGIConductor] = None - async def __aenter__(self): + async def __aenter__(self) -> ASGIConductor: if not _is_asgi_app(self.app): raise CompatibilityError( 'a conductor context manager may only be used with a Falcon ASGI app' @@ -2002,7 +2055,8 @@ async def __aenter__(self): return self._conductor - async def __aexit__(self, ex_type, ex, tb): + async def __aexit__(self, ex_type: Any, ex: Any, tb: Any) -> bool: + assert self._conductor is not None result = await self._conductor.__aexit__(ex_type, ex, tb) # NOTE(kgriffs): Reset to allow this instance of TestClient to be @@ -2011,56 +2065,56 @@ async def __aexit__(self, ex_type, ex, tb): return result - def simulate_get(self, path='/', **kwargs) -> _ResultBase: + def simulate_get(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a GET request to a WSGI application. (See also: :meth:`falcon.testing.simulate_get`) """ return self.simulate_request('GET', path, **kwargs) - def simulate_head(self, path='/', **kwargs) -> _ResultBase: + def simulate_head(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a HEAD request to a WSGI application. (See also: :meth:`falcon.testing.simulate_head`) """ return self.simulate_request('HEAD', path, **kwargs) - def simulate_post(self, path='/', **kwargs) -> _ResultBase: + def simulate_post(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a POST request to a WSGI application. (See also: :meth:`falcon.testing.simulate_post`) """ return self.simulate_request('POST', path, **kwargs) - def simulate_put(self, path='/', **kwargs) -> _ResultBase: + def simulate_put(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a PUT request to a WSGI application. (See also: :meth:`falcon.testing.simulate_put`) """ return self.simulate_request('PUT', path, **kwargs) - def simulate_options(self, path='/', **kwargs) -> _ResultBase: + def simulate_options(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate an OPTIONS request to a WSGI application. (See also: :meth:`falcon.testing.simulate_options`) """ return self.simulate_request('OPTIONS', path, **kwargs) - def simulate_patch(self, path='/', **kwargs) -> _ResultBase: + def simulate_patch(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a PATCH request to a WSGI application. (See also: :meth:`falcon.testing.simulate_patch`) """ return self.simulate_request('PATCH', path, **kwargs) - def simulate_delete(self, path='/', **kwargs) -> _ResultBase: + def simulate_delete(self, path: str = '/', **kwargs: Any) -> _ResultBase: """Simulate a DELETE request to a WSGI application. (See also: :meth:`falcon.testing.simulate_delete`) """ return self.simulate_request('DELETE', path, **kwargs) - def simulate_request(self, *args, **kwargs) -> _ResultBase: + def simulate_request(self, *args: Any, **kwargs: Any) -> _ResultBase: """Simulate a request to a WSGI application. Wraps :meth:`falcon.testing.simulate_request` to perform a @@ -2097,25 +2151,28 @@ def simulate_request(self, *args, **kwargs) -> _ResultBase: class _AsyncContextManager: - def __init__(self, coro): + def __init__(self, coro: Awaitable[StreamedResult]): self._coro = coro - self._obj = None + self._obj: Optional[StreamedResult] = None - async def __aenter__(self): + async def __aenter__(self) -> StreamedResult: self._obj = await self._coro return self._obj - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + assert self._obj is not None await self._obj.finalize() self._obj = None class _WSContextManager: - def __init__(self, ws, task_req): + def __init__( + self, ws: helpers.ASGIWebSocketSimulator, task_req: asyncio.Task + ) -> None: self._ws = ws self._task_req = task_req - async def __aenter__(self): + async def __aenter__(self) -> helpers.ASGIWebSocketSimulator: ready_waiter = asyncio.create_task(self._ws.wait_ready()) # NOTE(kgriffs): Wait on both so that in the case that the request @@ -2140,14 +2197,22 @@ async def __aenter__(self): return self._ws - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: await self._ws.close() await self._task_req def _prepare_sim_args( - path, query_string, params, params_csv, content_type, headers, body, json, extras -): + path: str, + query_string: Optional[str], + params: Optional[Mapping[str, Any]], + params_csv: bool, + content_type: Optional[str], + headers: Optional[Headers], + body: Optional[Union[str, bytes]], + json: Optional[Any], + extras: Optional[Mapping[str, Any]], +) -> Tuple[str, str, Optional[Headers], Optional[Union[str, bytes]], Mapping[str, Any]]: if not path.startswith('/'): raise ValueError("path must start with '/'") @@ -2183,7 +2248,7 @@ def _prepare_sim_args( return path, query_string, headers, body, extras -def _is_asgi_app(app): +def _is_asgi_app(app: Callable[..., Any]) -> bool: app_args = inspect.getfullargspec(app).args num_app_args = len(app_args) @@ -2198,7 +2263,7 @@ def _is_asgi_app(app): return is_asgi -async def _wait_for_startup(events): +async def _wait_for_startup(events: Iterable[AsgiEvent]) -> None: # NOTE(kgriffs): This is covered, but our gate for some reason doesn't # understand `while True`. while True: # pragma: nocover @@ -2215,7 +2280,7 @@ async def _wait_for_startup(events): await asyncio.sleep(0) -async def _wait_for_shutdown(events): +async def _wait_for_shutdown(events: Iterable[AsgiEvent]) -> None: # NOTE(kgriffs): This is covered, but our gate for some reason doesn't # understand `while True`. while True: # pragma: nocover diff --git a/falcon/testing/helpers.py b/falcon/testing/helpers.py index c44208ec0..fa00920b2 100644 --- a/falcon/testing/helpers.py +++ b/falcon/testing/helpers.py @@ -30,7 +30,6 @@ import contextlib from enum import auto from enum import Enum -from http.cookiejar import Cookie import io import itertools import json @@ -64,6 +63,7 @@ from falcon.asgi_spec import WSCloseCode from falcon.constants import SINGLETON_HEADERS import falcon.request +from falcon.typing import CookieArg from falcon.typing import HeaderArg from falcon.typing import HeaderList from falcon.typing import ResponseStatus @@ -839,7 +839,7 @@ def _create_checked_disconnect(self) -> AsgiEvent: # get_encoding_from_headers() is Copyright 2016 Kenneth Reitz, and is # used here under the terms of the Apache License, Version 2.0. -def get_encoding_from_headers(headers: Dict[str, str]) -> Optional[str]: +def get_encoding_from_headers(headers: Mapping[str, str]) -> Optional[str]: """Return encoding from given HTTP Header Dict. Args: @@ -900,9 +900,6 @@ def rand_string(min: int, max: int) -> str: return ''.join([chr(int_gen(ord(' '), ord('~'))) for __ in range(string_length)]) -_CookieType = Mapping[str, Union[str, Cookie]] - - def create_scope( path: str = '/', query_string: str = '', @@ -916,7 +913,7 @@ def create_scope( root_path: Optional[str] = None, content_length: Optional[int] = None, include_server: bool = True, - cookies: Optional[_CookieType] = None, + cookies: Optional[CookieArg] = None, ) -> Dict[str, Any]: """Create a mock ASGI scope ``dict`` for simulating HTTP requests. @@ -1133,13 +1130,13 @@ def create_environ( port: Optional[int] = None, headers: Optional[HeaderArg] = None, app: Optional[str] = None, - body: str = '', + body: Union[str, bytes] = b'', method: str = 'GET', - wsgierrors: Optional[io.StringIO] = None, + wsgierrors: Optional[TextIO] = None, file_wrapper: Optional[Callable[..., Any]] = None, remote_addr: Optional[str] = None, root_path: Optional[str] = None, - cookies: Optional[_CookieType] = None, + cookies: Optional[CookieArg] = None, ) -> Dict[str, Any]: """Create a mock PEP-3333 environ ``dict`` for simulating WSGI requests. @@ -1454,7 +1451,7 @@ def _add_headers_to_scope( port: int, scheme: Optional[str], http_version: str, - cookies: Optional[_CookieType], + cookies: Optional[CookieArg], ) -> None: found_ua = False prepared_headers: List[Iterable[bytes]] = [] @@ -1518,7 +1515,7 @@ def _fixup_http_version(http_version: str) -> str: return http_version -def _make_cookie_values(cookies: _CookieType) -> str: +def _make_cookie_values(cookies: CookieArg) -> str: return '; '.join( [ '{}={}'.format(key, cookie.value if hasattr(cookie, 'value') else cookie) diff --git a/falcon/testing/srmock.py b/falcon/testing/srmock.py index bae2613c9..97decb90d 100644 --- a/falcon/testing/srmock.py +++ b/falcon/testing/srmock.py @@ -20,11 +20,10 @@ from __future__ import annotations -from typing import Optional +from typing import Any, Optional from falcon import util from falcon.typing import HeaderList -from falcon.typing import ResponseStatus class StartResponseMock: @@ -43,16 +42,16 @@ class StartResponseMock: def __init__(self) -> None: self._called = 0 - self.status: Optional[ResponseStatus] = None + self.status: Optional[str] = None self.headers: Optional[HeaderList] = None - self.exc_info: Optional[Exception] = None + self.exc_info: Optional[Any] = None def __call__( self, - status: ResponseStatus, + status: str, headers: HeaderList, - exc_info: Optional[Exception] = None, - ) -> None: + exc_info: Optional[Any] = None, + ) -> Any: """Implement the PEP-3333 `start_response` protocol.""" self._called += 1 diff --git a/falcon/testing/test_case.py b/falcon/testing/test_case.py index 1cb95328c..368ce0978 100644 --- a/falcon/testing/test_case.py +++ b/falcon/testing/test_case.py @@ -78,7 +78,7 @@ def test_get_message(self): # NOTE(vytas): Here we have to restore __test__ to allow collecting tests! __test__ = True - def setUp(self): + def setUp(self) -> None: super(TestCase, self).setUp() app = falcon.App() diff --git a/falcon/typing.py b/falcon/typing.py index 7bcdfa185..881e92766 100644 --- a/falcon/typing.py +++ b/falcon/typing.py @@ -18,6 +18,7 @@ from enum import auto from enum import Enum import http +from http.cookiejar import Cookie import sys from typing import ( Any, @@ -27,6 +28,7 @@ Dict, List, Literal, + Mapping, Optional, Pattern, Protocol, @@ -66,7 +68,7 @@ class _Missing(Enum): MissingOr = Union[Literal[_Missing.MISSING], _T] Link = Dict[str, str] - +CookieArg = Mapping[str, Union[str, Cookie]] # Error handlers ErrorHandler = Callable[['Request', 'Response', BaseException, Dict[str, Any]], None] diff --git a/falcon/util/misc.py b/falcon/util/misc.py index 1fbe09ef3..7ee32d983 100644 --- a/falcon/util/misc.py +++ b/falcon/util/misc.py @@ -30,7 +30,7 @@ import http import inspect import re -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import unicodedata from falcon import status_codes @@ -216,7 +216,9 @@ def http_date_to_dt(http_date: str, obs_date: bool = False) -> datetime.datetime def to_query_str( - params: Dict[str, Any], comma_delimited_lists: bool = True, prefix: bool = True + params: Optional[Mapping[str, Any]], + comma_delimited_lists: bool = True, + prefix: bool = True, ) -> str: """Convert a dictionary of parameters to a query string. From b1959cc9fde60568065ddc20c0ca1bc84227cbf0 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Mon, 23 Sep 2024 09:05:02 +0200 Subject: [PATCH 6/8] typing: properly type return value of test client simulate* methods. fixes: #2207 --- falcon/testing/client.py | 122 ++++++++++++++++++++++++++++++--------- 1 file changed, 94 insertions(+), 28 deletions(-) diff --git a/falcon/testing/client.py b/falcon/testing/client.py index 3774e7bf6..d1d5a6d5b 100644 --- a/falcon/testing/client.py +++ b/falcon/testing/client.py @@ -34,8 +34,10 @@ Coroutine, Dict, Iterable, + Literal, Mapping, Optional, + overload, Sequence, TextIO, Tuple, @@ -499,7 +501,7 @@ def simulate_request( cookies: Optional[CookieArg] = None, asgi_chunk_size: int = 4096, asgi_disconnect_ttl: int = 300, -) -> _ResultBase: +) -> Result: """Simulate a request to a WSGI or ASGI application. Performs a request against a WSGI or ASGI application. In the case of @@ -612,7 +614,7 @@ def simulate_request( """ if _is_asgi_app(app): - return async_to_sync( + return async_to_sync( # type: ignore[return-value] _simulate_request_asgi, app, method=method, @@ -686,6 +688,60 @@ def simulate_request( return Result(helpers.closed_wsgi_iterable(iterable), srmock.status, srmock.headers) +@overload +async def _simulate_request_asgi( + app: Callable[..., Coroutine[Any, Any, Any]], + method: str = ..., + path: str = ..., + query_string: Optional[str] = ..., + headers: Optional[Headers] = ..., + content_type: Optional[str] = ..., + body: Optional[Union[str, bytes]] = ..., + json: Optional[Any] = ..., + params: Optional[Mapping[str, Any]] = ..., + params_csv: bool = ..., + protocol: str = ..., + host: str = ..., + remote_addr: Optional[str] = ..., + extras: Optional[Mapping[str, Any]] = ..., + http_version: str = ..., + port: Optional[int] = ..., + root_path: Optional[str] = ..., + asgi_chunk_size: int = ..., + asgi_disconnect_ttl: int = ..., + cookies: Optional[CookieArg] = ..., + _one_shot: Literal[False] = ..., + _stream_result: Literal[True] = ..., +) -> StreamedResult: ... + + +@overload +async def _simulate_request_asgi( + app: Callable[..., Coroutine[Any, Any, Any]], + method: str = ..., + path: str = ..., + query_string: Optional[str] = ..., + headers: Optional[Headers] = ..., + content_type: Optional[str] = ..., + body: Optional[Union[str, bytes]] = ..., + json: Optional[Any] = ..., + params: Optional[Mapping[str, Any]] = ..., + params_csv: bool = ..., + protocol: str = ..., + host: str = ..., + remote_addr: Optional[str] = ..., + extras: Optional[Mapping[str, Any]] = ..., + http_version: str = ..., + port: Optional[int] = ..., + root_path: Optional[str] = ..., + asgi_chunk_size: int = ..., + asgi_disconnect_ttl: int = ..., + cookies: Optional[CookieArg] = ..., + _one_shot: Literal[True] = ..., + _stream_result: bool = ..., +) -> Result: ... + + # NOTE(kgriffs): The default of asgi_disconnect_ttl was chosen to be # relatively long (5 minutes) to help testers notice when something # appears to be "hanging", which might indicates that the app is @@ -718,7 +774,7 @@ async def _simulate_request_asgi( # don't want these kwargs to be documented. _one_shot: bool = True, _stream_result: bool = False, -) -> _ResultBase: +) -> Union[Result, StreamedResult]: """Simulate a request to an ASGI application. Keyword Args: @@ -1094,7 +1150,7 @@ async def __aexit__(self, ex_type: Any, ex: Any, tb: Any) -> bool: return True - async def simulate_get(self, path: str = '/', **kwargs: Any) -> _ResultBase: + async def simulate_get(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a GET request to an ASGI application. (See also: :meth:`falcon.testing.simulate_get`) @@ -1165,53 +1221,63 @@ def simulate_ws(self, path: str = '/', **kwargs: Any) -> _WSContextManager: return _WSContextManager(ws, task_req) - async def simulate_head(self, path: str = '/', **kwargs: Any) -> _ResultBase: + async def simulate_head(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a HEAD request to an ASGI application. (See also: :meth:`falcon.testing.simulate_head`) """ return await self.simulate_request('HEAD', path, **kwargs) - async def simulate_post(self, path: str = '/', **kwargs: Any) -> _ResultBase: + async def simulate_post(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a POST request to an ASGI application. (See also: :meth:`falcon.testing.simulate_post`) """ return await self.simulate_request('POST', path, **kwargs) - async def simulate_put(self, path: str = '/', **kwargs: Any) -> _ResultBase: + async def simulate_put(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a PUT request to an ASGI application. (See also: :meth:`falcon.testing.simulate_put`) """ return await self.simulate_request('PUT', path, **kwargs) - async def simulate_options(self, path: str = '/', **kwargs: Any) -> _ResultBase: + async def simulate_options(self, path: str = '/', **kwargs: Any) -> Result: """Simulate an OPTIONS request to an ASGI application. (See also: :meth:`falcon.testing.simulate_options`) """ return await self.simulate_request('OPTIONS', path, **kwargs) - async def simulate_patch(self, path: str = '/', **kwargs: Any) -> _ResultBase: + async def simulate_patch(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a PATCH request to an ASGI application. (See also: :meth:`falcon.testing.simulate_patch`) """ return await self.simulate_request('PATCH', path, **kwargs) - async def simulate_delete(self, path: str = '/', **kwargs: Any) -> _ResultBase: + async def simulate_delete(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a DELETE request to an ASGI application. (See also: :meth:`falcon.testing.simulate_delete`) """ return await self.simulate_request('DELETE', path, **kwargs) - async def simulate_request(self, *args: Any, **kwargs: Any) -> _ResultBase: + @overload + async def simulate_request( + self, *args: Any, _stream_result: Literal[True], **kwargs: Any + ) -> StreamedResult: ... + + @overload + async def simulate_request(self, *args: Any, **kwargs: Any) -> Result: ... + + async def simulate_request( + self, *args: Any, **kwargs: Any + ) -> Union[Result, StreamedResult]: """Simulate a request to an ASGI application. - Wraps :meth:`falcon.testing.simulate_request` to perform a - WSGI request directly against ``self.app``. Equivalent to:: + Wraps :meth:`falcon.testing.simulate_request` to perform an + ASGI request directly against ``self.app``. Equivalent to:: falcon.testing.simulate_request(self.app, *args, **kwargs) """ @@ -1243,7 +1309,7 @@ async def simulate_request(self, *args: Any, **kwargs: Any) -> _ResultBase: websocket = _simulate_method_alias(simulate_ws, replace_name='websocket') -def simulate_get(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: +def simulate_get(app: Callable[..., Any], path: str, **kwargs: Any) -> Result: """Simulate a GET request to a WSGI or ASGI application. Equivalent to:: @@ -1346,7 +1412,7 @@ def simulate_get(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBa return simulate_request(app, 'GET', path, **kwargs) -def simulate_head(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: +def simulate_head(app: Callable[..., Any], path: str, **kwargs: Any) -> Result: """Simulate a HEAD request to a WSGI or ASGI application. Equivalent to:: @@ -1443,7 +1509,7 @@ def simulate_head(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultB return simulate_request(app, 'HEAD', path, **kwargs) -def simulate_post(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: +def simulate_post(app: Callable[..., Any], path: str, **kwargs: Any) -> Result: """Simulate a POST request to a WSGI or ASGI application. Equivalent to:: @@ -1554,7 +1620,7 @@ def simulate_post(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultB return simulate_request(app, 'POST', path, **kwargs) -def simulate_put(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: +def simulate_put(app: Callable[..., Any], path: str, **kwargs: Any) -> Result: """Simulate a PUT request to a WSGI or ASGI application. Equivalent to:: @@ -1665,7 +1731,7 @@ def simulate_put(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBa return simulate_request(app, 'PUT', path, **kwargs) -def simulate_options(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: +def simulate_options(app: Callable[..., Any], path: str, **kwargs: Any) -> Result: """Simulate an OPTIONS request to a WSGI or ASGI application. Equivalent to:: @@ -1754,7 +1820,7 @@ def simulate_options(app: Callable[..., Any], path: str, **kwargs: Any) -> _Resu return simulate_request(app, 'OPTIONS', path, **kwargs) -def simulate_patch(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: +def simulate_patch(app: Callable[..., Any], path: str, **kwargs: Any) -> Result: """Simulate a PATCH request to a WSGI or ASGI application. Equivalent to:: @@ -1860,7 +1926,7 @@ def simulate_patch(app: Callable[..., Any], path: str, **kwargs: Any) -> _Result return simulate_request(app, 'PATCH', path, **kwargs) -def simulate_delete(app: Callable[..., Any], path: str, **kwargs: Any) -> _ResultBase: +def simulate_delete(app: Callable[..., Any], path: str, **kwargs: Any) -> Result: """Simulate a DELETE request to a WSGI or ASGI application. Equivalent to:: @@ -2065,56 +2131,56 @@ async def __aexit__(self, ex_type: Any, ex: Any, tb: Any) -> bool: return result - def simulate_get(self, path: str = '/', **kwargs: Any) -> _ResultBase: + def simulate_get(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a GET request to a WSGI application. (See also: :meth:`falcon.testing.simulate_get`) """ return self.simulate_request('GET', path, **kwargs) - def simulate_head(self, path: str = '/', **kwargs: Any) -> _ResultBase: + def simulate_head(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a HEAD request to a WSGI application. (See also: :meth:`falcon.testing.simulate_head`) """ return self.simulate_request('HEAD', path, **kwargs) - def simulate_post(self, path: str = '/', **kwargs: Any) -> _ResultBase: + def simulate_post(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a POST request to a WSGI application. (See also: :meth:`falcon.testing.simulate_post`) """ return self.simulate_request('POST', path, **kwargs) - def simulate_put(self, path: str = '/', **kwargs: Any) -> _ResultBase: + def simulate_put(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a PUT request to a WSGI application. (See also: :meth:`falcon.testing.simulate_put`) """ return self.simulate_request('PUT', path, **kwargs) - def simulate_options(self, path: str = '/', **kwargs: Any) -> _ResultBase: + def simulate_options(self, path: str = '/', **kwargs: Any) -> Result: """Simulate an OPTIONS request to a WSGI application. (See also: :meth:`falcon.testing.simulate_options`) """ return self.simulate_request('OPTIONS', path, **kwargs) - def simulate_patch(self, path: str = '/', **kwargs: Any) -> _ResultBase: + def simulate_patch(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a PATCH request to a WSGI application. (See also: :meth:`falcon.testing.simulate_patch`) """ return self.simulate_request('PATCH', path, **kwargs) - def simulate_delete(self, path: str = '/', **kwargs: Any) -> _ResultBase: + def simulate_delete(self, path: str = '/', **kwargs: Any) -> Result: """Simulate a DELETE request to a WSGI application. (See also: :meth:`falcon.testing.simulate_delete`) """ return self.simulate_request('DELETE', path, **kwargs) - def simulate_request(self, *args: Any, **kwargs: Any) -> _ResultBase: + def simulate_request(self, *args: Any, **kwargs: Any) -> Result: """Simulate a request to a WSGI application. Wraps :meth:`falcon.testing.simulate_request` to perform a From 876703c79411db9ac57d30f0e8c4fbbea1442973 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Mon, 23 Sep 2024 09:11:58 +0200 Subject: [PATCH 7/8] fix: assert is valid only after getting the data --- falcon/testing/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/falcon/testing/client.py b/falcon/testing/client.py index d1d5a6d5b..5a29b9573 100644 --- a/falcon/testing/client.py +++ b/falcon/testing/client.py @@ -684,8 +684,9 @@ def simulate_request( iterable = validator(env, srmock) + data = helpers.closed_wsgi_iterable(iterable) assert srmock.status is not None and srmock.headers is not None - return Result(helpers.closed_wsgi_iterable(iterable), srmock.status, srmock.headers) + return Result(data, srmock.status, srmock.headers) @overload From 6b309d341eecf769e34ea5be3004d2d580bb8c57 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Mon, 23 Sep 2024 19:20:40 +0200 Subject: [PATCH 8/8] chore: ignore fa errors in tutorial --- examples/ws_tutorial/ws_tutorial/app.py | 2 -- pyproject.toml | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/ws_tutorial/ws_tutorial/app.py b/examples/ws_tutorial/ws_tutorial/app.py index 9dfab7e81..054514a61 100644 --- a/examples/ws_tutorial/ws_tutorial/app.py +++ b/examples/ws_tutorial/ws_tutorial/app.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from datetime import datetime import logging import pathlib diff --git a/pyproject.toml b/pyproject.toml index a6b6cc96e..ebccd4505 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -211,6 +211,7 @@ exclude = ["examples", "tests"] "F403" ] "falcon/uri.py" = ["F401"] + "examples/*" = ["FA"] [tool.ruff.lint.isort] case-sensitive = false