Skip to content

Commit

Permalink
Mypy as pre-commit check + api_jws typing (#787)
Browse files Browse the repository at this point in the history
* feat(mypy): from tox to pre-commit

* fix(mypy): apply mypy fixes

* feat(api_jws): typing

Co-authored-by: JulianMaurin <[email protected]>
  • Loading branch information
JulianMaurin and JulianMaurin authored Aug 3, 2022
1 parent e8780ab commit f827be3
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 53 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,8 @@ repos:
hooks:
- id: check-manifest
args: [--no-build-isolation]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v0.971"
hooks:
- id: mypy
61 changes: 31 additions & 30 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import binascii
import json
import warnings
from collections.abc import Mapping
from typing import Any, Dict, List, Optional, Type
from typing import Any, Type

from .algorithms import (
Algorithm,
Expand All @@ -23,7 +24,7 @@
class PyJWS:
header_typ = "JWT"

def __init__(self, algorithms=None, options=None):
def __init__(self, algorithms=None, options=None) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms)
Expand All @@ -39,10 +40,10 @@ def __init__(self, algorithms=None, options=None):
self.options = {**self._get_default_options(), **options}

@staticmethod
def _get_default_options():
def _get_default_options() -> dict[str, bool]:
return {"verify_signature": True}

def register_algorithm(self, alg_id, alg_obj):
def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
"""
Registers a new Algorithm for use when creating and verifying tokens.
"""
Expand All @@ -55,7 +56,7 @@ def register_algorithm(self, alg_id, alg_obj):
self._algorithms[alg_id] = alg_obj
self._valid_algs.add(alg_id)

def unregister_algorithm(self, alg_id):
def unregister_algorithm(self, alg_id: str) -> None:
"""
Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered.
Expand All @@ -69,7 +70,7 @@ def unregister_algorithm(self, alg_id):
del self._algorithms[alg_id]
self._valid_algs.remove(alg_id)

def get_algorithms(self):
def get_algorithms(self) -> list[str]:
"""
Returns a list of supported values for the 'alg' parameter.
"""
Expand All @@ -96,9 +97,9 @@ def encode(
self,
payload: bytes,
key: str,
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
) -> str:
segments = []
Expand All @@ -117,7 +118,7 @@ def encode(
is_payload_detached = True

# Header
header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]
header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}

if headers:
self._validate_headers(headers)
Expand Down Expand Up @@ -165,11 +166,11 @@ def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
detached_payload: Optional[bytes] = None,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
Expand Down Expand Up @@ -210,9 +211,9 @@ def decode(
self,
jwt: str,
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
detached_payload: Optional[bytes] = None,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> str:
if kwargs:
Expand All @@ -227,7 +228,7 @@ def decode(
)
return decoded["payload"]

def get_unverified_header(self, jwt):
def get_unverified_header(self, jwt: str | bytes) -> dict:
"""Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters
Expand All @@ -238,7 +239,7 @@ def get_unverified_header(self, jwt):

return headers

def _load(self, jwt):
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")

Expand All @@ -261,7 +262,7 @@ def _load(self, jwt):
except ValueError as e:
raise DecodeError(f"Invalid header string: {e}") from e

if not isinstance(header, Mapping):
if not isinstance(header, dict):
raise DecodeError("Invalid header string: must be a json object")

try:
Expand All @@ -278,16 +279,16 @@ def _load(self, jwt):

def _verify_signature(
self,
signing_input,
header,
signature,
key="",
algorithms=None,
):
signing_input: bytes,
header: dict,
signature: bytes,
key: str = "",
algorithms: list[str] | None = None,
) -> None:

alg = header.get("alg")

if algorithms is not None and alg not in algorithms:
if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed")

try:
Expand All @@ -299,11 +300,11 @@ def _verify_signature(
if not alg_obj.verify(signing_input, key, signature):
raise InvalidSignatureError("Signature verification failed")

def _validate_headers(self, headers):
def _validate_headers(self, headers: dict[str, Any]) -> None:
if "kid" in headers:
self._validate_kid(headers["kid"])

def _validate_kid(self, kid):
def _validate_kid(self, kid: str) -> None:
if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string")

Expand Down
13 changes: 7 additions & 6 deletions jwt/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
try:
import cryptography
except ModuleNotFoundError:
cryptography = None # type: ignore
cryptography = None


def info() -> Dict[str, Dict[str, str]]:
Expand All @@ -29,14 +29,15 @@ def info() -> Dict[str, Dict[str, str]]:
if implementation == "CPython":
implementation_version = platform.python_version()
elif implementation == "PyPy":
pypy_version_info = getattr(sys, "pypy_version_info")
implementation_version = (
f"{sys.pypy_version_info.major}." # type: ignore[attr-defined]
f"{sys.pypy_version_info.minor}."
f"{sys.pypy_version_info.micro}"
f"{pypy_version_info.major}."
f"{pypy_version_info.minor}."
f"{pypy_version_info.micro}"
)
if sys.pypy_version_info.releaselevel != "final": # type: ignore[attr-defined]
if pypy_version_info.releaselevel != "final":
implementation_version = "".join(
[implementation_version, sys.pypy_version_info.releaselevel] # type: ignore[attr-defined]
[implementation_version, pypy_version_info.releaselevel]
)
else:
implementation_version = "Unknown"
Expand Down
4 changes: 2 additions & 2 deletions jwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import binascii
import re
from typing import Any, Union
from typing import Union

try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
Expand All @@ -10,7 +10,7 @@
encode_dss_signature,
)
except ModuleNotFoundError:
EllipticCurve = Any # type: ignore
EllipticCurve = None


def force_bytes(value: Union[str, bytes]) -> bytes:
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
requires = ["setuptools"]
build-backend = "setuptools.build_meta"


[tool.coverage.run]
parallel = true
branch = true
Expand All @@ -14,8 +13,13 @@ source = ["jwt", ".tox/*/site-packages"]
[tool.coverage.report]
show_missing = true


[tool.isort]
profile = "black"
atomic = true
combine_as_imports = true

[tool.mypy]
python_version = 3.7
ignore_missing_imports = true
warn_unused_ignores = true
no_implicit_optional = true
7 changes: 0 additions & 7 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ dev =
types-cryptography>=3.3.21
pytest>=6.0.0,<7.0.0
coverage[toml]==5.0.4
mypy
pre-commit

[options.packages.find]
Expand All @@ -67,9 +66,3 @@ exclude =

[flake8]
extend-ignore = E203, E501

[mypy]
python_version = 3.7
ignore_missing_imports = true
warn_unused_ignores = true
no_implicit_optional = true
6 changes: 0 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ commands =
python -m doctest README.rst


[testenv:typing]
basepython = python3.8
extras = dev
commands = mypy jwt


[testenv:lint]
basepython = python3.8
extras = dev
Expand Down

0 comments on commit f827be3

Please sign in to comment.