Skip to content

Commit

Permalink
feat(api_jws): typing
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianMaurin committed Aug 1, 2022
1 parent c6b6214 commit d71ba51
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import binascii
import json
import warnings
from collections.abc import Mapping
from typing import Any, Dict, List, Optional, Type
from typing import Any, Type

from .algorithms import (
Algorithm,
Expand All @@ -23,7 +24,7 @@
class PyJWS:
header_typ = "JWT"

def __init__(self, algorithms=None, options=None):
def __init__(self, algorithms=None, options=None) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms)
Expand All @@ -39,10 +40,10 @@ def __init__(self, algorithms=None, options=None):
self.options = {**self._get_default_options(), **options}

@staticmethod
def _get_default_options():
def _get_default_options() -> dict[str, bool]:
return {"verify_signature": True}

def register_algorithm(self, alg_id, alg_obj):
def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
"""
Registers a new Algorithm for use when creating and verifying tokens.
"""
Expand All @@ -55,7 +56,7 @@ def register_algorithm(self, alg_id, alg_obj):
self._algorithms[alg_id] = alg_obj
self._valid_algs.add(alg_id)

def unregister_algorithm(self, alg_id):
def unregister_algorithm(self, alg_id: str) -> None:
"""
Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered.
Expand All @@ -69,7 +70,7 @@ def unregister_algorithm(self, alg_id):
del self._algorithms[alg_id]
self._valid_algs.remove(alg_id)

def get_algorithms(self):
def get_algorithms(self) -> list[str]:
"""
Returns a list of supported values for the 'alg' parameter.
"""
Expand All @@ -96,9 +97,9 @@ def encode(
self,
payload: bytes,
key: str,
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
) -> str:
segments = []
Expand All @@ -117,7 +118,7 @@ def encode(
is_payload_detached = True

# Header
header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]
header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}

if headers:
self._validate_headers(headers)
Expand Down Expand Up @@ -165,11 +166,11 @@ def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
detached_payload: Optional[bytes] = None,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
Expand Down Expand Up @@ -210,9 +211,9 @@ def decode(
self,
jwt: str,
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
detached_payload: Optional[bytes] = None,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> str:
if kwargs:
Expand All @@ -227,7 +228,7 @@ def decode(
)
return decoded["payload"]

def get_unverified_header(self, jwt):
def get_unverified_header(self, jwt: str | bytes) -> dict:
"""Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters
Expand All @@ -238,7 +239,7 @@ def get_unverified_header(self, jwt):

return headers

def _load(self, jwt):
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")

Expand All @@ -261,7 +262,7 @@ def _load(self, jwt):
except ValueError as e:
raise DecodeError(f"Invalid header string: {e}") from e

if not isinstance(header, Mapping):
if not isinstance(header, dict):
raise DecodeError("Invalid header string: must be a json object")

try:
Expand All @@ -278,16 +279,16 @@ def _load(self, jwt):

def _verify_signature(
self,
signing_input,
header,
signature,
key="",
algorithms=None,
):
signing_input: bytes,
header: dict,
signature: bytes,
key: str = "",
algorithms: list[str] | None = None,
) -> None:

alg = header.get("alg")

if algorithms is not None and alg not in algorithms:
if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed")

try:
Expand All @@ -299,11 +300,11 @@ def _verify_signature(
if not alg_obj.verify(signing_input, key, signature):
raise InvalidSignatureError("Signature verification failed")

def _validate_headers(self, headers):
def _validate_headers(self, headers: dict[str, Any]) -> None:
if "kid" in headers:
self._validate_kid(headers["kid"])

def _validate_kid(self, kid):
def _validate_kid(self, kid: str) -> None:
if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string")

Expand Down

0 comments on commit d71ba51

Please sign in to comment.