From 1b2e20e48c05f0f2975c2fafbb12146cecc32882 Mon Sep 17 00:00:00 2001 From: vainu-arto <70135394+vainu-arto@users.noreply.github.com> Date: Sat, 29 Jan 2022 05:56:23 +0200 Subject: [PATCH] Simplify using custom token classes in serializers (#517) For most cases this could be done by overriding get_token, which is simple enough. The exception was TokenRefreshSerializer.validate where the entire method needed to be copy-pasted to allow using a custom replacement for RefreshToken. The other cases are changed the same way mainly for consistency. --- rest_framework_simplejwt/serializers.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index 8e98ceda6..64213c722 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -24,6 +24,7 @@ def __init__(self, *args, **kwargs): class TokenObtainSerializer(serializers.Serializer): username_field = get_user_model().USERNAME_FIELD + token_class = None default_error_messages = { "no_active_account": _("No active account found with the given credentials") @@ -57,15 +58,11 @@ def validate(self, attrs): @classmethod def get_token(cls, user): - raise NotImplementedError( - "Must implement `get_token` method for `TokenObtainSerializer` subclasses" - ) + return cls.token_class.for_user(user) class TokenObtainPairSerializer(TokenObtainSerializer): - @classmethod - def get_token(cls, user): - return RefreshToken.for_user(user) + token_class = RefreshToken def validate(self, attrs): data = super().validate(attrs) @@ -82,9 +79,7 @@ def validate(self, attrs): class TokenObtainSlidingSerializer(TokenObtainSerializer): - @classmethod - def get_token(cls, user): - return SlidingToken.for_user(user) + token_class = SlidingToken def validate(self, attrs): data = super().validate(attrs) @@ -102,9 +97,10 @@ def validate(self, attrs): class TokenRefreshSerializer(serializers.Serializer): refresh = serializers.CharField() access = serializers.CharField(read_only=True) + token_class = RefreshToken def validate(self, attrs): - refresh = RefreshToken(attrs["refresh"]) + refresh = self.token_class(attrs["refresh"]) data = {"access": str(refresh.access_token)} @@ -129,9 +125,10 @@ def validate(self, attrs): class TokenRefreshSlidingSerializer(serializers.Serializer): token = serializers.CharField() + token_class = SlidingToken def validate(self, attrs): - token = SlidingToken(attrs["token"]) + token = self.token_class(attrs["token"]) # Check that the timestamp in the "refresh_exp" claim has not # passed @@ -163,9 +160,10 @@ def validate(self, attrs): class TokenBlacklistSerializer(serializers.Serializer): refresh = serializers.CharField() + token_class = RefreshToken def validate(self, attrs): - refresh = RefreshToken(attrs["refresh"]) + refresh = self.token_class(attrs["refresh"]) try: refresh.blacklist() except AttributeError: