diff --git a/pyproject.toml b/pyproject.toml index d36e57c2..ef994a01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,20 @@ markers = [ ] xfail_strict = true +[tool.mypy] +packages = [ + "responder", +] +exclude = [ +] +check_untyped_defs = true +explicit_package_bases = true +ignore_missing_imports = true +implicit_optional = true +install_types = true +namespace_packages = true +non_interactive = true + [tool.poe.tasks] check = [ @@ -98,7 +112,7 @@ lint = [ { cmd = "ruff format --check ." }, { cmd = "ruff check ." }, { cmd = "validate-pyproject pyproject.toml" }, - # { cmd = "mypy" }, + { cmd = "mypy" }, ] release = [ diff --git a/responder/api.py b/responder/api.py index 27f3cbfd..4311c8f9 100644 --- a/responder/api.py +++ b/responder/api.py @@ -221,11 +221,16 @@ async def _static_response(self, req, resp): with open(index, "r") as f: resp.html = f.read() else: - resp.status_code = status_codes.HTTP_404 + resp.status_code = status_codes.HTTP_404 # type: ignore[attr-defined] resp.text = "Not found." def redirect( - self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301 + self, + resp, + location, + *, + set_text=True, + status_code=status_codes.HTTP_301, # type: ignore[attr-defined] ): """ Redirects a given response to a given location. @@ -365,7 +370,7 @@ def serve(self, *, address=None, port=None, debug=False, **options): port = 5042 def spawn(): - uvicorn.run(self, host=address, port=port, debug=debug, **options) + uvicorn.run(self, host=address, port=port, **options) spawn() diff --git a/responder/ext/schema/__init__.py b/responder/ext/schema/__init__.py index 0d938adb..a2caa318 100644 --- a/responder/ext/schema/__init__.py +++ b/responder/ext/schema/__init__.py @@ -133,6 +133,6 @@ def docs_response(self, req, resp): resp.html = self.docs def schema_response(self, req, resp): - resp.status_code = status_codes.HTTP_200 + resp.status_code = status_codes.HTTP_200 # type: ignore[attr-defined] resp.headers["Content-Type"] = "application/x-yaml" resp.content = self.openapi diff --git a/responder/models.py b/responder/models.py index da72a705..f688f058 100644 --- a/responder/models.py +++ b/responder/models.py @@ -18,7 +18,7 @@ ) from .statics import DEFAULT_ENCODING -from .status_codes import HTTP_301 +from .status_codes import HTTP_301 # type: ignore[attr-defined] class QueryDict(dict): @@ -107,7 +107,7 @@ def __init__(self, scope, receive, api=None, formats=None): self.api = api self._content = None - headers = CaseInsensitiveDict() + headers: CaseInsensitiveDict = CaseInsensitiveDict() for key, value in self._starlette.headers.items(): headers[key] = value @@ -150,7 +150,7 @@ def cookies(self): cookies = RequestsCookieJar() cookie_header = self.headers.get("Cookie", "") - bc = SimpleCookie(cookie_header) + bc: SimpleCookie = SimpleCookie(cookie_header) for key, morsel in bc.items(): cookies[key] = morsel.value @@ -239,9 +239,20 @@ async def media(self, format: t.Union[str, t.Callable] = None): # noqa: A001, A format = "yaml" if "yaml" in self.mimetype or "" else "json" # noqa: A001 format = "form" if "form" in self.mimetype or "" else format # noqa: A001 - if format in self.formats: - return await self.formats[format](self) - return await format(self) + formatter: t.Callable + if isinstance(format, str): + try: + formatter = self.formats[format] + except KeyError as ex: + raise ValueError(f"Unable to process data in '{format}' format") from ex + + elif callable(format): + formatter = format + + else: + raise TypeError(f"Invalid 'format' argument: {format}") + + return await formatter(self) def content_setter(mimetype): @@ -275,7 +286,8 @@ class Response: def __init__(self, req, *, formats): self.req = req - self.status_code = None #: The HTTP Status Code to use for the Response. + #: The HTTP Status Code to use for the Response. + self.status_code: t.Union[int, None] = None self.content = None #: A bytes representation of the response body. self.mimetype = None self.encoding = DEFAULT_ENCODING @@ -285,7 +297,7 @@ def __init__(self, req, *, formats): self.headers = {} #: A Python dictionary of ``{key: value}``, #: representing the headers of the response. self.formats = formats - self.cookies = SimpleCookie() #: The cookies set in the Response + self.cookies: SimpleCookie = SimpleCookie() #: The cookies set in the Response self.session = ( req.session ) #: The cookie-based session data, in dict form, to add to the Response. @@ -365,16 +377,25 @@ async def __call__(self, scope, receive, send): if self.headers: headers.update(self.headers) + response_cls: t.Union[ + t.Type[StarletteResponse], t.Type[StarletteStreamingResponse] + ] if self._stream is not None: response_cls = StarletteStreamingResponse else: response_cls = StarletteResponse - response = response_cls(body, status_code=self.status_code, headers=headers) + response = response_cls(body, status_code=self.status_code_safe, headers=headers) self._prepare_cookies(response) await response(scope, receive, send) @property def ok(self): - return 200 <= self.status_code < 300 + return 200 <= self.status_code_safe < 300 + + @property + def status_code_safe(self) -> int: + if self.status_code is None: + raise RuntimeError("HTTP status code has not been defined") + return self.status_code diff --git a/responder/routes.py b/responder/routes.py index 70628083..7a9a8662 100644 --- a/responder/routes.py +++ b/responder/routes.py @@ -2,11 +2,13 @@ import inspect import re import traceback +import typing as t from collections import defaultdict from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.middleware.wsgi import WSGIMiddleware +from starlette.types import ASGIApp from starlette.websockets import WebSocket, WebSocketClose from . import status_codes @@ -121,7 +123,7 @@ async def __call__(self, scope, receive, send): views.append(view) except AttributeError as ex: if on_request is None: - raise HTTPException(status_code=status_codes.HTTP_405) from ex + raise HTTPException(status_code=status_codes.HTTP_405) from ex # type: ignore[attr-defined] else: views.append(self.endpoint) @@ -135,7 +137,7 @@ async def __call__(self, scope, receive, send): await run_in_threadpool(view, request, response, **path_params) if response.status_code is None: - response.status_code = status_codes.HTTP_200 + response.status_code = status_codes.HTTP_200 # type: ignore[attr-defined] await response(scope, receive, send) @@ -207,7 +209,7 @@ class Router: def __init__(self, routes=None, default_response=None, before_requests=None): self.routes = [] if routes is None else list(routes) # [TODO] Make its own router - self.apps = {} + self.apps: t.Dict[str, ASGIApp] = {} self.default_endpoint = ( self.default_response if default_response is None else default_response ) @@ -255,7 +257,7 @@ def add_route( def mount(self, route, app): """Mounts ASGI / WSGI applications at a given route""" - self.apps.update(route, app) + self.apps.update({route: app}) def add_event_handler(self, event_type, handler): assert event_type in ( @@ -287,14 +289,14 @@ def url_for(self, endpoint, **params): async def default_response(self, scope, receive, send): if scope["type"] == "websocket": websocket_close = WebSocketClose() - await websocket_close(receive, send) + await websocket_close(scope, receive, send) return # FIXME: Please review! request = Request(scope, receive) response = Response(request, formats=get_formats()) # noqa: F841 - raise HTTPException(status_code=status_codes.HTTP_404) + raise HTTPException(status_code=status_codes.HTTP_404) # type: ignore[attr-defined] def _resolve_route(self, scope): for route in self.routes: diff --git a/setup.py b/setup.py index 7ffe9fed..88f27162 100644 --- a/setup.py +++ b/setup.py @@ -131,7 +131,7 @@ def run(self): ], "graphql": ["graphene"], "release": ["build", "twine"], - "test": ["pytest", "pytest-cov", "pytest-mock", "flask"], + "test": ["flask", "mypy", "pytest", "pytest-cov", "pytest-mock"], }, include_package_data=True, license="Apache 2.0",