From d71ba5115df672869557cbd5e80935d78140a471 Mon Sep 17 00:00:00 2001 From: JulianMaurin Date: Mon, 1 Aug 2022 19:16:00 +0200 Subject: [PATCH] feat(api_jws): typing --- jwt/api_jws.py | 61 +++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 90206c9a2..ab8490f9f 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import binascii import json import warnings -from collections.abc import Mapping -from typing import Any, Dict, List, Optional, Type +from typing import Any, Type from .algorithms import ( Algorithm, @@ -23,7 +24,7 @@ class PyJWS: header_typ = "JWT" - def __init__(self, algorithms=None, options=None): + def __init__(self, algorithms=None, options=None) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( set(algorithms) if algorithms is not None else set(self._algorithms) @@ -39,10 +40,10 @@ def __init__(self, algorithms=None, options=None): self.options = {**self._get_default_options(), **options} @staticmethod - def _get_default_options(): + def _get_default_options() -> dict[str, bool]: return {"verify_signature": True} - def register_algorithm(self, alg_id, alg_obj): + def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: """ Registers a new Algorithm for use when creating and verifying tokens. """ @@ -55,7 +56,7 @@ def register_algorithm(self, alg_id, alg_obj): self._algorithms[alg_id] = alg_obj self._valid_algs.add(alg_id) - def unregister_algorithm(self, alg_id): + def unregister_algorithm(self, alg_id: str) -> None: """ Unregisters an Algorithm for use when creating and verifying tokens Throws KeyError if algorithm is not registered. @@ -69,7 +70,7 @@ def unregister_algorithm(self, alg_id): del self._algorithms[alg_id] self._valid_algs.remove(alg_id) - def get_algorithms(self): + def get_algorithms(self) -> list[str]: """ Returns a list of supported values for the 'alg' parameter. """ @@ -96,9 +97,9 @@ def encode( self, payload: bytes, key: str, - algorithm: Optional[str] = "HS256", - headers: Optional[Dict[str, Any]] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + algorithm: str | None = "HS256", + headers: dict[str, Any] | None = None, + json_encoder: Type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, ) -> str: segments = [] @@ -117,7 +118,7 @@ def encode( is_payload_detached = True # Header - header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any] + header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} if headers: self._validate_headers(headers) @@ -165,11 +166,11 @@ def decode_complete( self, jwt: str, key: str = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, - detached_payload: Optional[bytes] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, + detached_payload: bytes | None = None, **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if kwargs: warnings.warn( "passing additional kwargs to decode_complete() is deprecated " @@ -210,9 +211,9 @@ def decode( self, jwt: str, key: str = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, - detached_payload: Optional[bytes] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, + detached_payload: bytes | None = None, **kwargs, ) -> str: if kwargs: @@ -227,7 +228,7 @@ def decode( ) return decoded["payload"] - def get_unverified_header(self, jwt): + def get_unverified_header(self, jwt: str | bytes) -> dict: """Returns back the JWT header parameters as a dict() Note: The signature is not verified so the header parameters @@ -238,7 +239,7 @@ def get_unverified_header(self, jwt): return headers - def _load(self, jwt): + def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]: if isinstance(jwt, str): jwt = jwt.encode("utf-8") @@ -261,7 +262,7 @@ def _load(self, jwt): except ValueError as e: raise DecodeError(f"Invalid header string: {e}") from e - if not isinstance(header, Mapping): + if not isinstance(header, dict): raise DecodeError("Invalid header string: must be a json object") try: @@ -278,16 +279,16 @@ def _load(self, jwt): def _verify_signature( self, - signing_input, - header, - signature, - key="", - algorithms=None, - ): + signing_input: bytes, + header: dict, + signature: bytes, + key: str = "", + algorithms: list[str] | None = None, + ) -> None: alg = header.get("alg") - if algorithms is not None and alg not in algorithms: + if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") try: @@ -299,11 +300,11 @@ def _verify_signature( if not alg_obj.verify(signing_input, key, signature): raise InvalidSignatureError("Signature verification failed") - def _validate_headers(self, headers): + def _validate_headers(self, headers: dict[str, Any]) -> None: if "kid" in headers: self._validate_kid(headers["kid"]) - def _validate_kid(self, kid): + def _validate_kid(self, kid: str) -> None: if not isinstance(kid, str): raise InvalidTokenError("Key ID header parameter must be a string")