From 275bf293ba1d98f4a90e38ef06e0e4ccd649c97a Mon Sep 17 00:00:00 2001 From: Oscar Chen Date: Thu, 11 Jan 2024 22:55:44 -0700 Subject: [PATCH] Implement feature to allow customizing token claim from user attributes --- ninja_simple_jwt/auth/ninja_auth.py | 11 +++++- ninja_simple_jwt/auth/views/api.py | 9 +++-- ninja_simple_jwt/jwt/json_encode.py | 12 ++++++ ninja_simple_jwt/jwt/token_operations.py | 47 ++++++++++++++---------- ninja_simple_jwt/settings.py | 8 ++++ setup.cfg | 2 +- tests/settings.py | 6 +++ tests/test_jwt/test_json_encode.py | 23 ++++++++++++ tests/test_jwt/test_token_operations.py | 2 +- 9 files changed, 94 insertions(+), 26 deletions(-) create mode 100644 ninja_simple_jwt/jwt/json_encode.py create mode 100644 tests/test_jwt/test_json_encode.py diff --git a/ninja_simple_jwt/auth/ninja_auth.py b/ninja_simple_jwt/auth/ninja_auth.py index 285e396..6b9d634 100644 --- a/ninja_simple_jwt/auth/ninja_auth.py +++ b/ninja_simple_jwt/auth/ninja_auth.py @@ -1,3 +1,5 @@ +from django.contrib.auth.base_user import AbstractBaseUser +from django.contrib.auth.models import AnonymousUser from django.http import HttpRequest from jwt import PyJWTError from ninja.errors import AuthenticationError @@ -5,6 +7,7 @@ from ninja.security.http import DecodeError from ninja_simple_jwt.jwt.token_operations import TokenTypes, decode_token +from ninja_simple_jwt.settings import ninja_simple_jwt_settings class HttpJwtAuth(HttpBearer): @@ -16,11 +19,15 @@ def authenticate(self, request: HttpRequest, token: str) -> bool: except PyJWTError as e: raise AuthenticationError(e) - setattr(request.user, "id", access_token["user_id"]) - setattr(request.user, "username", access_token["username"]) + self.set_token_claims_to_user(request.user, access_token) return True + @staticmethod + def set_token_claims_to_user(user: AbstractBaseUser | AnonymousUser, token: dict) -> None: + for claim, user_attribute in ninja_simple_jwt_settings.TOKEN_CLAIM_USER_ATTRIBUTE_MAP.items(): + setattr(user, user_attribute, token.get(claim)) + def decode_authorization(self, value: str) -> str: parts = value.split(" ") if len(parts) != 2 or parts[0].lower() != "bearer": diff --git a/ninja_simple_jwt/auth/views/api.py b/ninja_simple_jwt/auth/views/api.py index 7dd4ceb..fd89e35 100644 --- a/ninja_simple_jwt/auth/views/api.py +++ b/ninja_simple_jwt/auth/views/api.py @@ -1,6 +1,7 @@ from datetime import UTC, datetime from django.contrib.auth import authenticate +from django.contrib.auth.signals import user_logged_in from django.http import HttpRequest, HttpResponse from jwt.exceptions import PyJWTError from ninja import Router @@ -14,7 +15,7 @@ WebSignInResponse, ) from ninja_simple_jwt.jwt.token_operations import ( - get_access_token, + get_access_token_for_user, get_access_token_from_refresh_token, get_refresh_token_for_user, ) @@ -28,10 +29,11 @@ def mobile_sign_in(request: HttpRequest, payload: SignInRequest) -> dict: payload_data = payload.dict() user = authenticate(username=payload_data["username"], password=payload_data["password"]) + user_logged_in.send(sender=user.__class__, request=request, user=user) if user is None: raise AuthenticationError() refresh_token, _ = get_refresh_token_for_user(user) - access_token, _ = get_access_token(str(user.pk), user.get_username()) + access_token, _ = get_access_token_for_user(user) return {"refresh": refresh_token, "access": access_token} @@ -50,10 +52,11 @@ def mobile_token_refresh(request: HttpRequest, payload: MobileTokenRefreshReques def web_sign_in(request: HttpRequest, payload: SignInRequest, response: HttpResponse) -> dict: payload_data = payload.dict() user = authenticate(username=payload_data["username"], password=payload_data["password"]) + user_logged_in.send(sender=user.__class__, request=request, user=user) if user is None: raise AuthenticationError() refresh_token, refresh_token_payload = get_refresh_token_for_user(user) - access_token, _ = get_access_token(str(user.pk), user.get_username()) + access_token, _ = get_access_token_for_user(user) response.set_cookie( key=ninja_simple_jwt_settings.JWT_REFRESH_COOKIE_NAME, value=refresh_token, diff --git a/ninja_simple_jwt/jwt/json_encode.py b/ninja_simple_jwt/jwt/json_encode.py new file mode 100644 index 0000000..e7751b3 --- /dev/null +++ b/ninja_simple_jwt/jwt/json_encode.py @@ -0,0 +1,12 @@ +from typing import Any +from uuid import UUID + +from django.core.serializers.json import DjangoJSONEncoder + + +class TokenUserEncoder(DjangoJSONEncoder): + def default(self, o: Any) -> Any: + if isinstance(o, UUID): + return str(o) + + return super().default(o) diff --git a/ninja_simple_jwt/jwt/token_operations.py b/ninja_simple_jwt/jwt/token_operations.py index e6f4f4a..8842111 100644 --- a/ninja_simple_jwt/jwt/token_operations.py +++ b/ninja_simple_jwt/jwt/token_operations.py @@ -1,14 +1,15 @@ import time from datetime import datetime from enum import Enum -from typing import Any, Tuple, TypedDict +from json import JSONEncoder +from typing import Any, Optional, Tuple from uuid import uuid4 import jwt from django.contrib.auth.models import AbstractBaseUser from django.utils import timezone +from django.utils.module_loading import import_string from jwt import ExpiredSignatureError, InvalidKeyError, InvalidTokenError -from pydantic import UUID4 from ninja_simple_jwt.jwt.key_retrieval import InMemoryJwtKeyPair from ninja_simple_jwt.settings import ninja_simple_jwt_settings @@ -19,32 +20,35 @@ class TokenTypes(str, Enum): REFRESH = "refresh" -class TokenPayload(TypedDict): - user_id: str - username: str +TokenUserJsonEncoder = import_string(ninja_simple_jwt_settings.TOKEN_USER_ENCODER_CLS) -class DecodedTokenPayload(TokenPayload): - token_type: TokenTypes - jti: UUID4 - exp: int - iat: int +def get_refresh_token_for_user(user: AbstractBaseUser) -> Tuple[str, dict]: + payload = get_token_payload_for_user(user) + return encode_token(payload, TokenTypes.REFRESH, json_encoder=TokenUserJsonEncoder) -def get_refresh_token_for_user(user: AbstractBaseUser) -> Tuple[str, dict]: - return encode_token({"username": user.get_username(), "user_id": user.pk}, TokenTypes.REFRESH) +def get_access_token_for_user(user: AbstractBaseUser) -> Tuple[str, dict]: + payload = get_token_payload_for_user(user) + return encode_token(payload, TokenTypes.ACCESS, json_encoder=TokenUserJsonEncoder) -def get_access_token(user_id: str, user_name: str) -> Tuple[str, dict]: - return encode_token({"username": user_name, "user_id": user_id}, TokenTypes.ACCESS) +def get_token_payload_for_user(user: AbstractBaseUser) -> dict: + return { + claim: getattr(user, user_attr) + for claim, user_attr in ninja_simple_jwt_settings.TOKEN_CLAIM_USER_ATTRIBUTE_MAP.items() + } def get_access_token_from_refresh_token(refresh_token: str) -> Tuple[str, dict]: decoded = decode_token(refresh_token, token_type=TokenTypes.REFRESH, verify=True) - return get_access_token(user_id=decoded["user_id"], user_name=decoded["username"]) + payload = {claim: decoded.get(claim) for claim in ninja_simple_jwt_settings.TOKEN_CLAIM_USER_ATTRIBUTE_MAP} + return encode_token(payload, TokenTypes.ACCESS) -def encode_token(payload: TokenPayload, token_type: TokenTypes, **additional_headers: Any) -> Tuple[str, dict]: +def encode_token( + payload: dict, token_type: TokenTypes, json_encoder: Optional[type[JSONEncoder]] = None, **additional_headers: Any +) -> Tuple[str, dict]: now = timezone.now() if token_type == TokenTypes.REFRESH: expiry = now + ninja_simple_jwt_settings.JWT_REFRESH_TOKEN_LIFETIME @@ -53,7 +57,6 @@ def encode_token(payload: TokenPayload, token_type: TokenTypes, **additional_hea payload_data = { **payload, - "user_id": str(payload["user_id"]), "jti": uuid4().hex, "exp": int(expiry.timestamp()), "iat": int(now.timestamp()), @@ -61,12 +64,18 @@ def encode_token(payload: TokenPayload, token_type: TokenTypes, **additional_hea } return ( - jwt.encode(payload_data, InMemoryJwtKeyPair.private_key, algorithm="RS256", headers=additional_headers), + jwt.encode( + payload_data, + InMemoryJwtKeyPair.private_key, + algorithm="RS256", + headers=additional_headers, + json_encoder=json_encoder, + ), payload_data, ) -def decode_token(token: str, token_type: TokenTypes, verify: bool = True) -> DecodedTokenPayload: +def decode_token(token: str, token_type: TokenTypes, verify: bool = True) -> dict: if verify is True: decoded = jwt.decode(token, InMemoryJwtKeyPair.public_key, algorithms=["RS256"]) _verify_exp(decoded) diff --git a/ninja_simple_jwt/settings.py b/ninja_simple_jwt/settings.py index a9336e1..57dbe65 100644 --- a/ninja_simple_jwt/settings.py +++ b/ninja_simple_jwt/settings.py @@ -19,6 +19,8 @@ class NinjaSimpleJwtSettingsDict(TypedDict): WEB_REFRESH_COOKIE_HTTP_ONLY: NotRequired[bool] WEB_REFRESH_COOKIE_SAME_SITE_POLICY: NotRequired[str] WEB_REFRESH_COOKIE_PATH: NotRequired[str] + TOKEN_CLAIM_USER_ATTRIBUTE_MAP: NotRequired[dict[str, str]] + TOKEN_USER_ENCODER_CLS: NotRequired[str] DEFAULTS: NinjaSimpleJwtSettingsDict = { @@ -32,6 +34,12 @@ class NinjaSimpleJwtSettingsDict(TypedDict): "WEB_REFRESH_COOKIE_HTTP_ONLY": True, "WEB_REFRESH_COOKIE_SAME_SITE_POLICY": "Strict", "WEB_REFRESH_COOKIE_PATH": "/api/auth/web/token-refresh", + "TOKEN_CLAIM_USER_ATTRIBUTE_MAP": { + "user_id": "id", + "username": "username", + "last_login": "last_login", + }, + "TOKEN_USER_ENCODER_CLS": "ninja_simple_jwt.jwt.json_encode.TokenUserEncoder", } EMPTY_SETTINGS: NinjaSimpleJwtSettingsDict = {} diff --git a/setup.cfg b/setup.cfg index b0a45de..b69f9e6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = django-ninja-simple-jwt -version = 0.3.0 +version = 0.4.0 description = Simple JWT-based authentication using Django and Django-ninja long_description = file: README.md url = https://github.com/oscarychen/django-ninja-simple-jwt diff --git a/tests/settings.py b/tests/settings.py index 8338d7d..2122b5f 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -74,6 +74,7 @@ "django.contrib.staticfiles", "django.contrib.admin", "django.contrib.messages", + "ninja_simple_jwt", "tests", ) @@ -110,3 +111,8 @@ }, }, } + +NINJA_SIMPLE_JWT = { + "JWT_PRIVATE_KEY_PATH": "tests/mock-jwt-signing.pem", + "JWT_PUBLIC_KEY_PATH": "tests/mock-jwt-signing.pub", +} diff --git a/tests/test_jwt/test_json_encode.py b/tests/test_jwt/test_json_encode.py new file mode 100644 index 0000000..fc6355f --- /dev/null +++ b/tests/test_jwt/test_json_encode.py @@ -0,0 +1,23 @@ +import json +from datetime import datetime +from uuid import uuid4 + +from django.test import TestCase + +from ninja_simple_jwt.jwt.json_encode import TokenUserEncoder + + +class TestDjangoUserEncoder(TestCase): + def test_encoder_can_serialize_datetime(self) -> None: + test_data = datetime(2012, 1, 14, 12, 0, 1) + + result = json.dumps(test_data, cls=TokenUserEncoder) + + self.assertEqual('"2012-01-14T12:00:01"', result) + + def test_encoder_can_serialize_uuid(self) -> None: + test_uuid = uuid4() + + result = json.dumps(test_uuid, cls=TokenUserEncoder) + + self.assertEqual(f'"{str(test_uuid)}"', result) diff --git a/tests/test_jwt/test_token_operations.py b/tests/test_jwt/test_token_operations.py index 01d5510..c06b046 100644 --- a/tests/test_jwt/test_token_operations.py +++ b/tests/test_jwt/test_token_operations.py @@ -6,7 +6,7 @@ from ninja_simple_jwt.jwt.token_operations import TokenTypes, decode_token, encode_token -class TestEncodeToken(TestCase): +class TestEncodeDecodeToken(TestCase): def setUp(self) -> None: make_and_save_key_pair()