Skip to content

Commit

Permalink
Add a utility function for creating user tokens (#341)
Browse files Browse the repository at this point in the history
* Add a utility function for creating user tokens

* Fix async context

* Fix mypy

* Update tests and behavior
  • Loading branch information
KundaPanda authored Jul 25, 2022
1 parent 4535f85 commit 46072d7
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 13 deletions.
7 changes: 2 additions & 5 deletions strawberry_django_jwt/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,8 @@ def refresh_expiration(f):
@wraps(f)
def wrapper(cls, *args, **kwargs):
def on_resolve(payload):
payload.refresh_expires_in = (
timegm(datetime.utcnow().utctimetuple()) + jwt_settings.JWT_REFRESH_EXPIRATION_DELTA.total_seconds()
if jwt_settings.JWT_LONG_RUNNING_REFRESH_TOKEN
else None
)
if jwt_settings.JWT_ALLOW_REFRESH:
payload.refresh_expires_in = timegm(datetime.utcnow().utctimetuple()) + jwt_settings.JWT_REFRESH_EXPIRATION_DELTA.total_seconds()
return payload

result = f(cls, *args, **kwargs)
Expand Down
8 changes: 2 additions & 6 deletions strawberry_django_jwt/object_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ class PayloadType:


@strawberry.type
@inject_fields(
{
**({"refresh_token": (str, "")} if jwt_settings.JWT_ALLOW_REFRESH else {}),
}
)
class TokenDataType:
payload: TokenPayloadType
token: str = ""
refresh_expires_in: Optional[int] = 0
refresh_token: Optional[str] = None
refresh_expires_in: Optional[int] = None
3 changes: 2 additions & 1 deletion strawberry_django_jwt/refresh_token/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.utils.translation import gettext as _

from strawberry_django_jwt.exceptions import JSONWebTokenError
from strawberry_django_jwt.refresh_token.models import AbstractRefreshToken
from strawberry_django_jwt.refresh_token.utils import get_refresh_token_model
from strawberry_django_jwt.settings import jwt_settings

Expand All @@ -21,7 +22,7 @@ def get_refresh_token(token, context=None):
raise JSONWebTokenError(_("Invalid refresh token"))


def create_refresh_token(user, refresh_token=None):
def create_refresh_token(user, refresh_token=None) -> AbstractRefreshToken:
if refresh_token is not None and jwt_settings.JWT_REUSE_REFRESH_TOKENS:
refresh_token.reuse()
return refresh_token
Expand Down
23 changes: 22 additions & 1 deletion strawberry_django_jwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from calendar import timegm
from contextlib import suppress
from datetime import datetime
Expand All @@ -8,6 +9,7 @@

from asgiref.sync import sync_to_async
from django.contrib.auth import get_user_model
from django.contrib.auth.models import User
from django.http import HttpRequest
from django.utils.translation import gettext as _
from graphql import GraphQLResolveInfo
Expand All @@ -18,7 +20,8 @@
from strawberry.django.context import StrawberryDjangoContext
from strawberry.types import Info

from strawberry_django_jwt import exceptions, object_types
from strawberry_django_jwt import exceptions, object_types, signals
from strawberry_django_jwt.refresh_token.shortcuts import create_refresh_token
from strawberry_django_jwt.settings import jwt_settings

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -238,3 +241,21 @@ def get_context(info: HttpRequest | Request | Info[Any, Any] | GraphQLResolveInf
return ctx.request
return ctx
return info


async def create_user_token(user: User) -> object_types.TokenDataType:
token: object_types.TokenPayloadType = jwt_settings.JWT_PAYLOAD_HANDLER(user)
token_object = object_types.TokenDataType(payload=token, token=jwt_settings.JWT_ENCODE_HANDLER(token))
if jwt_settings.JWT_ALLOW_REFRESH:
token_object.refresh_expires_in = token.exp - int(datetime.now().timestamp())
if jwt_settings.JWT_LONG_RUNNING_REFRESH_TOKEN:
refresh_token = ( # type: ignore
(await sync_to_async(create_refresh_token)(user)) if asyncio.get_event_loop().is_running() else create_refresh_token(user)
)
token_object.refresh_expires_in = (
refresh_token.created.timestamp() + jwt_settings.JWT_REFRESH_EXPIRATION_DELTA.total_seconds() - int(datetime.now().timestamp())
)
token_object.refresh_token = refresh_token.get_token()

signals.token_issued.send(sender=create_user_token, request=None, user=user)
return token_object
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import timedelta
from functools import wraps
from imp import reload
import importlib
from types import ModuleType
from unittest import mock
Expand All @@ -11,6 +12,7 @@
import strawberry_django_jwt.object_types
from strawberry_django_jwt.object_types import TokenPayloadType
from strawberry_django_jwt.settings import jwt_settings
from strawberry_django_jwt.shortcuts import get_user_by_token_async
from tests.decorators import OverrideJwtSettings
from tests.testcases import AsyncTestCase, TestCase

Expand Down Expand Up @@ -202,3 +204,22 @@ async def test_user_disabled_by_payload_async(self):
return_value=False,
), self.assertRaises(exceptions.JSONWebTokenError):
await utils.get_user_by_payload_async(payload)


class CreateUserTokenTestsAsync(AsyncTestCase):
@OverrideJwtSettings(JWT_LONG_RUNNING_REFRESH_TOKEN=False)
async def test_create_user_token_async(self):
reload(utils)
token = await utils.create_user_token(self.user)
user = await get_user_by_token_async(token.token)
assert user == self.user
assert token.refresh_token is None
assert token.refresh_expires_in - jwt_settings.JWT_EXPIRATION_DELTA.total_seconds() < 5

@OverrideJwtSettings(JWT_LONG_RUNNING_REFRESH_TOKEN=True)
async def test_create_user_token_with_refresh_async(self):
token = await utils.create_user_token(self.user)
user = await get_user_by_token_async(token.token)
assert user == self.user
assert token.refresh_token is not None
assert token.refresh_expires_in - jwt_settings.JWT_REFRESH_EXPIRATION_DELTA.total_seconds() < 5

0 comments on commit 46072d7

Please sign in to comment.