Skip to content

Commit

Permalink
Implement feature to allow customizing token claim from user attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarychen committed Jan 13, 2024
1 parent e23792d commit 275bf29
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 26 deletions.
11 changes: 9 additions & 2 deletions ninja_simple_jwt/auth/ninja_auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.models import AnonymousUser
from django.http import HttpRequest
from jwt import PyJWTError
from ninja.errors import AuthenticationError
from ninja.security import HttpBearer
from ninja.security.http import DecodeError

from ninja_simple_jwt.jwt.token_operations import TokenTypes, decode_token
from ninja_simple_jwt.settings import ninja_simple_jwt_settings


class HttpJwtAuth(HttpBearer):
Expand All @@ -16,11 +19,15 @@ def authenticate(self, request: HttpRequest, token: str) -> bool:
except PyJWTError as e:
raise AuthenticationError(e)

setattr(request.user, "id", access_token["user_id"])
setattr(request.user, "username", access_token["username"])
self.set_token_claims_to_user(request.user, access_token)

return True

@staticmethod
def set_token_claims_to_user(user: AbstractBaseUser | AnonymousUser, token: dict) -> None:
for claim, user_attribute in ninja_simple_jwt_settings.TOKEN_CLAIM_USER_ATTRIBUTE_MAP.items():
setattr(user, user_attribute, token.get(claim))

def decode_authorization(self, value: str) -> str:
parts = value.split(" ")
if len(parts) != 2 or parts[0].lower() != "bearer":
Expand Down
9 changes: 6 additions & 3 deletions ninja_simple_jwt/auth/views/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import UTC, datetime

from django.contrib.auth import authenticate
from django.contrib.auth.signals import user_logged_in
from django.http import HttpRequest, HttpResponse
from jwt.exceptions import PyJWTError
from ninja import Router
Expand All @@ -14,7 +15,7 @@
WebSignInResponse,
)
from ninja_simple_jwt.jwt.token_operations import (
get_access_token,
get_access_token_for_user,
get_access_token_from_refresh_token,
get_refresh_token_for_user,
)
Expand All @@ -28,10 +29,11 @@
def mobile_sign_in(request: HttpRequest, payload: SignInRequest) -> dict:
payload_data = payload.dict()
user = authenticate(username=payload_data["username"], password=payload_data["password"])
user_logged_in.send(sender=user.__class__, request=request, user=user)
if user is None:
raise AuthenticationError()
refresh_token, _ = get_refresh_token_for_user(user)
access_token, _ = get_access_token(str(user.pk), user.get_username())
access_token, _ = get_access_token_for_user(user)
return {"refresh": refresh_token, "access": access_token}


Expand All @@ -50,10 +52,11 @@ def mobile_token_refresh(request: HttpRequest, payload: MobileTokenRefreshReques
def web_sign_in(request: HttpRequest, payload: SignInRequest, response: HttpResponse) -> dict:
payload_data = payload.dict()
user = authenticate(username=payload_data["username"], password=payload_data["password"])
user_logged_in.send(sender=user.__class__, request=request, user=user)
if user is None:
raise AuthenticationError()
refresh_token, refresh_token_payload = get_refresh_token_for_user(user)
access_token, _ = get_access_token(str(user.pk), user.get_username())
access_token, _ = get_access_token_for_user(user)
response.set_cookie(
key=ninja_simple_jwt_settings.JWT_REFRESH_COOKIE_NAME,
value=refresh_token,
Expand Down
12 changes: 12 additions & 0 deletions ninja_simple_jwt/jwt/json_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any
from uuid import UUID

from django.core.serializers.json import DjangoJSONEncoder


class TokenUserEncoder(DjangoJSONEncoder):
def default(self, o: Any) -> Any:
if isinstance(o, UUID):
return str(o)

return super().default(o)
47 changes: 28 additions & 19 deletions ninja_simple_jwt/jwt/token_operations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import time
from datetime import datetime
from enum import Enum
from typing import Any, Tuple, TypedDict
from json import JSONEncoder
from typing import Any, Optional, Tuple
from uuid import uuid4

import jwt
from django.contrib.auth.models import AbstractBaseUser
from django.utils import timezone
from django.utils.module_loading import import_string
from jwt import ExpiredSignatureError, InvalidKeyError, InvalidTokenError
from pydantic import UUID4

from ninja_simple_jwt.jwt.key_retrieval import InMemoryJwtKeyPair
from ninja_simple_jwt.settings import ninja_simple_jwt_settings
Expand All @@ -19,32 +20,35 @@ class TokenTypes(str, Enum):
REFRESH = "refresh"


class TokenPayload(TypedDict):
user_id: str
username: str
TokenUserJsonEncoder = import_string(ninja_simple_jwt_settings.TOKEN_USER_ENCODER_CLS)


class DecodedTokenPayload(TokenPayload):
token_type: TokenTypes
jti: UUID4
exp: int
iat: int
def get_refresh_token_for_user(user: AbstractBaseUser) -> Tuple[str, dict]:
payload = get_token_payload_for_user(user)
return encode_token(payload, TokenTypes.REFRESH, json_encoder=TokenUserJsonEncoder)


def get_refresh_token_for_user(user: AbstractBaseUser) -> Tuple[str, dict]:
return encode_token({"username": user.get_username(), "user_id": user.pk}, TokenTypes.REFRESH)
def get_access_token_for_user(user: AbstractBaseUser) -> Tuple[str, dict]:
payload = get_token_payload_for_user(user)
return encode_token(payload, TokenTypes.ACCESS, json_encoder=TokenUserJsonEncoder)


def get_access_token(user_id: str, user_name: str) -> Tuple[str, dict]:
return encode_token({"username": user_name, "user_id": user_id}, TokenTypes.ACCESS)
def get_token_payload_for_user(user: AbstractBaseUser) -> dict:
return {
claim: getattr(user, user_attr)
for claim, user_attr in ninja_simple_jwt_settings.TOKEN_CLAIM_USER_ATTRIBUTE_MAP.items()
}


def get_access_token_from_refresh_token(refresh_token: str) -> Tuple[str, dict]:
decoded = decode_token(refresh_token, token_type=TokenTypes.REFRESH, verify=True)
return get_access_token(user_id=decoded["user_id"], user_name=decoded["username"])
payload = {claim: decoded.get(claim) for claim in ninja_simple_jwt_settings.TOKEN_CLAIM_USER_ATTRIBUTE_MAP}
return encode_token(payload, TokenTypes.ACCESS)


def encode_token(payload: TokenPayload, token_type: TokenTypes, **additional_headers: Any) -> Tuple[str, dict]:
def encode_token(
payload: dict, token_type: TokenTypes, json_encoder: Optional[type[JSONEncoder]] = None, **additional_headers: Any
) -> Tuple[str, dict]:
now = timezone.now()
if token_type == TokenTypes.REFRESH:
expiry = now + ninja_simple_jwt_settings.JWT_REFRESH_TOKEN_LIFETIME
Expand All @@ -53,20 +57,25 @@ def encode_token(payload: TokenPayload, token_type: TokenTypes, **additional_hea

payload_data = {
**payload,
"user_id": str(payload["user_id"]),
"jti": uuid4().hex,
"exp": int(expiry.timestamp()),
"iat": int(now.timestamp()),
"token_type": token_type,
}

return (
jwt.encode(payload_data, InMemoryJwtKeyPair.private_key, algorithm="RS256", headers=additional_headers),
jwt.encode(
payload_data,
InMemoryJwtKeyPair.private_key,
algorithm="RS256",
headers=additional_headers,
json_encoder=json_encoder,
),
payload_data,
)


def decode_token(token: str, token_type: TokenTypes, verify: bool = True) -> DecodedTokenPayload:
def decode_token(token: str, token_type: TokenTypes, verify: bool = True) -> dict:
if verify is True:
decoded = jwt.decode(token, InMemoryJwtKeyPair.public_key, algorithms=["RS256"])
_verify_exp(decoded)
Expand Down
8 changes: 8 additions & 0 deletions ninja_simple_jwt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class NinjaSimpleJwtSettingsDict(TypedDict):
WEB_REFRESH_COOKIE_HTTP_ONLY: NotRequired[bool]
WEB_REFRESH_COOKIE_SAME_SITE_POLICY: NotRequired[str]
WEB_REFRESH_COOKIE_PATH: NotRequired[str]
TOKEN_CLAIM_USER_ATTRIBUTE_MAP: NotRequired[dict[str, str]]
TOKEN_USER_ENCODER_CLS: NotRequired[str]


DEFAULTS: NinjaSimpleJwtSettingsDict = {
Expand All @@ -32,6 +34,12 @@ class NinjaSimpleJwtSettingsDict(TypedDict):
"WEB_REFRESH_COOKIE_HTTP_ONLY": True,
"WEB_REFRESH_COOKIE_SAME_SITE_POLICY": "Strict",
"WEB_REFRESH_COOKIE_PATH": "/api/auth/web/token-refresh",
"TOKEN_CLAIM_USER_ATTRIBUTE_MAP": {
"user_id": "id",
"username": "username",
"last_login": "last_login",
},
"TOKEN_USER_ENCODER_CLS": "ninja_simple_jwt.jwt.json_encode.TokenUserEncoder",
}

EMPTY_SETTINGS: NinjaSimpleJwtSettingsDict = {}
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = django-ninja-simple-jwt
version = 0.3.0
version = 0.4.0
description = Simple JWT-based authentication using Django and Django-ninja
long_description = file: README.md
url = https://github.com/oscarychen/django-ninja-simple-jwt
Expand Down
6 changes: 6 additions & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"django.contrib.staticfiles",
"django.contrib.admin",
"django.contrib.messages",
"ninja_simple_jwt",
"tests",
)

Expand Down Expand Up @@ -110,3 +111,8 @@
},
},
}

NINJA_SIMPLE_JWT = {
"JWT_PRIVATE_KEY_PATH": "tests/mock-jwt-signing.pem",
"JWT_PUBLIC_KEY_PATH": "tests/mock-jwt-signing.pub",
}
23 changes: 23 additions & 0 deletions tests/test_jwt/test_json_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json
from datetime import datetime
from uuid import uuid4

from django.test import TestCase

from ninja_simple_jwt.jwt.json_encode import TokenUserEncoder


class TestDjangoUserEncoder(TestCase):
def test_encoder_can_serialize_datetime(self) -> None:
test_data = datetime(2012, 1, 14, 12, 0, 1)

result = json.dumps(test_data, cls=TokenUserEncoder)

self.assertEqual('"2012-01-14T12:00:01"', result)

def test_encoder_can_serialize_uuid(self) -> None:
test_uuid = uuid4()

result = json.dumps(test_uuid, cls=TokenUserEncoder)

self.assertEqual(f'"{str(test_uuid)}"', result)
2 changes: 1 addition & 1 deletion tests/test_jwt/test_token_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ninja_simple_jwt.jwt.token_operations import TokenTypes, decode_token, encode_token


class TestEncodeToken(TestCase):
class TestEncodeDecodeToken(TestCase):
def setUp(self) -> None:
make_and_save_key_pair()

Expand Down

0 comments on commit 275bf29

Please sign in to comment.