Skip to content

Commit

Permalink
Simplify using custom token classes in serializers (#517)
Browse files Browse the repository at this point in the history
For most cases this could be done by overriding get_token, which is simple
enough. The exception was TokenRefreshSerializer.validate where the entire
method needed to be copy-pasted to allow using a custom replacement for
RefreshToken. The other cases are changed the same way mainly for
consistency.
  • Loading branch information
vainu-arto authored Jan 29, 2022
1 parent 92124cf commit 1b2e20e
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions rest_framework_simplejwt/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, *args, **kwargs):

class TokenObtainSerializer(serializers.Serializer):
username_field = get_user_model().USERNAME_FIELD
token_class = None

default_error_messages = {
"no_active_account": _("No active account found with the given credentials")
Expand Down Expand Up @@ -57,15 +58,11 @@ def validate(self, attrs):

@classmethod
def get_token(cls, user):
raise NotImplementedError(
"Must implement `get_token` method for `TokenObtainSerializer` subclasses"
)
return cls.token_class.for_user(user)


class TokenObtainPairSerializer(TokenObtainSerializer):
@classmethod
def get_token(cls, user):
return RefreshToken.for_user(user)
token_class = RefreshToken

def validate(self, attrs):
data = super().validate(attrs)
Expand All @@ -82,9 +79,7 @@ def validate(self, attrs):


class TokenObtainSlidingSerializer(TokenObtainSerializer):
@classmethod
def get_token(cls, user):
return SlidingToken.for_user(user)
token_class = SlidingToken

def validate(self, attrs):
data = super().validate(attrs)
Expand All @@ -102,9 +97,10 @@ def validate(self, attrs):
class TokenRefreshSerializer(serializers.Serializer):
refresh = serializers.CharField()
access = serializers.CharField(read_only=True)
token_class = RefreshToken

def validate(self, attrs):
refresh = RefreshToken(attrs["refresh"])
refresh = self.token_class(attrs["refresh"])

data = {"access": str(refresh.access_token)}

Expand All @@ -129,9 +125,10 @@ def validate(self, attrs):

class TokenRefreshSlidingSerializer(serializers.Serializer):
token = serializers.CharField()
token_class = SlidingToken

def validate(self, attrs):
token = SlidingToken(attrs["token"])
token = self.token_class(attrs["token"])

# Check that the timestamp in the "refresh_exp" claim has not
# passed
Expand Down Expand Up @@ -163,9 +160,10 @@ def validate(self, attrs):

class TokenBlacklistSerializer(serializers.Serializer):
refresh = serializers.CharField()
token_class = RefreshToken

def validate(self, attrs):
refresh = RefreshToken(attrs["refresh"])
refresh = self.token_class(attrs["refresh"])
try:
refresh.blacklist()
except AttributeError:
Expand Down

0 comments on commit 1b2e20e

Please sign in to comment.