Skip to content

Commit

Permalink
Sandbox: Enable mypy type checker
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Oct 30, 2024
1 parent e4cff76 commit a5b6d36
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 22 deletions.
16 changes: 15 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ markers = [
]
xfail_strict = true

[tool.mypy]
packages = [
"responder",
]
exclude = [
]
check_untyped_defs = true
explicit_package_bases = true
ignore_missing_imports = true
implicit_optional = true
install_types = true
namespace_packages = true
non_interactive = true

[tool.poe.tasks]

check = [
Expand Down Expand Up @@ -98,7 +112,7 @@ lint = [
{ cmd = "ruff format --check ." },
{ cmd = "ruff check ." },
{ cmd = "validate-pyproject pyproject.toml" },
# { cmd = "mypy" },
{ cmd = "mypy" },
]

release = [
Expand Down
11 changes: 8 additions & 3 deletions responder/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,16 @@ async def _static_response(self, req, resp):
with open(index, "r") as f:
resp.html = f.read()
else:
resp.status_code = status_codes.HTTP_404
resp.status_code = status_codes.HTTP_404 # type: ignore[attr-defined]
resp.text = "Not found."

def redirect(
self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301
self,
resp,
location,
*,
set_text=True,
status_code=status_codes.HTTP_301, # type: ignore[attr-defined]
):
"""
Redirects a given response to a given location.
Expand Down Expand Up @@ -365,7 +370,7 @@ def serve(self, *, address=None, port=None, debug=False, **options):
port = 5042

def spawn():
uvicorn.run(self, host=address, port=port, debug=debug, **options)
uvicorn.run(self, host=address, port=port, **options)

spawn()

Expand Down
2 changes: 1 addition & 1 deletion responder/ext/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,6 @@ def docs_response(self, req, resp):
resp.html = self.docs

def schema_response(self, req, resp):
resp.status_code = status_codes.HTTP_200
resp.status_code = status_codes.HTTP_200 # type: ignore[attr-defined]
resp.headers["Content-Type"] = "application/x-yaml"
resp.content = self.openapi
41 changes: 31 additions & 10 deletions responder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

from .statics import DEFAULT_ENCODING
from .status_codes import HTTP_301
from .status_codes import HTTP_301 # type: ignore[attr-defined]


class QueryDict(dict):
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(self, scope, receive, api=None, formats=None):
self.api = api
self._content = None

headers = CaseInsensitiveDict()
headers: CaseInsensitiveDict = CaseInsensitiveDict()
for key, value in self._starlette.headers.items():
headers[key] = value

Expand Down Expand Up @@ -150,7 +150,7 @@ def cookies(self):
cookies = RequestsCookieJar()
cookie_header = self.headers.get("Cookie", "")

bc = SimpleCookie(cookie_header)
bc: SimpleCookie = SimpleCookie(cookie_header)
for key, morsel in bc.items():
cookies[key] = morsel.value

Expand Down Expand Up @@ -239,9 +239,20 @@ async def media(self, format: t.Union[str, t.Callable] = None): # noqa: A001, A
format = "yaml" if "yaml" in self.mimetype or "" else "json" # noqa: A001
format = "form" if "form" in self.mimetype or "" else format # noqa: A001

if format in self.formats:
return await self.formats[format](self)
return await format(self)
formatter: t.Callable
if isinstance(format, str):
try:
formatter = self.formats[format]
except KeyError as ex:
raise ValueError(f"Unable to process data in '{format}' format") from ex

elif callable(format):
formatter = format

else:
raise TypeError(f"Invalid 'format' argument: {format}")

return await formatter(self)


def content_setter(mimetype):
Expand Down Expand Up @@ -275,7 +286,8 @@ class Response:

def __init__(self, req, *, formats):
self.req = req
self.status_code = None #: The HTTP Status Code to use for the Response.
#: The HTTP Status Code to use for the Response.
self.status_code: t.Union[int, None] = None
self.content = None #: A bytes representation of the response body.
self.mimetype = None
self.encoding = DEFAULT_ENCODING
Expand All @@ -285,7 +297,7 @@ def __init__(self, req, *, formats):
self.headers = {} #: A Python dictionary of ``{key: value}``,
#: representing the headers of the response.
self.formats = formats
self.cookies = SimpleCookie() #: The cookies set in the Response
self.cookies: SimpleCookie = SimpleCookie() #: The cookies set in the Response
self.session = (
req.session
) #: The cookie-based session data, in dict form, to add to the Response.
Expand Down Expand Up @@ -365,16 +377,25 @@ async def __call__(self, scope, receive, send):
if self.headers:
headers.update(self.headers)

response_cls: t.Union[
t.Type[StarletteResponse], t.Type[StarletteStreamingResponse]
]
if self._stream is not None:
response_cls = StarletteStreamingResponse
else:
response_cls = StarletteResponse

response = response_cls(body, status_code=self.status_code, headers=headers)
response = response_cls(body, status_code=self.status_code_safe, headers=headers)
self._prepare_cookies(response)

await response(scope, receive, send)

@property
def ok(self):
return 200 <= self.status_code < 300
return 200 <= self.status_code_safe < 300

@property
def status_code_safe(self) -> int:
if self.status_code is None:
raise RuntimeError("HTTP status code has not been defined")
return self.status_code
14 changes: 8 additions & 6 deletions responder/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import inspect
import re
import traceback
import typing as t
from collections import defaultdict

from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.middleware.wsgi import WSGIMiddleware
from starlette.types import ASGIApp
from starlette.websockets import WebSocket, WebSocketClose

from . import status_codes
Expand Down Expand Up @@ -121,7 +123,7 @@ async def __call__(self, scope, receive, send):
views.append(view)
except AttributeError as ex:
if on_request is None:
raise HTTPException(status_code=status_codes.HTTP_405) from ex
raise HTTPException(status_code=status_codes.HTTP_405) from ex # type: ignore[attr-defined]
else:
views.append(self.endpoint)

Expand All @@ -135,7 +137,7 @@ async def __call__(self, scope, receive, send):
await run_in_threadpool(view, request, response, **path_params)

if response.status_code is None:
response.status_code = status_codes.HTTP_200
response.status_code = status_codes.HTTP_200 # type: ignore[attr-defined]

await response(scope, receive, send)

Expand Down Expand Up @@ -207,7 +209,7 @@ class Router:
def __init__(self, routes=None, default_response=None, before_requests=None):
self.routes = [] if routes is None else list(routes)
# [TODO] Make its own router
self.apps = {}
self.apps: t.Dict[str, ASGIApp] = {}
self.default_endpoint = (
self.default_response if default_response is None else default_response
)
Expand Down Expand Up @@ -255,7 +257,7 @@ def add_route(

def mount(self, route, app):
"""Mounts ASGI / WSGI applications at a given route"""
self.apps.update(route, app)
self.apps.update({route: app})

def add_event_handler(self, event_type, handler):
assert event_type in (
Expand Down Expand Up @@ -287,14 +289,14 @@ def url_for(self, endpoint, **params):
async def default_response(self, scope, receive, send):
if scope["type"] == "websocket":
websocket_close = WebSocketClose()
await websocket_close(receive, send)
await websocket_close(scope, receive, send)
return

# FIXME: Please review!
request = Request(scope, receive)
response = Response(request, formats=get_formats()) # noqa: F841

raise HTTPException(status_code=status_codes.HTTP_404)
raise HTTPException(status_code=status_codes.HTTP_404) # type: ignore[attr-defined]

def _resolve_route(self, scope):
for route in self.routes:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def run(self):
],
"graphql": ["graphene"],
"release": ["build", "twine"],
"test": ["pytest", "pytest-cov", "pytest-mock", "flask"],
"test": ["flask", "mypy", "pytest", "pytest-cov", "pytest-mock"],
},
include_package_data=True,
license="Apache 2.0",
Expand Down

0 comments on commit a5b6d36

Please sign in to comment.