From 3de576c916ea122d21d2883abdb1fb959c98d3f8 Mon Sep 17 00:00:00 2001 From: Jamie Cockburn Date: Tue, 28 Feb 2023 17:37:20 +0000 Subject: [PATCH 1/3] Added middleware to refresh access tokens --- django_auth_adfs/backend.py | 36 +++++++++++- django_auth_adfs/config.py | 1 + django_auth_adfs/middleware.py | 16 ++++++ tests/test_authentication.py | 102 ++++++++++++++++----------------- 4 files changed, 101 insertions(+), 54 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 4574cb42..1c6d91e4 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -1,11 +1,13 @@ import logging +from datetime import datetime, timedelta import jwt -from django.contrib.auth import get_user_model +from django.contrib.auth import get_user_model, logout from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied) +from requests import HTTPError from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -398,10 +400,38 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): provider_config.load_config() adfs_response = self.exchange_auth_code(authorization_code, request) - access_token = adfs_response["access_token"] - user = self.process_access_token(access_token, adfs_response) + user = self._process_adfs_response(request, adfs_response) return user + def _process_adfs_response(self, request, adfs_response): + user = self.process_access_token(adfs_response['access_token'], adfs_response) + request.session['adfs_access_token'] = adfs_response['access_token'] + expiry = datetime.now() + timedelta(seconds=adfs_response['expires_in']) + request.session['adfs_token_expiry'] = expiry.isoformat() + if 'refresh_token' in adfs_response: + request.session['adfs_refresh_token'] = adfs_response['refresh_token'] + request.session.save() + return user + + def process_request(self, request): + now = datetime.now() + settings.REFRESH_THRESHOLD + expiry = datetime.fromisoformat(request.session['adfs_token_expiry']) + if now > expiry: + try: + self._refresh_access_token(request, request.session['adfs_refresh_token']) + except (PermissionDenied, HTTPError): + logout(request) + + def _refresh_access_token(self, request, refresh_token): + provider_config.load_config() + response = provider_config.session.post( + provider_config.token_endpoint, + data=f'grant_type=refresh_token&refresh_token={refresh_token}' + ) + response.raise_for_status() + adfs_response = response.json() + self._process_adfs_response(request, adfs_response) + class AdfsAccessTokenBackend(AdfsBaseBackend): """ diff --git a/django_auth_adfs/config.py b/django_auth_adfs/config.py index 12c36dc9..9c7ebb74 100644 --- a/django_auth_adfs/config.py +++ b/django_auth_adfs/config.py @@ -72,6 +72,7 @@ def __init__(self): self.USERNAME_CLAIM = "winaccountname" self.GUEST_USERNAME_CLAIM = None self.JWT_LEEWAY = 0 + self.REFRESH_THRESHOLD = timedelta(minutes=5) self.CUSTOM_FAILED_RESPONSE_VIEW = lambda request, error_message, status: render( request, 'django_auth_adfs/login_failed.html', {'error_message': error_message}, status=status ) diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 649a2390..0b4c50eb 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -4,9 +4,11 @@ from re import compile from django.conf import settings as django_settings +from django.contrib import auth from django.contrib.auth.views import redirect_to_login from django.urls import reverse +from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.exceptions import MFARequired from django_auth_adfs.config import settings @@ -49,3 +51,17 @@ def __call__(self, request): return redirect_to_login('django_auth_adfs:login-force-mfa') return self.get_response(request) + + +def adfs_refresh_middleware(get_response): + def middleware(request): + try: + backend_str = request.session[auth.BACKEND_SESSION_KEY] + except KeyError: + pass + else: + backend = auth.load_backend(backend_str) + if isinstance(backend, AdfsAuthCodeBackend): + backend.process_request(request) + return get_response() + return middleware diff --git a/tests/test_authentication.py b/tests/test_authentication.py index c16691fc..6d822a8f 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,5 +1,7 @@ import base64 +from django.urls import reverse + from django_auth_adfs.exceptions import MFARequired try: @@ -16,7 +18,6 @@ from mock import Mock, patch from django_auth_adfs import signals -from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.config import ProviderConfig, Settings from .models import Profile @@ -34,14 +35,13 @@ def setUp(self): @mock_adfs("2012") def test_post_authenticate_signal_send(self): - backend = AdfsAuthCodeBackend() - backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) self.assertEqual(self.signal_handler.call_count, 1) @mock_adfs("2012") def test_with_auth_code_2012(self): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -52,8 +52,8 @@ def test_with_auth_code_2012(self): @mock_adfs("2016") def test_with_auth_code_2016(self): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -64,9 +64,15 @@ def test_with_auth_code_2016(self): @mock_adfs("2016", mfa_error=True) def test_mfa_error_backends(self): - with self.assertRaises(MFARequired): - backend = AdfsAuthCodeBackend() - backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + self.assertEqual(response.status_code, 302) + self.assertEqual( + response['Location'], + "https://adfs.example.com/adfs/oauth2/authorize/?response_type=code&" + "client_id=your-configured-client-id&resource=your-adfs-RPT-name&" + "redirect_uri=http%3A%2F%2Ftestserver%2Foauth2%2Fcallback&state=Lw%3D%3D&scope=openid&" + "amr_values=ngcmfa" + ) @mock_adfs("azure") def test_with_auth_code_azure(self): @@ -77,8 +83,8 @@ def test_with_auth_code_azure(self): with patch("django_auth_adfs.config.django_settings", settings): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -100,9 +106,8 @@ def test_with_auth_code_azure_guest_block(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - with self.assertRaises(PermissionDenied, msg=''): - backend = AdfsAuthCodeBackend() - _ = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + self.assertEqual(response.status_code, 401) @mock_adfs("azure", guest=True) def test_with_auth_code_azure_guest_no_block(self): @@ -117,8 +122,8 @@ def test_with_auth_code_azure_guest_no_block(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -139,8 +144,8 @@ def test_version_two_endpoint_calls_correct_url(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -151,14 +156,15 @@ def test_version_two_endpoint_calls_correct_url(self): @mock_adfs("2016") def test_empty(self): - backend = AdfsAuthCodeBackend() - self.assertIsNone(backend.authenticate(self.request)) + response = self.client.get(reverse('django_auth_adfs:callback')) + user = response.wsgi_request.user + self.assertTrue(user.is_anonymous) @mock_adfs("2016") def test_group_claim(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", "nonexisting"): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -167,9 +173,9 @@ def test_group_claim(self): @mock_adfs("2016") def test_no_group_claim(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", None): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -181,9 +187,9 @@ def test_group_claim_with_mirror_groups(self): # Remove one group Group.objects.filter(name="group1").delete() - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -197,9 +203,9 @@ def test_group_claim_without_mirror_groups(self): # Remove one group Group.objects.filter(name="group1").delete() - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -210,9 +216,9 @@ def test_group_claim_without_mirror_groups(self): @mock_adfs("2016", empty_keys=True) def test_empty_keys(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.config.provider_config.signing_keys", []): - self.assertRaises(PermissionDenied, backend.authenticate, self.request, authorization_code='testcode') + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertEqual(response.status_code, 401) @mock_adfs("2016") def test_group_removal(self): @@ -227,9 +233,8 @@ def test_group_removal(self): self.assertEqual(user.groups.all()[0].name, "group3") self.assertEqual(len(user.groups.all()), 1) - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -253,9 +258,8 @@ def test_group_removal_overlap(self): self.assertEqual(user.groups.all()[1].name, "group3") self.assertEqual(len(user.groups.all()), 2) - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -272,9 +276,8 @@ def test_group_to_flag_mapping(self): } with patch("django_auth_adfs.backend.settings.GROUP_TO_FLAG_MAPPING", group_to_flag_mapping): with patch("django_auth_adfs.backend.settings.BOOLEAN_CLAIM_MAPPING", {}): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -289,9 +292,8 @@ def test_boolean_claim_mapping(self): "is_superuser": "user_is_superuser", } with patch("django_auth_adfs.backend.settings.BOOLEAN_CLAIM_MAPPING", boolean_claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -312,9 +314,8 @@ def test_extended_model_claim_mapping_missing_instance(self): }, } with patch("django_auth_adfs.backend.settings.CLAIM_MAPPING", claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -340,9 +341,8 @@ def create_profile(sender, instance, created, **kwargs): }, } with patch("django_auth_adfs.backend.settings.CLAIM_MAPPING", claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -493,5 +493,5 @@ def test_nonexisting_user(self): settings.AUTH_ADFS["CREATE_NEW_USERS"] = False with patch("django_auth_adfs.config.django_settings", settings),\ patch("django_auth_adfs.backend.settings", Settings()): - backend = AdfsAuthCodeBackend() - self.assertRaises(PermissionDenied, backend.authenticate, self.request, authorization_code='testcode') + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertEqual(response.status_code, 401) From 9abbcbdd14f2f6fd9abd474247565f8dff468784 Mon Sep 17 00:00:00 2001 From: Jamie Cockburn Date: Tue, 28 Feb 2023 18:55:37 +0000 Subject: [PATCH 2/3] Added middleware to refresh access tokens --- django_auth_adfs/backend.py | 15 ++++++++----- django_auth_adfs/middleware.py | 2 +- tests/settings.py | 1 + tests/test_authentication.py | 40 ++++++++++++++++++++++++++++++---- tests/urls.py | 5 ++++- tests/utils.py | 37 ++++++++++++++++++++++++------- tests/views.py | 9 ++++++++ 7 files changed, 90 insertions(+), 19 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 1c6d91e4..8f75c948 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -396,6 +396,11 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): logger.debug("Authentication backend was called but no authorization code was received") return + # If there's no request object, we pass control to the next authentication backend + if request is None: + logger.debug("Authentication backend was called without request") + return + # If loaded data is too old, reload it again provider_config.load_config() @@ -405,20 +410,20 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): def _process_adfs_response(self, request, adfs_response): user = self.process_access_token(adfs_response['access_token'], adfs_response) - request.session['adfs_access_token'] = adfs_response['access_token'] + request.session['_adfs_access_token'] = adfs_response['access_token'] expiry = datetime.now() + timedelta(seconds=adfs_response['expires_in']) - request.session['adfs_token_expiry'] = expiry.isoformat() + request.session['_adfs_token_expiry'] = expiry.isoformat() if 'refresh_token' in adfs_response: - request.session['adfs_refresh_token'] = adfs_response['refresh_token'] + request.session['_adfs_refresh_token'] = adfs_response['refresh_token'] request.session.save() return user def process_request(self, request): now = datetime.now() + settings.REFRESH_THRESHOLD - expiry = datetime.fromisoformat(request.session['adfs_token_expiry']) + expiry = datetime.fromisoformat(request.session['_adfs_token_expiry']) if now > expiry: try: - self._refresh_access_token(request, request.session['adfs_refresh_token']) + self._refresh_access_token(request, request.session['_adfs_refresh_token']) except (PermissionDenied, HTTPError): logout(request) diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 0b4c50eb..0163ea78 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -63,5 +63,5 @@ def middleware(request): backend = auth.load_backend(backend_str) if isinstance(backend, AdfsAuthCodeBackend): backend.process_request(request) - return get_response() + return get_response(request) return middleware diff --git a/tests/settings.py b/tests/settings.py index 81d397c7..121507e0 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -35,6 +35,7 @@ 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'django_auth_adfs.middleware.adfs_refresh_middleware', 'django_auth_adfs.middleware.LoginRequiredMiddleware', ) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 6d822a8f..edd1b3bc 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,5 +1,7 @@ import base64 +from datetime import datetime, timedelta + from django.urls import reverse from django_auth_adfs.exceptions import MFARequired @@ -12,9 +14,9 @@ from copy import deepcopy from django.contrib.auth.models import Group, User -from django.core.exceptions import ObjectDoesNotExist, PermissionDenied +from django.core.exceptions import ObjectDoesNotExist from django.db.models.signals import post_save -from django.test import RequestFactory, TestCase +from django.test import TestCase from mock import Mock, patch from django_auth_adfs import signals @@ -29,13 +31,12 @@ def setUp(self): Group.objects.create(name='group1') Group.objects.create(name='group2') Group.objects.create(name='group3') - self.request = RequestFactory().get('/oauth2/callback') self.signal_handler = Mock() signals.post_authenticate.connect(self.signal_handler) @mock_adfs("2012") def test_post_authenticate_signal_send(self): - response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) self.assertEqual(self.signal_handler.call_count, 1) @mock_adfs("2012") @@ -495,3 +496,34 @@ def test_nonexisting_user(self): patch("django_auth_adfs.backend.settings", Settings()): response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) self.assertEqual(response.status_code, 401) + + @mock_adfs("2016") + def test_access_token_unexpired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 200) + + @mock_adfs("2016") + def test_access_token_expired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + fromisoformat = datetime.fromisoformat + with patch('django_auth_adfs.backend.datetime') as dt: + dt.fromisoformat = fromisoformat + dt.now.return_value = datetime.now() + timedelta(hours=1) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 200) + + @mock_adfs("2016", refresh_token_expired=True) + def test_refresh_token_expired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + fromisoformat = datetime.fromisoformat + with patch('django_auth_adfs.backend.datetime') as dt: + dt.fromisoformat = fromisoformat + dt.now.return_value = datetime.now() + timedelta(hours=1) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 302) + self.assertEqual(response['Location'], f"{reverse('django_auth_adfs:login')}?next=/") + self.assertTrue(response.wsgi_request.user.is_anonymous) diff --git a/tests/urls.py b/tests/urls.py index e3a608df..9ad8a6e7 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,6 +1,9 @@ -from django.urls import include, re_path +from django.urls import include, re_path, path + +from tests.views import TestView urlpatterns = [ + path('', TestView.as_view(), name='test'), re_path(r'^oauth2/', include('django_auth_adfs.urls')), re_path(r'^oauth2/', include('django_auth_adfs.drf_urls')), ] diff --git a/tests/utils.py b/tests/utils.py index f6040d27..bda61c6d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ import time from datetime import datetime, tzinfo, timedelta from functools import partial +from urllib.parse import parse_qs import jwt import responses @@ -98,9 +99,14 @@ def build_access_token_azure_groups_in_claim_source(request): return do_build_access_token(request, issuer, groups_in_claim_names=True) +def build_access_token_adfs_expired(request): + issuer = "http://adfs.example.com/adfs/services/trust" + return do_build_access_token(request, issuer, refresh_token_expired=True) + + def do_build_mfa_error(request): response = {'error_description': 'AADSTS50076'} - return 400, [], json.dumps(response) + return 400, {}, json.dumps(response) def do_build_graph_response(request): @@ -111,7 +117,11 @@ def do_build_graph_response_no_group_perm(request): return do_build_ms_graph_groups(request, missing_group_names=True) -def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, groups_in_claim_names=False): +def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, groups_in_claim_names=False, + refresh_token_expired=False): + data = parse_qs(request.body) + if data.get('grant_type') == ['refresh_token'] and data.get('refresh_token') == ['expired_refresh_token']: + return 401, {}, None issued_at = int(time.time()) expires = issued_at + 3600 auth_time = datetime.utcnow() @@ -159,16 +169,20 @@ def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, } } token = jwt.encode(claims, signing_key_b, algorithm="RS256") + if refresh_token_expired: + refresh_token = 'expired_refresh_token' + else: + refresh_token = 'random_refresh_token' response = { 'resource': 'django_website.adfs.relying_party_id', 'token_type': 'bearer', 'refresh_token_expires_in': 28799, - 'refresh_token': 'random_refresh_token', + 'refresh_token': refresh_token, 'expires_in': 3600, 'id_token': 'not_used', 'access_token': token.decode() if isinstance(token, bytes) else token # PyJWT>=2 returns a str instead of bytes } - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def do_build_obo_access_token(request): @@ -228,7 +242,7 @@ def do_build_obo_access_token(request): 'refresh_token': 'not_used', 'access_token': token.decode() if isinstance(token, bytes) else token # PyJWT>=2 returns a str instead of bytes } - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def do_build_ms_graph_groups(request, missing_group_names=False): @@ -308,7 +322,7 @@ def do_build_ms_graph_groups(request, missing_group_names=False): if missing_group_names: for group in response["value"]: group["displayName"] = None - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def build_openid_keys(request, empty_keys=False): @@ -337,7 +351,7 @@ def build_openid_keys(request, empty_keys=False): }, ] } - return 200, [], json.dumps(keys) + return 200, {}, json.dumps(keys) def build_adfs_meta(request): @@ -345,7 +359,7 @@ def build_adfs_meta(request): data = "".join(f.readlines()) data = data.replace("REPLACE_WITH_CERT_A", base64.b64encode(signing_cert_a).decode()) data = data.replace("REPLACE_WITH_CERT_B", base64.b64encode(signing_cert_b).decode()) - return 200, [], data + return 200, {}, data def mock_adfs( @@ -356,6 +370,7 @@ def mock_adfs( version=None, requires_obo=False, missing_graph_group_perm=False, + refresh_token_expired=False, ): if adfs_version not in ["2012", "2016", "azure"]: raise NotImplementedError("This version of ADFS is not implemented") @@ -465,6 +480,12 @@ def wrapper(*original_args, **original_kwargs): callback=do_build_mfa_error, content_type='application/json', ) + elif refresh_token_expired: + rsps.add_callback( + rsps.POST, token_endpoint, + callback=build_access_token_adfs_expired, + content_type='application/json', + ) else: rsps.add_callback( rsps.POST, token_endpoint, diff --git a/tests/views.py b/tests/views.py index b16e4025..7bb0bedd 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,2 +1,11 @@ +from django.http import HttpResponse +from django.views import View + + def test_failed_response(request, error_message, status): pass + + +class TestView(View): + def get(self, request): + return HttpResponse('okay') From b43f25906c9f2242f7e7ccf55c3b9e15db7ca2e2 Mon Sep 17 00:00:00 2001 From: Dominik Vogt Date: Tue, 2 Jul 2024 20:06:07 +0200 Subject: [PATCH 3/3] Moved refresh access token check into middleware --- django_auth_adfs/backend.py | 14 ++------------ django_auth_adfs/middleware.py | 16 +++++++++++++++- tests/test_authentication.py | 2 +- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 8f75c948..95640173 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -2,12 +2,11 @@ from datetime import datetime, timedelta import jwt -from django.contrib.auth import get_user_model, logout +from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied) -from requests import HTTPError from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -418,16 +417,7 @@ def _process_adfs_response(self, request, adfs_response): request.session.save() return user - def process_request(self, request): - now = datetime.now() + settings.REFRESH_THRESHOLD - expiry = datetime.fromisoformat(request.session['_adfs_token_expiry']) - if now > expiry: - try: - self._refresh_access_token(request, request.session['_adfs_refresh_token']) - except (PermissionDenied, HTTPError): - logout(request) - - def _refresh_access_token(self, request, refresh_token): + def refresh_access_token(self, request, refresh_token): provider_config.load_config() response = provider_config.session.post( provider_config.token_endpoint, diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 0163ea78..d39d9f0a 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -1,12 +1,17 @@ """ Based on https://djangosnippets.org/snippets/1179/ """ +import logging +from datetime import datetime from re import compile from django.conf import settings as django_settings from django.contrib import auth +from django.contrib.auth import logout from django.contrib.auth.views import redirect_to_login +from django.core.exceptions import PermissionDenied from django.urls import reverse +from requests import HTTPError from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.exceptions import MFARequired @@ -21,6 +26,8 @@ if hasattr(settings, 'LOGIN_EXEMPT_URLS'): LOGIN_EXEMPT_URLS += [compile(expr) for expr in settings.LOGIN_EXEMPT_URLS] +logger = logging.getLogger("django_auth_adfs") + class LoginRequiredMiddleware: """ @@ -62,6 +69,13 @@ def middleware(request): else: backend = auth.load_backend(backend_str) if isinstance(backend, AdfsAuthCodeBackend): - backend.process_request(request) + now = datetime.now() + settings.REFRESH_THRESHOLD + expiry = datetime.fromisoformat(request.session['_adfs_token_expiry']) + if now > expiry: + try: + backend.refresh_access_token(request, request.session['_adfs_refresh_token']) + except (PermissionDenied, HTTPError) as error: + logger.debug("Error refreshing access token: %s", error) + logout(request) return get_response(request) return middleware diff --git a/tests/test_authentication.py b/tests/test_authentication.py index edd1b3bc..dfc788bd 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -520,7 +520,7 @@ def test_refresh_token_expired(self): response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) self.assertFalse(response.wsgi_request.user.is_anonymous) fromisoformat = datetime.fromisoformat - with patch('django_auth_adfs.backend.datetime') as dt: + with patch('django_auth_adfs.middleware.datetime') as dt: dt.fromisoformat = fromisoformat dt.now.return_value = datetime.now() + timedelta(hours=1) response = self.client.get(reverse('test'))