From ed0f7b61ed4570f8e2a9f17f0a377be5dcf106b8 Mon Sep 17 00:00:00 2001 From: yeongkwang Date: Wed, 2 Feb 2022 15:20:05 +0900 Subject: [PATCH] Make the token serializer configurable (#521) --- rest_framework_simplejwt/settings.py | 6 ++++++ rest_framework_simplejwt/views.py | 22 ++++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/rest_framework_simplejwt/settings.py b/rest_framework_simplejwt/settings.py index e6463b4f7..b4bac42ff 100644 --- a/rest_framework_simplejwt/settings.py +++ b/rest_framework_simplejwt/settings.py @@ -34,6 +34,12 @@ "SLIDING_TOKEN_REFRESH_EXP_CLAIM": "refresh_exp", "SLIDING_TOKEN_LIFETIME": timedelta(minutes=5), "SLIDING_TOKEN_REFRESH_LIFETIME": timedelta(days=1), + "TOKEN_OBTAIN_SERIALIZER": "rest_framework_simplejwt.serializers.TokenObtainPairSerializer", + "TOKEN_REFRESH_SERIALIZER": "rest_framework_simplejwt.serializers.TokenRefreshSerializer", + "TOKEN_VERIFY_SERIALIZER": "rest_framework_simplejwt.serializers.TokenVerifySerializer", + "TOKEN_BLACKLIST_SERIALIZER": "rest_framework_simplejwt.serializers.TokenBlacklistSerializer", + "SLIDING_TOKEN_OBTAIN_SERIALIZER": "rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer", + "SLIDING_TOKEN_REFRESH_SERIALIZER": "rest_framework_simplejwt.serializers.TokenRefreshSlidingSerializer", } IMPORT_STRINGS = ( diff --git a/rest_framework_simplejwt/views.py b/rest_framework_simplejwt/views.py index 7a13eb59d..f523469ac 100644 --- a/rest_framework_simplejwt/views.py +++ b/rest_framework_simplejwt/views.py @@ -1,9 +1,11 @@ +from django.utils.module_loading import import_string from rest_framework import generics, status from rest_framework.response import Response from . import serializers from .authentication import AUTH_HEADER_TYPES from .exceptions import InvalidToken, TokenError +from .settings import api_settings class TokenViewBase(generics.GenericAPIView): @@ -14,6 +16,14 @@ class TokenViewBase(generics.GenericAPIView): www_authenticate_realm = "api" + def get_serializer_class(self): + # Get the serializer from settings + try: + return import_string(self._serializer_class) + except ImportError: + msg = "Could not import serializer '%s'" % self._serializer_class + raise ImportError(msg) + def get_authenticate_header(self, request): return '{} realm="{}"'.format( AUTH_HEADER_TYPES[0], @@ -37,7 +47,7 @@ class TokenObtainPairView(TokenViewBase): token pair to prove the authentication of those credentials. """ - serializer_class = serializers.TokenObtainPairSerializer + _serializer_class = api_settings.TOKEN_OBTAIN_SERIALIZER token_obtain_pair = TokenObtainPairView.as_view() @@ -49,7 +59,7 @@ class TokenRefreshView(TokenViewBase): token if the refresh token is valid. """ - serializer_class = serializers.TokenRefreshSerializer + _serializer_class = api_settings.TOKEN_REFRESH_SERIALIZER token_refresh = TokenRefreshView.as_view() @@ -61,7 +71,7 @@ class TokenObtainSlidingView(TokenViewBase): prove the authentication of those credentials. """ - serializer_class = serializers.TokenObtainSlidingSerializer + _serializer_class = api_settings.SLIDING_TOKEN_OBTAIN_SERIALIZER token_obtain_sliding = TokenObtainSlidingView.as_view() @@ -73,7 +83,7 @@ class TokenRefreshSlidingView(TokenViewBase): token's refresh period has not expired. """ - serializer_class = serializers.TokenRefreshSlidingSerializer + _serializer_class = api_settings.SLIDING_TOKEN_REFRESH_SERIALIZER token_refresh_sliding = TokenRefreshSlidingView.as_view() @@ -85,7 +95,7 @@ class TokenVerifyView(TokenViewBase): information about a token's fitness for a particular use. """ - serializer_class = serializers.TokenVerifySerializer + _serializer_class = api_settings.TOKEN_VERIFY_SERIALIZER token_verify = TokenVerifyView.as_view() @@ -97,7 +107,7 @@ class TokenBlacklistView(TokenViewBase): `rest_framework_simplejwt.token_blacklist` app installed. """ - serializer_class = serializers.TokenBlacklistSerializer + _serializer_class = api_settings.TOKEN_BLACKLIST_SERIALIZER token_blacklist = TokenBlacklistView.as_view()