diff --git a/timed/authentication.py b/timed/authentication.py index 491299ef..4a582847 100644 --- a/timed/authentication.py +++ b/timed/authentication.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import base64 import functools import hashlib +from typing import TYPE_CHECKING import requests from django.conf import settings @@ -10,9 +13,18 @@ from mozilla_django_oidc.auth import LOGGER, OIDCAuthenticationBackend from rest_framework.exceptions import AuthenticationFailed +if TYPE_CHECKING: + from typing import Callable, Self + + from django.db.models import QuerySet + + from timed.employment.models import User + class TimedOIDCAuthenticationBackend(OIDCAuthenticationBackend): - def get_introspection(self, access_token, _id_token, _payload): + def get_introspection( + self, access_token: str, _id_token: str, _payload: dict + ) -> dict: """Return user details dictionary.""" basic = base64.b64encode( f"{settings.OIDC_RP_INTROSPECT_CLIENT_ID}:{settings.OIDC_RP_INTROSPECT_CLIENT_SECRET}".encode() @@ -31,7 +43,7 @@ def get_introspection(self, access_token, _id_token, _payload): response.raise_for_status() return response.json() - def get_userinfo_or_introspection(self, access_token): + def get_userinfo_or_introspection(self, access_token: str) -> dict: try: return self.cached_request(self.get_userinfo, access_token, "auth.userinfo") except requests.HTTPError as exc: @@ -57,7 +69,9 @@ def get_userinfo_or_introspection(self, access_token): return claims raise AuthenticationFailed from exc - def get_or_create_user(self, access_token, _id_token, _payload): + def get_or_create_user( + self, access_token: str, _id_token: str, _payload: dict + ) -> User | None: """Verify claims and return user, otherwise raise an Exception.""" claims = self.get_userinfo_or_introspection(access_token) @@ -76,17 +90,22 @@ def get_or_create_user(self, access_token, _id_token, _payload): ) return None - def update_user_from_claims(self, user, claims): + def update_user_from_claims(self, user: User, claims: dict[str, str]) -> None: user.email = claims.get(settings.OIDC_EMAIL_CLAIM, "") user.first_name = claims.get(settings.OIDC_FIRSTNAME_CLAIM, "") user.last_name = claims.get(settings.OIDC_LASTNAME_CLAIM, "") user.save() - def filter_users_by_claims(self, claims): + def filter_users_by_claims(self, claims: dict[str, str]) -> QuerySet[User]: username = self.get_username(claims) return self.UserModel.objects.filter(username__iexact=username) - def cached_request(self, method, token, cache_prefix): + def cached_request( + self, + method: Callable[[Self, str, None, None], dict], + token: str, + cache_prefix: str, + ) -> dict: token_hash = hashlib.sha256(force_bytes(token)).hexdigest() func = functools.partial(method, token, None, None) @@ -97,7 +116,7 @@ def cached_request(self, method, token, cache_prefix): timeout=settings.OIDC_BEARER_TOKEN_REVALIDATION_TIME, ) - def create_user(self, claims): + def create_user(self, claims: dict[str, str]) -> User: """Return object for a newly created user account.""" username = self.get_username(claims) email = claims.get(settings.OIDC_EMAIL_CLAIM, "") @@ -108,7 +127,7 @@ def create_user(self, claims): username=username, email=email, first_name=first_name, last_name=last_name ) - def get_username(self, claims): + def get_username(self, claims: dict[str, str]) -> str: try: return claims[settings.OIDC_USERNAME_CLAIM] except KeyError as exc: