Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow other header claims in tokens #531

Open
73VW opened this issue Feb 10, 2022 · 11 comments
Open

Allow other header claims in tokens #531

73VW opened this issue Feb 10, 2022 · 11 comments

Comments

@73VW
Copy link

73VW commented Feb 10, 2022

As defined in RFC7515, section 4.1, tokens can include several more header claims than just typ and alg as allowed from this.

I have tried to include a kid one as I use signed token but I couldn't.

Using pyjwt I was able to add it to the token string but when I called RefreshToken(token) constructor it removed all custom headers.

I have checked in the doc and nothing seems to cover this use case.

I haven't digged much in the code though.

As for kid claim, I suggest to include it by default in header when the token is signed.

(AuthLib documentation for reference)

This is somehow related to #491 as kid might be useful when combined with JWK endpoint.

@73VW
Copy link
Author

73VW commented Feb 10, 2022

Seems that the first part of my issue can be done using what has been done in !517

Sadly I couldn't find any issue related to this.

Any clue when this will be on Pypi?

@73VW
Copy link
Author

73VW commented Feb 10, 2022

Well after digging into the code I have managed to include the kid header claim as I wanted without using what's in !517.

I've had to redefine quite a few classes.

views.py

import jwt
import rest_framework_simplejwt.views as original_views
from authlib.jose import JsonWebKey
from django.conf import settings
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.serializers import (TokenObtainPairSerializer,
                                                  TokenRefreshSerializer)
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import AccessToken, RefreshToken, Token


class TokenBackendWithHeaders(TokenBackend):

    def encode(self, payload, headers={}):
        """
        Returns an encoded token for the given payload dictionary.
        """
        jwt_payload = payload.copy()
        if self.audience is not None:
            jwt_payload["aud"] = self.audience
        if self.issuer is not None:
            jwt_payload["iss"] = self.issuer

        token = jwt.encode(jwt_payload, self.signing_key,
                           algorithm=self.algorithm, headers=headers)
        if isinstance(token, bytes):
            # For PyJWT <= 1.7.1
            return token.decode("utf-8")
        # For PyJWT >= 2.0.0a1
        return token


class TokenWithAnotherTokenBackend(Token):
    _token_backend = TokenBackendWithHeaders(
        api_settings.ALGORITHM,
        api_settings.SIGNING_KEY,
        api_settings.VERIFYING_KEY,
        api_settings.AUDIENCE,
        api_settings.ISSUER,
        api_settings.JWK_URL,
        api_settings.LEEWAY,
    )

    def __init__(self, token=None, verify=True):
        Token.__init__(self, token, verify)
        self.headers = {}

    def __str__(self):
        """
        Signs and returns a token as a base64 encoded string.
        """
        return self.get_token_backend().encode(self.payload, self.headers)


class AccessTokenWithAnotherTokenBackend(AccessToken, TokenWithAnotherTokenBackend):
    pass


class RefreshTokenWithAnotherTokenBackend(RefreshToken, TokenWithAnotherTokenBackend):

    @property
    def access_token(self):
        """
        Returns an access token created from this refresh token.  Copies all
        claims present in this refresh token to the new access token except
        those claims listed in the `no_copy_claims` attribute.
        """
        access = AccessTokenWithAnotherTokenBackend()

        # Use instantiation time of refresh token as relative timestamp for
        # access token "exp" claim.  This ensures that both a refresh and
        # access token expire relative to the same time if they are created as
        # a pair.
        access.set_exp(from_time=self.current_time)

        no_copy = self.no_copy_claims
        for claim, value in self.payload.items():
            if claim in no_copy:
                continue
            access[claim] = value

        for claim, value in self.headers.items():
            access.headers[claim] = value

        return access


class TokenObtainPairSerializerDifferentToken(TokenObtainPairSerializer):
    token_class = RefreshTokenWithAnotherTokenBackend

    @classmethod
    def get_token(cls, user):

        key = JsonWebKey.import_key(
            settings.SIMPLE_JWT['VERIFYING_KEY'], {'kty': 'RSA'})
        token = cls.token_class.for_user(user)

        # Add custom header claims
        token.headers['kid'] = key.thumbprint()

        return token


class TokenRefreshSerializerDifferentToken(TokenRefreshSerializer):

    # Needed to redifine all of this due to the hardcoded "RefreshToken" in
    # the original code. Replaced here by "RefreshTokenWithAnotherTokenBackend"
    # PR for fixing this was already merged. New version of simple-jwt should
    # include changes contained in
    # https://github.com/jazzband/djangorestframework-simplejwt/pull/517
    def validate(self, attrs):
        refresh = RefreshTokenWithAnotherTokenBackend(attrs['refresh'])

        data = {'access': str(refresh.access_token)}

        if api_settings.ROTATE_REFRESH_TOKENS:
            if api_settings.BLACKLIST_AFTER_ROTATION:
                try:
                    # Attempt to blacklist the given refresh token
                    refresh.blacklist()
                except AttributeError:
                    # If blacklist app not installed, `blacklist` method will
                    # not be present
                    pass

            refresh.set_jti()
            refresh.set_exp()
            refresh.set_iat()

            data['refresh'] = str(refresh)

        return data


class TokenObtainPairView(original_views.TokenObtainPairView):
    serializer_class = TokenObtainPairSerializerDifferentToken


class TokenRefreshView(original_views.TokenRefreshView):
    serializer_class = TokenRefreshSerializerDifferentToken

urls.py

"""e_abeilles URL Configuration

The `urlpatterns` list routes URLs to views. For more information please see:
    https://docs.djangoproject.com/en/4.0/topics/http/urls/
Examples:
Function views
    1. Add an import:  from my_app import views
    2. Add a URL to urlpatterns:  path('', views.home, name='home')
Class-based views
    1. Add an import:  from other_app.views import Home
    2. Add a URL to urlpatterns:  path('', Home.as_view(), name='home')
Including another URLconf
    1. Import the include() function: from django.urls import include, path
    2. Add a URL to urlpatterns:  path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import path, include

from rest_framework_simplejwt import views as jwt_views
from my_package import views

urlpatterns = [
    path('admin/', admin.site.urls),
    path('api/token/', views.TokenObtainPairView.as_view(),
         name='token_obtain_pair'),
    path('api/token/refresh/', views.TokenRefreshView.as_view(),
         name='token_refresh'),
]

@73VW
Copy link
Author

73VW commented Feb 10, 2022

Could this be included in the base code? I can open a PR if you wish!

@Andrew-Chen-Wang
Copy link
Member

What we’ve done in the past is have a callable or a dotted import string in SIMPLE_JWT settings. In the serializer, we can pass the token to your function. This is similar to the authorization callable.

@73VW
Copy link
Author

73VW commented Feb 11, 2022

@Andrew-Chen-Wang That might be possible but I don't think this is the way to go as it involves encoding a token -> sending it to the callback -> decoding it -> adding a header while reencoding it -> sending it back.

Performancewise, adding a header before encoding it would be much better, don't you think?

@Andrew-Chen-Wang
Copy link
Member

Yes, it definitely would be. I just worry about the ordering and people missing something with override classes. But please open a PR and we shall deliberate :)

@rj76
Copy link

rj76 commented Jun 19, 2022

For anyone interested, here is the same for sliding tokens

import jwt
import rest_framework_simplejwt.views as original_views
from authlib.jose import JsonWebKey
from django.conf import settings
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.serializers import TokenObtainSlidingSerializer
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import SlidingToken, Token


class TokenBackendWithHeaders(TokenBackend):

    def encode(self, payload, headers={}):
        """
        Returns an encoded token for the given payload dictionary.
        """
        jwt_payload = payload.copy()
        if self.audience is not None:
            jwt_payload["aud"] = self.audience
        if self.issuer is not None:
            jwt_payload["iss"] = self.issuer

        token = jwt.encode(jwt_payload, self.signing_key,
                           algorithm=self.algorithm, headers=headers)
        if isinstance(token, bytes):
            # For PyJWT <= 1.7.1
            return token.decode("utf-8")
        # For PyJWT >= 2.0.0a1
        return token


class TokenWithAnotherTokenBackend(Token):
    _token_backend = TokenBackendWithHeaders(
        api_settings.ALGORITHM,
        api_settings.SIGNING_KEY,
        api_settings.VERIFYING_KEY,
        api_settings.AUDIENCE,
        api_settings.ISSUER,
        api_settings.JWK_URL,
        api_settings.LEEWAY,
    )

    def __init__(self, token=None, verify=True):
        Token.__init__(self, token, verify)
        self.headers = {}

    def __str__(self):
        """
        Signs and returns a token as a base64 encoded string.
        """
        return self.get_token_backend().encode(self.payload, self.headers)


class SlidingokenWithAnotherTokenBackend(SlidingToken, TokenWithAnotherTokenBackend):
    pass


class TokenObtainSlidingSerializerDifferentToken(TokenObtainSlidingSerializer):
    token_class = SlidingokenWithAnotherTokenBackend

    @classmethod
    def get_token(cls, user):

        key = JsonWebKey.import_key(
            settings.SIMPLE_JWT['VERIFYING_KEY'], {'kty': 'RSA'})
        token = cls.token_class.for_user(user)

        # Add custom header claims
        token.headers['kid'] = key.thumbprint()

        return token


class TokenObtainSlidingView(original_views.TokenObtainPairView):
    serializer_class = TokenObtainSlidingSerializerDifferentToken

@nixsiow
Copy link

nixsiow commented Mar 31, 2023

Has this been incorporated or solved in the latest codebase as I am currently facing the exact same issue of trying to add a 'kid' claim into the header of the signed token? So strange that this is not mentioned anywhere in the docs.

@Andrew-Chen-Wang
Copy link
Member

This is not implemented.

@steven-jeanneret
Copy link

Would a new settings EXTRA_JWT_HEADERS be a solution?

I'm facing this problem where I want to add kid in the headers.

@adamJLev
Copy link

At this point it would be good to just have the kid header always added by default no?
Its part of the JWK standard now
https://datatracker.ietf.org/doc/html/rfc7517#section-4.5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants