From 46072d7dedc72c0b2248047521c6c34fa03027a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20Dohnal?= Date: Mon, 25 Jul 2022 19:49:10 +0200 Subject: [PATCH] Add a utility function for creating user tokens (#341) * Add a utility function for creating user tokens * Fix async context * Fix mypy * Update tests and behavior --- strawberry_django_jwt/decorators.py | 7 ++---- strawberry_django_jwt/object_types.py | 8 ++----- .../refresh_token/shortcuts.py | 3 ++- strawberry_django_jwt/utils.py | 23 ++++++++++++++++++- tests/test_utils.py | 21 +++++++++++++++++ 5 files changed, 49 insertions(+), 13 deletions(-) diff --git a/strawberry_django_jwt/decorators.py b/strawberry_django_jwt/decorators.py index 7a0f92c7..a63f139d 100644 --- a/strawberry_django_jwt/decorators.py +++ b/strawberry_django_jwt/decorators.py @@ -204,11 +204,8 @@ def refresh_expiration(f): @wraps(f) def wrapper(cls, *args, **kwargs): def on_resolve(payload): - payload.refresh_expires_in = ( - timegm(datetime.utcnow().utctimetuple()) + jwt_settings.JWT_REFRESH_EXPIRATION_DELTA.total_seconds() - if jwt_settings.JWT_LONG_RUNNING_REFRESH_TOKEN - else None - ) + if jwt_settings.JWT_ALLOW_REFRESH: + payload.refresh_expires_in = timegm(datetime.utcnow().utctimetuple()) + jwt_settings.JWT_REFRESH_EXPIRATION_DELTA.total_seconds() return payload result = f(cls, *args, **kwargs) diff --git a/strawberry_django_jwt/object_types.py b/strawberry_django_jwt/object_types.py index c4246c52..162c27ac 100644 --- a/strawberry_django_jwt/object_types.py +++ b/strawberry_django_jwt/object_types.py @@ -43,12 +43,8 @@ class PayloadType: @strawberry.type -@inject_fields( - { - **({"refresh_token": (str, "")} if jwt_settings.JWT_ALLOW_REFRESH else {}), - } -) class TokenDataType: payload: TokenPayloadType token: str = "" - refresh_expires_in: Optional[int] = 0 + refresh_token: Optional[str] = None + refresh_expires_in: Optional[int] = None diff --git a/strawberry_django_jwt/refresh_token/shortcuts.py b/strawberry_django_jwt/refresh_token/shortcuts.py index 9b6d3ea9..58ce69cc 100644 --- a/strawberry_django_jwt/refresh_token/shortcuts.py +++ b/strawberry_django_jwt/refresh_token/shortcuts.py @@ -3,6 +3,7 @@ from django.utils.translation import gettext as _ from strawberry_django_jwt.exceptions import JSONWebTokenError +from strawberry_django_jwt.refresh_token.models import AbstractRefreshToken from strawberry_django_jwt.refresh_token.utils import get_refresh_token_model from strawberry_django_jwt.settings import jwt_settings @@ -21,7 +22,7 @@ def get_refresh_token(token, context=None): raise JSONWebTokenError(_("Invalid refresh token")) -def create_refresh_token(user, refresh_token=None): +def create_refresh_token(user, refresh_token=None) -> AbstractRefreshToken: if refresh_token is not None and jwt_settings.JWT_REUSE_REFRESH_TOKENS: refresh_token.reuse() return refresh_token diff --git a/strawberry_django_jwt/utils.py b/strawberry_django_jwt/utils.py index a0577287..e1046642 100644 --- a/strawberry_django_jwt/utils.py +++ b/strawberry_django_jwt/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from calendar import timegm from contextlib import suppress from datetime import datetime @@ -8,6 +9,7 @@ from asgiref.sync import sync_to_async from django.contrib.auth import get_user_model +from django.contrib.auth.models import User from django.http import HttpRequest from django.utils.translation import gettext as _ from graphql import GraphQLResolveInfo @@ -18,7 +20,8 @@ from strawberry.django.context import StrawberryDjangoContext from strawberry.types import Info -from strawberry_django_jwt import exceptions, object_types +from strawberry_django_jwt import exceptions, object_types, signals +from strawberry_django_jwt.refresh_token.shortcuts import create_refresh_token from strawberry_django_jwt.settings import jwt_settings if TYPE_CHECKING: # pragma: no cover @@ -238,3 +241,21 @@ def get_context(info: HttpRequest | Request | Info[Any, Any] | GraphQLResolveInf return ctx.request return ctx return info + + +async def create_user_token(user: User) -> object_types.TokenDataType: + token: object_types.TokenPayloadType = jwt_settings.JWT_PAYLOAD_HANDLER(user) + token_object = object_types.TokenDataType(payload=token, token=jwt_settings.JWT_ENCODE_HANDLER(token)) + if jwt_settings.JWT_ALLOW_REFRESH: + token_object.refresh_expires_in = token.exp - int(datetime.now().timestamp()) + if jwt_settings.JWT_LONG_RUNNING_REFRESH_TOKEN: + refresh_token = ( # type: ignore + (await sync_to_async(create_refresh_token)(user)) if asyncio.get_event_loop().is_running() else create_refresh_token(user) + ) + token_object.refresh_expires_in = ( + refresh_token.created.timestamp() + jwt_settings.JWT_REFRESH_EXPIRATION_DELTA.total_seconds() - int(datetime.now().timestamp()) + ) + token_object.refresh_token = refresh_token.get_token() + + signals.token_issued.send(sender=create_user_token, request=None, user=user) + return token_object diff --git a/tests/test_utils.py b/tests/test_utils.py index 6a2af239..1733f6d8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ from datetime import timedelta from functools import wraps +from imp import reload import importlib from types import ModuleType from unittest import mock @@ -11,6 +12,7 @@ import strawberry_django_jwt.object_types from strawberry_django_jwt.object_types import TokenPayloadType from strawberry_django_jwt.settings import jwt_settings +from strawberry_django_jwt.shortcuts import get_user_by_token_async from tests.decorators import OverrideJwtSettings from tests.testcases import AsyncTestCase, TestCase @@ -202,3 +204,22 @@ async def test_user_disabled_by_payload_async(self): return_value=False, ), self.assertRaises(exceptions.JSONWebTokenError): await utils.get_user_by_payload_async(payload) + + +class CreateUserTokenTestsAsync(AsyncTestCase): + @OverrideJwtSettings(JWT_LONG_RUNNING_REFRESH_TOKEN=False) + async def test_create_user_token_async(self): + reload(utils) + token = await utils.create_user_token(self.user) + user = await get_user_by_token_async(token.token) + assert user == self.user + assert token.refresh_token is None + assert token.refresh_expires_in - jwt_settings.JWT_EXPIRATION_DELTA.total_seconds() < 5 + + @OverrideJwtSettings(JWT_LONG_RUNNING_REFRESH_TOKEN=True) + async def test_create_user_token_with_refresh_async(self): + token = await utils.create_user_token(self.user) + user = await get_user_by_token_async(token.token) + assert user == self.user + assert token.refresh_token is not None + assert token.refresh_expires_in - jwt_settings.JWT_REFRESH_EXPIRATION_DELTA.total_seconds() < 5