diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index 8e98ceda6..64213c722 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -24,6 +24,7 @@ def __init__(self, *args, **kwargs): class TokenObtainSerializer(serializers.Serializer): username_field = get_user_model().USERNAME_FIELD + token_class = None default_error_messages = { "no_active_account": _("No active account found with the given credentials") @@ -57,15 +58,11 @@ def validate(self, attrs): @classmethod def get_token(cls, user): - raise NotImplementedError( - "Must implement `get_token` method for `TokenObtainSerializer` subclasses" - ) + return cls.token_class.for_user(user) class TokenObtainPairSerializer(TokenObtainSerializer): - @classmethod - def get_token(cls, user): - return RefreshToken.for_user(user) + token_class = RefreshToken def validate(self, attrs): data = super().validate(attrs) @@ -82,9 +79,7 @@ def validate(self, attrs): class TokenObtainSlidingSerializer(TokenObtainSerializer): - @classmethod - def get_token(cls, user): - return SlidingToken.for_user(user) + token_class = SlidingToken def validate(self, attrs): data = super().validate(attrs) @@ -102,9 +97,10 @@ def validate(self, attrs): class TokenRefreshSerializer(serializers.Serializer): refresh = serializers.CharField() access = serializers.CharField(read_only=True) + token_class = RefreshToken def validate(self, attrs): - refresh = RefreshToken(attrs["refresh"]) + refresh = self.token_class(attrs["refresh"]) data = {"access": str(refresh.access_token)} @@ -129,9 +125,10 @@ def validate(self, attrs): class TokenRefreshSlidingSerializer(serializers.Serializer): token = serializers.CharField() + token_class = SlidingToken def validate(self, attrs): - token = SlidingToken(attrs["token"]) + token = self.token_class(attrs["token"]) # Check that the timestamp in the "refresh_exp" claim has not # passed @@ -163,9 +160,10 @@ def validate(self, attrs): class TokenBlacklistSerializer(serializers.Serializer): refresh = serializers.CharField() + token_class = RefreshToken def validate(self, attrs): - refresh = RefreshToken(attrs["refresh"]) + refresh = self.token_class(attrs["refresh"]) try: refresh.blacklist() except AttributeError: