diff --git a/rest_auth/serializers.py b/rest_auth/serializers.py index b6452317..02d47ac0 100644 --- a/rest_auth/serializers.py +++ b/rest_auth/serializers.py @@ -137,7 +137,8 @@ class JWTSerializer(serializers.Serializer): """ Serializer for JWT authentication. """ - token = serializers.CharField() + access_token = serializers.CharField() + refresh_token = serializers.CharField() user = serializers.SerializerMethodField() def get_user(self, obj): diff --git a/rest_auth/utils.py b/rest_auth/utils.py index 800f184c..7fb89392 100644 --- a/rest_auth/utils.py +++ b/rest_auth/utils.py @@ -1,5 +1,7 @@ from six import string_types from importlib import import_module +from rest_framework_simplejwt.serializers import TokenObtainPairSerializer +from rest_framework_simplejwt.views import TokenObtainPairView def import_callable(path_or_callable): @@ -17,13 +19,5 @@ def default_create_token(token_model, user, serializer): def jwt_encode(user): - try: - from rest_framework_jwt.settings import api_settings - except ImportError: - raise ImportError("djangorestframework_jwt needs to be installed") - - jwt_payload_handler = api_settings.JWT_PAYLOAD_HANDLER - jwt_encode_handler = api_settings.JWT_ENCODE_HANDLER - - payload = jwt_payload_handler(user) - return jwt_encode_handler(payload) + refresh = TokenObtainPairSerializer.get_token(user) + return refresh.access_token, refresh diff --git a/rest_auth/views.py b/rest_auth/views.py index 0a0a982e..8f443dec 100644 --- a/rest_auth/views.py +++ b/rest_auth/views.py @@ -62,7 +62,7 @@ def login(self): self.user = self.serializer.validated_data['user'] if getattr(settings, 'REST_USE_JWT', False): - self.token = jwt_encode(self.user) + self.access_token, self.refresh_token = jwt_encode(self.user) else: self.token = create_token(self.token_model, self.user, self.serializer) @@ -76,7 +76,8 @@ def get_response(self): if getattr(settings, 'REST_USE_JWT', False): data = { 'user': self.user, - 'token': self.token + 'access_token': self.access_token, + 'refresh_token': self.refresh_token, } serializer = serializer_class(instance=data, context={'request': self.request}) @@ -85,15 +86,6 @@ def get_response(self): context={'request': self.request}) response = Response(serializer.data, status=status.HTTP_200_OK) - if getattr(settings, 'REST_USE_JWT', False): - from rest_framework_jwt.settings import api_settings as jwt_settings - if jwt_settings.JWT_AUTH_COOKIE: - from datetime import datetime - expiration = (datetime.utcnow() + jwt_settings.JWT_EXPIRATION_DELTA) - response.set_cookie(jwt_settings.JWT_AUTH_COOKIE, - self.token, - expires=expiration, - httponly=True) return response def post(self, request, *args, **kwargs):