From fcfc31856e62d6c94b2119b5623442f50391beed Mon Sep 17 00:00:00 2001 From: Alan Crosswell Date: Tue, 6 Oct 2020 09:30:41 -0400 Subject: [PATCH] Revert "Openid Connect Core support - Round 2 (#859)" This reverts commit 4655c030be15616ba6e0872253a2c15a897d9701. --- .gitignore | 2 +- oauth2_provider/admin.py | 9 +- oauth2_provider/forms.py | 1 - .../migrations/0002_auto_20190406_1805.py | 2 + .../migrations/0003_auto_20200902_2022.py | 48 - oauth2_provider/models.py | 114 -- oauth2_provider/oauth2_backends.py | 26 +- oauth2_provider/oauth2_validators.py | 376 +---- oauth2_provider/settings.py | 62 +- oauth2_provider/urls.py | 9 +- oauth2_provider/views/__init__.py | 16 +- oauth2_provider/views/application.py | 4 +- oauth2_provider/views/base.py | 74 +- oauth2_provider/views/introspect.py | 2 +- oauth2_provider/views/mixins.py | 31 +- oauth2_provider/views/oidc.py | 95 -- setup.cfg | 1 - tests/migrations/0001_initial.py | 7 +- tests/settings.py | 27 - tests/test_application_views.py | 1 - tests/test_authorization_code.py | 682 ++------- tests/test_hybrid.py | 1264 ----------------- tests/test_implicit.py | 198 +-- tests/test_oauth2_backends.py | 4 +- tests/test_oauth2_validators.py | 7 - tests/test_oidc_views.py | 77 - tests/urls.py | 8 +- tox.ini | 9 +- 28 files changed, 259 insertions(+), 2897 deletions(-) delete mode 100644 oauth2_provider/migrations/0003_auto_20200902_2022.py delete mode 100644 oauth2_provider/views/oidc.py delete mode 100644 tests/test_hybrid.py delete mode 100644 tests/test_oidc_views.py diff --git a/.gitignore b/.gitignore index c22ef00fa..af644d1e3 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ __pycache__ pip-log.txt # Unit test / coverage reports -.pytest_cache +.cache .coverage .tox .pytest_cache/ diff --git a/oauth2_provider/admin.py b/oauth2_provider/admin.py index a8d69e623..8b963d981 100644 --- a/oauth2_provider/admin.py +++ b/oauth2_provider/admin.py @@ -2,7 +2,7 @@ from .models import ( get_access_token_model, get_application_model, - get_grant_model, get_id_token_model, get_refresh_token_model + get_grant_model, get_refresh_token_model ) @@ -26,11 +26,6 @@ class AccessTokenAdmin(admin.ModelAdmin): raw_id_fields = ("user", "source_refresh_token") -class IDTokenAdmin(admin.ModelAdmin): - list_display = ("token", "user", "application", "expires") - raw_id_fields = ("user", ) - - class RefreshTokenAdmin(admin.ModelAdmin): list_display = ("token", "user", "application") raw_id_fields = ("user", "access_token") @@ -39,11 +34,9 @@ class RefreshTokenAdmin(admin.ModelAdmin): Application = get_application_model() Grant = get_grant_model() AccessToken = get_access_token_model() -IDToken = get_id_token_model() RefreshToken = get_refresh_token_model() admin.site.register(Application, ApplicationAdmin) admin.site.register(Grant, GrantAdmin) admin.site.register(AccessToken, AccessTokenAdmin) -admin.site.register(IDToken, IDTokenAdmin) admin.site.register(RefreshToken, RefreshTokenAdmin) diff --git a/oauth2_provider/forms.py b/oauth2_provider/forms.py index 41129c449..2e465959a 100644 --- a/oauth2_provider/forms.py +++ b/oauth2_provider/forms.py @@ -5,7 +5,6 @@ class AllowForm(forms.Form): allow = forms.BooleanField(required=False) redirect_uri = forms.CharField(widget=forms.HiddenInput()) scope = forms.CharField(widget=forms.HiddenInput()) - nonce = forms.CharField(required=False, widget=forms.HiddenInput()) client_id = forms.CharField(widget=forms.HiddenInput()) state = forms.CharField(required=False, widget=forms.HiddenInput()) response_type = forms.CharField(widget=forms.HiddenInput()) diff --git a/oauth2_provider/migrations/0002_auto_20190406_1805.py b/oauth2_provider/migrations/0002_auto_20190406_1805.py index bcacc23ce..8ca177abf 100644 --- a/oauth2_provider/migrations/0002_auto_20190406_1805.py +++ b/oauth2_provider/migrations/0002_auto_20190406_1805.py @@ -1,3 +1,5 @@ +# Generated by Django 2.2 on 2019-04-06 18:05 + from django.db import migrations, models diff --git a/oauth2_provider/migrations/0003_auto_20200902_2022.py b/oauth2_provider/migrations/0003_auto_20200902_2022.py deleted file mode 100644 index 684949c9d..000000000 --- a/oauth2_provider/migrations/0003_auto_20200902_2022.py +++ /dev/null @@ -1,48 +0,0 @@ -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - -from oauth2_provider.settings import oauth2_settings - - -class Migration(migrations.Migration): - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('oauth2_provider', '0002_auto_20190406_1805'), - ] - - operations = [ - migrations.AddField( - model_name='application', - name='algorithm', - field=models.CharField(choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256', max_length=5), - ), - migrations.AlterField( - model_name='application', - name='authorization_grant_type', - field=models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32), - ), - migrations.CreateModel( - name='IDToken', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('token', models.TextField(unique=True)), - ('expires', models.DateTimeField()), - ('scope', models.TextField(blank=True)), - ('created', models.DateTimeField(auto_now_add=True)), - ('updated', models.DateTimeField(auto_now=True)), - ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL)), - ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), - ], - options={ - 'abstract': False, - 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', - }, - ), - migrations.AddField( - model_name='accesstoken', - name='id_token', - field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=oauth2_settings.ID_TOKEN_MODEL), - ), - ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 7135192db..5676bc0c5 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -1,4 +1,3 @@ -import json import logging from datetime import timedelta from urllib.parse import parse_qsl, urlparse @@ -10,7 +9,6 @@ from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext_lazy as _ -from jwcrypto import jwk, jwt from .generators import generate_client_id, generate_client_secret from .scopes import get_scopes_backend @@ -52,20 +50,11 @@ class AbstractApplication(models.Model): GRANT_IMPLICIT = "implicit" GRANT_PASSWORD = "password" GRANT_CLIENT_CREDENTIALS = "client-credentials" - GRANT_OPENID_HYBRID = "openid-hybrid" GRANT_TYPES = ( (GRANT_AUTHORIZATION_CODE, _("Authorization code")), (GRANT_IMPLICIT, _("Implicit")), (GRANT_PASSWORD, _("Resource owner password-based")), (GRANT_CLIENT_CREDENTIALS, _("Client credentials")), - (GRANT_OPENID_HYBRID, _("OpenID connect hybrid")), - ) - - RS256_ALGORITHM = "RS256" - HS256_ALGORITHM = "HS256" - ALGORITHM_TYPES = ( - (RS256_ALGORITHM, _("RSA with SHA-2 256")), - (HS256_ALGORITHM, _("HMAC with SHA-2 256")), ) id = models.BigAutoField(primary_key=True) @@ -93,7 +82,6 @@ class AbstractApplication(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) - algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=RS256_ALGORITHM) class Meta: abstract = True @@ -294,10 +282,6 @@ class AbstractAccessToken(models.Model): related_name="refreshed_access_token" ) token = models.CharField(max_length=255, unique=True, ) - id_token = models.OneToOneField( - oauth2_settings.ID_TOKEN_MODEL, on_delete=models.CASCADE, blank=True, null=True, - related_name="access_token" - ) application = models.ForeignKey( oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, ) @@ -431,99 +415,6 @@ class Meta(AbstractRefreshToken.Meta): swappable = "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL" -class AbstractIDToken(models.Model): - """ - An IDToken instance represents the actual token to - access user's resources, as in :openid:`2`. - - Fields: - - * :attr:`user` The Django user representing resources' owner - * :attr:`token` ID token - * :attr:`application` Application instance - * :attr:`expires` Date and time of token expiration, in DateTime format - * :attr:`scope` Allowed scopes - """ - id = models.BigAutoField(primary_key=True) - user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, - related_name="%(app_label)s_%(class)s" - ) - token = models.TextField(unique=True) - application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, - ) - expires = models.DateTimeField() - scope = models.TextField(blank=True) - - created = models.DateTimeField(auto_now_add=True) - updated = models.DateTimeField(auto_now=True) - - def is_valid(self, scopes=None): - """ - Checks if the access token is valid. - - :param scopes: An iterable containing the scopes to check or None - """ - return not self.is_expired() and self.allow_scopes(scopes) - - def is_expired(self): - """ - Check token expiration with timezone awareness - """ - if not self.expires: - return True - - return timezone.now() >= self.expires - - def allow_scopes(self, scopes): - """ - Check if the token allows the provided scopes - - :param scopes: An iterable containing the scopes to check - """ - if not scopes: - return True - - provided_scopes = set(self.scope.split()) - resource_scopes = set(scopes) - - return resource_scopes.issubset(provided_scopes) - - def revoke(self): - """ - Convenience method to uniform tokens' interface, for now - simply remove this token from the database in order to revoke it. - """ - self.delete() - - @property - def scopes(self): - """ - Returns a dictionary of allowed scope names (as keys) with their descriptions (as values) - """ - all_scopes = get_scopes_backend().get_all_scopes() - token_scopes = self.scope.split() - return {name: desc for name, desc in all_scopes.items() if name in token_scopes} - - @property - def claims(self): - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - jwt_token = jwt.JWT(key=key, jwt=self.token) - return json.loads(jwt_token.claims) - - def __str__(self): - return self.token - - class Meta: - abstract = True - - -class IDToken(AbstractIDToken): - class Meta(AbstractIDToken.Meta): - swappable = "OAUTH2_PROVIDER_ID_TOKEN_MODEL" - - def get_application_model(): """ Return the Application model that is active in this project. """ return apps.get_model(oauth2_settings.APPLICATION_MODEL) @@ -539,11 +430,6 @@ def get_access_token_model(): return apps.get_model(oauth2_settings.ACCESS_TOKEN_MODEL) -def get_id_token_model(): - """ Return the AccessToken model that is active in this project. """ - return apps.get_model(oauth2_settings.ID_TOKEN_MODEL) - - def get_refresh_token_model(): """ Return the RefreshToken model that is active in this project. """ return apps.get_model(oauth2_settings.REFRESH_TOKEN_MODEL) diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index 404add70e..6d8e68a2c 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -104,7 +104,7 @@ def validate_authorization_request(self, request): except oauth2.OAuth2Error as error: raise OAuthToolkitError(error=error) - def create_authorization_response(self, uri, request, scopes, credentials, body, allow): + def create_authorization_response(self, request, scopes, credentials, allow): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -112,8 +112,7 @@ def create_authorization_response(self, uri, request, scopes, credentials, body, :param request: The current django.http.HttpRequest object :param scopes: A list of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri` and `response_type` - :param body: Other body parameters not used in credentials dictionary + `client_id`, `state`, `redirect_uri`, `response_type` :param allow: True if the user authorize the client, otherwise False """ try: @@ -125,10 +124,10 @@ def create_authorization_response(self, uri, request, scopes, credentials, body, credentials["user"] = request.user headers, body, status = self.server.create_authorization_response( - uri=uri, scopes=scopes, credentials=credentials, body=body) - redirect_uri = headers.get("Location", None) + uri=credentials["redirect_uri"], scopes=scopes, credentials=credentials) + uri = headers.get("Location", None) - return redirect_uri, headers, body, status + return uri, headers, body, status except oauth2.FatalClientError as error: raise FatalClientError( @@ -167,21 +166,6 @@ def create_revocation_response(self, request): return uri, headers, body, status - def create_userinfo_response(self, request): - """ - A wrapper method that calls create_userinfo_response on a - `server_class` instance. - - :param request: The current django.http.HttpRequest object - """ - uri, http_method, body, headers = self._extract_params(request) - headers, body, status = self.server.create_userinfo_response( - uri, http_method, body, headers - ) - uri = headers.get("Location", None) - - return uri, headers, body, status - def verify_request(self, request, scopes): """ A wrapper method that calls verify_request on `server_class` instance. diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index e7fb860b3..515353d6f 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -1,8 +1,6 @@ import base64 import binascii -import hashlib import http.client -import json import logging from collections import OrderedDict from datetime import datetime, timedelta @@ -14,21 +12,15 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models import Q -from django.http import HttpRequest -from django.urls import reverse -from django.utils import dateformat, timezone +from django.utils import timezone from django.utils.timezone import make_aware from django.utils.translation import gettext_lazy as _ -from jwcrypto import jwk, jwt -from jwcrypto.common import JWException -from jwcrypto.jwt import JWTExpired from oauthlib.oauth2 import RequestValidator -from oauthlib.oauth2.rfc6749 import utils from .exceptions import FatalClientError from .models import ( - AbstractApplication, get_access_token_model, get_application_model, - get_grant_model, get_id_token_model, get_refresh_token_model + AbstractApplication, get_access_token_model, + get_application_model, get_grant_model, get_refresh_token_model ) from .scopes import get_scopes_backend from .settings import oauth2_settings @@ -37,23 +29,18 @@ log = logging.getLogger("oauth2_provider") GRANT_TYPE_MAPPING = { - "authorization_code": ( - AbstractApplication.GRANT_AUTHORIZATION_CODE, - AbstractApplication.GRANT_OPENID_HYBRID, - ), - "password": (AbstractApplication.GRANT_PASSWORD,), - "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS,), + "authorization_code": (AbstractApplication.GRANT_AUTHORIZATION_CODE, ), + "password": (AbstractApplication.GRANT_PASSWORD, ), + "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS, ), "refresh_token": ( AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_PASSWORD, AbstractApplication.GRANT_CLIENT_CREDENTIALS, - AbstractApplication.GRANT_OPENID_HYBRID, - ), + ) } Application = get_application_model() AccessToken = get_access_token_model() -IDToken = get_id_token_model() Grant = get_grant_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() @@ -106,15 +93,12 @@ def _authenticate_basic_auth(self, request): except UnicodeDecodeError: log.debug( "Failed basic auth: %r can't be decoded as unicode by %r", - auth_string, - encoding, + auth_string, encoding ) return False try: - client_id, client_secret = map( - unquote_plus, auth_string_decoded.split(":", 1) - ) + client_id, client_secret = map(unquote_plus, auth_string_decoded.split(":", 1)) except ValueError: log.debug("Failed basic auth, Invalid base64 encoding.") return False @@ -163,54 +147,35 @@ def _load_application(self, client_id, request): """ # we want to be sure that request has the client attribute! - assert hasattr( - request, "client" - ), '"request" instance has no "client" attribute' + assert hasattr(request, "client"), '"request" instance has no "client" attribute' try: - request.client = request.client or Application.objects.get( - client_id=client_id - ) + request.client = request.client or Application.objects.get(client_id=client_id) # Check that the application can be used (defaults to always True) if not request.client.is_usable(request): - log.debug( - "Failed body authentication: Application %r is disabled" - % (client_id) - ) + log.debug("Failed body authentication: Application %r is disabled" % (client_id)) return None return request.client except Application.DoesNotExist: - log.debug( - "Failed body authentication: Application %r does not exist" - % (client_id) - ) + log.debug("Failed body authentication: Application %r does not exist" % (client_id)) return None def _set_oauth2_error_on_request(self, request, access_token, scopes): if access_token is None: - error = OrderedDict( - [ - ("error", "invalid_token",), - ("error_description", _("The access token is invalid."),), - ] - ) + error = OrderedDict([ + ("error", "invalid_token", ), + ("error_description", _("The access token is invalid."), ), + ]) elif access_token.is_expired(): - error = OrderedDict( - [ - ("error", "invalid_token",), - ("error_description", _("The access token has expired."),), - ] - ) + error = OrderedDict([ + ("error", "invalid_token", ), + ("error_description", _("The access token has expired."), ), + ]) elif not access_token.allow_scopes(scopes): - error = OrderedDict( - [ - ("error", "insufficient_scope",), - ( - "error_description", - _("The access token is valid but does not have enough scope."), - ), - ] - ) + error = OrderedDict([ + ("error", "insufficient_scope", ), + ("error_description", _("The access token is valid but does not have enough scope."), ), + ]) else: log.warning("OAuth2 access token is invalid for an unknown reason.") error = OrderedDict([ @@ -276,15 +241,11 @@ def authenticate_client_id(self, client_id, request, *args, **kwargs): proceed only if the client exists and is not of type "Confidential". """ if self._load_application(client_id, request) is not None: - log.debug( - "Application %r has type %r" % (client_id, request.client.client_type) - ) + log.debug("Application %r has type %r" % (client_id, request.client.client_type)) return request.client.client_type != AbstractApplication.CLIENT_CONFIDENTIAL return False - def confirm_redirect_uri( - self, client_id, code, redirect_uri, client, *args, **kwargs - ): + def confirm_redirect_uri(self, client_id, code, redirect_uri, client, *args, **kwargs): """ Ensure the redirect_uri is listed in the Application instance redirect_uris field """ @@ -309,7 +270,7 @@ def get_default_redirect_uri(self, client_id, request, *args, **kwargs): return request.client.default_redirect_uri def _get_token_from_authentication_server( - self, token, introspection_url, introspection_token, introspection_credentials + self, token, introspection_url, introspection_token, introspection_credentials ): """Use external introspection endpoint to "crack open" the token. :param introspection_url: introspection endpoint URL @@ -337,12 +298,11 @@ def _get_token_from_authentication_server( try: response = requests.post( - introspection_url, data={"token": token}, headers=headers + introspection_url, + data={"token": token}, headers=headers ) except requests.exceptions.RequestException: - log.exception( - "Introspection: Failed POST to %r in token lookup", introspection_url - ) + log.exception("Introspection: Failed POST to %r in token lookup", introspection_url) return None # Log an exception when response from auth server is not successful @@ -388,8 +348,7 @@ def _get_token_from_authentication_server( "application": None, "scope": scope, "expires": expires, - }, - ) + }) return access_token @@ -402,14 +361,10 @@ def validate_bearer_token(self, token, scopes, request): introspection_url = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL introspection_token = oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN - introspection_credentials = ( - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS - ) + introspection_credentials = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS try: - access_token = AccessToken.objects.select_related( - "application", "user" - ).get(token=token) + access_token = AccessToken.objects.select_related("application", "user").get(token=token) except AccessToken.DoesNotExist: access_token = None @@ -420,7 +375,7 @@ def validate_bearer_token(self, token, scopes, request): token, introspection_url, introspection_token, - introspection_credentials, + introspection_credentials ) if access_token and access_token.is_valid(scopes): @@ -447,38 +402,22 @@ def validate_code(self, client_id, code, client, request, *args, **kwargs): except Grant.DoesNotExist: return False - def validate_grant_type( - self, client_id, grant_type, client, request, *args, **kwargs - ): + def validate_grant_type(self, client_id, grant_type, client, request, *args, **kwargs): """ Validate both grant_type is a valid string and grant_type is allowed for current workflow """ - assert grant_type in GRANT_TYPE_MAPPING # mapping misconfiguration + assert(grant_type in GRANT_TYPE_MAPPING) # mapping misconfiguration return request.client.allows_grant_type(*GRANT_TYPE_MAPPING[grant_type]) - def validate_response_type( - self, client_id, response_type, client, request, *args, **kwargs - ): + def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs): """ We currently do not support the Authorization Endpoint Response Types registry as in rfc:`8.4`, so validate the response_type only if it matches "code" or "token" """ if response_type == "code": - return client.allows_grant_type( - AbstractApplication.GRANT_AUTHORIZATION_CODE - ) + return client.allows_grant_type(AbstractApplication.GRANT_AUTHORIZATION_CODE) elif response_type == "token": return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) - elif response_type == "id_token": - return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) - elif response_type == "id_token token": - return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) - elif response_type == "code id_token": - return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) - elif response_type == "code token": - return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) - elif response_type == "code id_token token": - return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) else: return False @@ -486,15 +425,11 @@ def validate_scopes(self, client_id, scopes, client, request, *args, **kwargs): """ Ensure required scopes are permitted (as specified in the settings file) """ - available_scopes = get_scopes_backend().get_available_scopes( - application=client, request=request - ) + available_scopes = get_scopes_backend().get_available_scopes(application=client, request=request) return set(scopes).issubset(set(available_scopes)) def get_default_scopes(self, client_id, request, *args, **kwargs): - default_scopes = get_scopes_backend().get_default_scopes( - application=request.client, request=request - ) + default_scopes = get_scopes_backend().get_default_scopes(application=request.client, request=request) return default_scopes def validate_redirect_uri(self, client_id, redirect_uri, request, *args, **kwargs): @@ -522,24 +457,6 @@ def get_code_challenge_method(self, code, request): def save_authorization_code(self, client_id, code, request, *args, **kwargs): self._create_authorization_code(request, code) - def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): - scopes = [] - fields = { - "code": code, - } - - if client_id: - fields["application__client_id"] = client_id - - if redirect_uri: - fields["redirect_uri"] = redirect_uri - - grant = Grant.objects.filter(**fields).values() - if grant.exists(): - grant_dict = dict(grant[0]) - scopes = utils.scope_to_list(grant_dict["scope"]) - return scopes - def rotate_refresh_token(self, request): """ Checks if rotate refresh token is enabled @@ -580,11 +497,9 @@ def save_bearer_token(self, token, request, *args, **kwargs): refresh_token_instance = getattr(request, "refresh_token_instance", None) # If we are to reuse tokens, and we can: do so - if ( - not self.rotate_refresh_token(request) - and isinstance(refresh_token_instance, RefreshToken) - and refresh_token_instance.access_token - ): + if not self.rotate_refresh_token(request) and \ + isinstance(refresh_token_instance, RefreshToken) and \ + refresh_token_instance.access_token: access_token = AccessToken.objects.select_for_update().get( pk=refresh_token_instance.access_token.pk @@ -631,18 +546,14 @@ def save_bearer_token(self, token, request, *args, **kwargs): source_refresh_token=refresh_token_instance, ) - self._create_refresh_token( - request, refresh_token_code, access_token - ) + self._create_refresh_token(request, refresh_token_code, access_token) else: # make sure that the token data we're returning matches # the existing token token["access_token"] = previous_access_token.token - token["refresh_token"] = ( - RefreshToken.objects.filter(access_token=previous_access_token) - .first() - .token - ) + token["refresh_token"] = RefreshToken.objects.filter( + access_token=previous_access_token + ).first().token token["scope"] = previous_access_token.scope # No refresh token should be created, just access token @@ -650,15 +561,11 @@ def save_bearer_token(self, token, request, *args, **kwargs): self._create_access_token(expires, request, token) def _create_access_token(self, expires, request, token, source_refresh_token=None): - id_token = token.get("id_token", None) - if id_token: - id_token = IDToken.objects.get(token=id_token) return AccessToken.objects.create( user=request.user, scope=token["scope"], expires=expires, token=token["access_token"], - id_token=id_token, application=request.client, source_refresh_token=source_refresh_token, ) @@ -683,7 +590,7 @@ def _create_refresh_token(self, request, refresh_token_code, access_token): user=request.user, token=refresh_token_code, application=request.client, - access_token=access_token, + access_token=access_token ) def revoke_token(self, token, token_type_hint, request, *args, **kwargs): @@ -736,8 +643,9 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs """ null_or_recent = Q(revoked__isnull=True) | Q( - revoked__gt=timezone.now() - - timedelta(seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS) + revoked__gt=timezone.now() - timedelta( + seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS + ) ) rt = RefreshToken.objects.filter(null_or_recent, token=refresh_token).select_related( "access_token" @@ -751,183 +659,3 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs # Temporary store RefreshToken instance to be reused by get_original_scopes and save_bearer_token. request.refresh_token_instance = rt return rt.application == client - - @transaction.atomic - def _save_id_token(self, token, request, expires, *args, **kwargs): - - scopes = request.scope or " ".join(request.scopes) - - if request.grant_type == "client_credentials": - request.user = None - - id_token = IDToken.objects.create( - user=request.user, - scope=scopes, - expires=expires, - token=token.serialize(), - application=request.client, - ) - return id_token - - def get_jwt_bearer_token(self, token, token_handler, request): - return self.get_id_token(token, token_handler, request) - - def get_oidc_claims(self, token, token_handler, request): - # Required OIDC claims - claims = { - "sub": str(request.user.id), - } - - # https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims - claims.update(**self.get_additional_claims(request)) - - return claims - - def get_id_token_dictionary(self, token, token_handler, request): - # TODO: http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken2 - # Save the id_token on database bound to code when the request come to - # Authorization Endpoint and return the same one when request come to - # Token Endpoint - - # TODO: Check if at this point this request parameters are alredy validated - claims = self.get_oidc_claims(token, token_handler, request) - - expiration_time = timezone.now() + timedelta( - seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS - ) - # Required ID Token claims - claims.update(**{ - "iss": self.get_oidc_issuer_endpoint(request), - "aud": request.client_id, - "exp": int(dateformat.format(expiration_time, "U")), - "iat": int(dateformat.format(datetime.utcnow(), "U")), - "auth_time": int(dateformat.format(request.user.last_login, "U")), - }) - - nonce = getattr(request, "nonce", None) - if nonce: - claims["nonce"] = nonce - - # TODO: create a function to check if we should add at_hash - # http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken - # http://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDToken - # if request.grant_type in 'authorization_code' and 'access_token' in token: - if ( - (request.grant_type == "authorization_code" and "access_token" in token) - or request.response_type == "code id_token token" - or (request.response_type == "id_token token" and "access_token" in token) - ): - acess_token = token["access_token"] - at_hash = self.generate_at_hash(acess_token) - claims["at_hash"] = at_hash - - # TODO: create a function to check if we should include c_hash - # http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken - if request.response_type in ("code id_token", "code id_token token"): - code = token["code"] - sha256 = hashlib.sha256(code.encode("ascii")) - bits256 = sha256.hexdigest()[:32] - c_hash = base64.urlsafe_b64encode(bits256.encode("ascii")) - claims["c_hash"] = c_hash.decode("utf8") - - return claims, expiration_time - - def get_oidc_issuer_endpoint(self, request): - if oauth2_settings.OIDC_ISS_ENDPOINT: - return oauth2_settings.OIDC_ISS_ENDPOINT - - # generate it based on known URL - django_request = HttpRequest() - django_request.META = request.headers - - abs_url = django_request.build_absolute_uri(reverse("oauth2_provider:oidc-connect-discovery-info")) - base_url = abs_url[:-len("/.well-known/openid-configuration/")] - return base_url - - def generate_at_hash(self, access_token): - sha256 = hashlib.sha256(access_token.encode("ascii")) - bits128 = sha256.digest()[:16] - at_hash = base64.urlsafe_b64encode(bits128).decode("utf8").rstrip("=") - return at_hash - - def get_id_token(self, token, token_handler, request): - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - - claims, expiration_time = self.get_id_token_dictionary(token, token_handler, request) - - jwt_token = jwt.JWT( - header=json.dumps({"alg": "RS256"}, default=str), - claims=json.dumps(claims, default=str), - ) - jwt_token.make_signed_token(key) - - id_token = self._save_id_token(jwt_token, request, expiration_time) - # this is needed by django rest framework - request.access_token = id_token - request.id_token = id_token - return jwt_token.serialize() - - def validate_jwt_bearer_token(self, token, scopes, request): - return self.validate_id_token(token, scopes, request) - - def validate_id_token(self, token, scopes, request): - """ - When users try to access resources, check that provided id_token is valid - """ - if not token: - return False - - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - - try: - jwt_token = jwt.JWT(key=key, jwt=token) - id_token = IDToken.objects.get(token=jwt_token.serialize()) - request.client = id_token.application - request.user = id_token.user - request.scopes = scopes - # this is needed by django rest framework - request.access_token = id_token - return True - except (JWException, JWTExpired): - # TODO: This is the base exception of all jwcrypto - return False - - return False - - def validate_user_match(self, id_token_hint, scopes, claims, request): - # TODO: Fix to validate when necessary acording - # https://github.com/idan/oauthlib/blob/master/oauthlib/oauth2/rfc6749/request_validator.py#L556 - # http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest id_token_hint section - return True - - def get_authorization_code_nonce(self, client_id, code, redirect_uri, request): - """ Extracts nonce from saved authorization code. - If present in the Authentication Request, Authorization - Servers MUST include a nonce Claim in the ID Token with the - Claim Value being the nonce value sent in the Authentication - Request. Authorization Servers SHOULD perform no other - processing on nonce values used. The nonce value is a - case-sensitive string. - Only code param should be sufficient to retrieve grant code from - any storage you are using. However, `client_id` and `redirect_uri` - have been validated and can be used also. - :param client_id: Unicode client identifier - :param code: Unicode authorization code grant - :param redirect_uri: Unicode absolute URI - :return: Unicode nonce - Method is used by: - - Authorization Token Grant Dispatcher - """ - # TODO: Fix this ;) - return "" - - def get_userinfo_claims(self, request): - """ - Generates and saves a new JWT for this request, and returns it as the - current user's claims. - - """ - return self.get_oidc_claims(None, None, request) - - def get_additional_claims(self, request): - return {} diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index d3d60801e..0135da8b7 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -23,19 +23,10 @@ USER_SETTINGS = getattr(settings, "OAUTH2_PROVIDER", None) -APPLICATION_MODEL = getattr( - settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application" -) -ACCESS_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken" -) -ID_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_ID_TOKEN_MODEL", "oauth2_provider.IDToken" -) +APPLICATION_MODEL = getattr(settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application") +ACCESS_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken") GRANT_MODEL = getattr(settings, "OAUTH2_PROVIDER_GRANT_MODEL", "oauth2_provider.Grant") -REFRESH_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken" -) +REFRESH_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken") DEFAULTS = { "CLIENT_ID_GENERATOR_CLASS": "oauth2_provider.generators.ClientIdGenerator", @@ -44,7 +35,7 @@ "ACCESS_TOKEN_GENERATOR": None, "REFRESH_TOKEN_GENERATOR": None, "EXTRA_SERVER_KWARGS": {}, - "OAUTH2_SERVER_CLASS": "oauthlib.openid.connect.core.endpoints.pre_configured.Server", + "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", "OAUTH2_VALIDATOR_CLASS": "oauth2_provider.oauth2_validators.OAuth2Validator", "OAUTH2_BACKEND_CLASS": "oauth2_provider.oauth2_backends.OAuthLibCore", "SCOPES": {"read": "Reading scope", "write": "Writing scope"}, @@ -54,46 +45,29 @@ "WRITE_SCOPE": "write", "AUTHORIZATION_CODE_EXPIRE_SECONDS": 60, "ACCESS_TOKEN_EXPIRE_SECONDS": 36000, - "ID_TOKEN_EXPIRE_SECONDS": 36000, "REFRESH_TOKEN_EXPIRE_SECONDS": None, "REFRESH_TOKEN_GRACE_PERIOD_SECONDS": 0, "ROTATE_REFRESH_TOKEN": True, "ERROR_RESPONSE_WITH_SCOPES": False, "APPLICATION_MODEL": APPLICATION_MODEL, "ACCESS_TOKEN_MODEL": ACCESS_TOKEN_MODEL, - "ID_TOKEN_MODEL": ID_TOKEN_MODEL, "GRANT_MODEL": GRANT_MODEL, "REFRESH_TOKEN_MODEL": REFRESH_TOKEN_MODEL, "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], - "OIDC_ISS_ENDPOINT": "", - "OIDC_USERINFO_ENDPOINT": "", - "OIDC_RSA_PRIVATE_KEY": "", - "OIDC_RESPONSE_TYPES_SUPPORTED": [ - "code", - "token", - "id_token", - "id_token token", - "code token", - "code id_token", - "code id_token token", - ], - "OIDC_SUBJECT_TYPES_SUPPORTED": ["public"], - "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED": ["RS256", "HS256"], - "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED": [ - "client_secret_post", - "client_secret_basic", - ], + # Special settings that will be evaluated at runtime "_SCOPES": [], "_DEFAULT_SCOPES": [], + # Resource Server with Token Introspection "RESOURCE_SERVER_INTROSPECTION_URL": None, "RESOURCE_SERVER_AUTH_TOKEN": None, "RESOURCE_SERVER_INTROSPECTION_CREDENTIALS": None, "RESOURCE_SERVER_TOKEN_CACHING_SECONDS": 36000, + # Whether or not PKCE is required - "PKCE_REQUIRED": False, + "PKCE_REQUIRED": False } # List of settings that cannot be empty @@ -105,11 +79,6 @@ "OAUTH2_BACKEND_CLASS", "SCOPES", "ALLOWED_REDIRECT_URI_SCHEMES", - "OIDC_RSA_PRIVATE_KEY", - "OIDC_RESPONSE_TYPES_SUPPORTED", - "OIDC_SUBJECT_TYPES_SUPPORTED", - "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED", - "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED", ) # List of settings that may be in string import notation. @@ -148,12 +117,7 @@ def import_from_string(val, setting_name): module = importlib.import_module(module_path) return getattr(module, class_name) except ImportError as e: - msg = "Could not import %r for setting %r. %s: %s." % ( - val, - setting_name, - e.__class__.__name__, - e, - ) + msg = "Could not import %r for setting %r. %s: %s." % (val, setting_name, e.__class__.__name__, e) raise ImportError(msg) @@ -165,9 +129,7 @@ class OAuth2ProviderSettings: and return the class, rather than the string literal. """ - def __init__( - self, user_settings=None, defaults=None, import_strings=None, mandatory=None - ): + def __init__(self, user_settings=None, defaults=None, import_strings=None, mandatory=None): self.user_settings = user_settings or {} self.defaults = defaults or {} self.import_strings = import_strings or () @@ -202,9 +164,7 @@ def __getattr__(self, attr): if scope in self._SCOPES: val.append(scope) else: - raise ImproperlyConfigured( - "Defined DEFAULT_SCOPES not present in SCOPES" - ) + raise ImproperlyConfigured("Defined DEFAULT_SCOPES not present in SCOPES") self.validate_setting(attr, val) diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index f2f04d853..4cf6d4c6d 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -27,12 +27,5 @@ name="authorized-token-delete"), ] -oidc_urlpatterns = [ - re_path(r"^\.well-known/openid-configuration/$", views.ConnectDiscoveryInfoView.as_view(), - name="oidc-connect-discovery-info"), - re_path(r"^jwks/$", views.JwksInfoView.as_view(), name="jwks-info"), - re_path(r"^userinfo/$", views.UserInfoView.as_view(), name="user-info") -] - -urlpatterns = base_urlpatterns + management_urlpatterns + oidc_urlpatterns +urlpatterns = base_urlpatterns + management_urlpatterns diff --git a/oauth2_provider/views/__init__.py b/oauth2_provider/views/__init__.py index 9f2ac4ff7..7636bd9c7 100644 --- a/oauth2_provider/views/__init__.py +++ b/oauth2_provider/views/__init__.py @@ -1,13 +1,9 @@ # flake8: noqa -from .application import ( - ApplicationDelete, ApplicationDetail, ApplicationList, - ApplicationRegistration, ApplicationUpdate -) -from .base import AuthorizationView, RevokeTokenView, TokenView +from .base import AuthorizationView, TokenView, RevokeTokenView +from .application import ApplicationRegistration, ApplicationDetail, ApplicationList, \ + ApplicationDelete, ApplicationUpdate from .generic import ( - ProtectedResourceView, ReadWriteScopedResourceView, - ScopedProtectedResourceView -) + ProtectedResourceView, ScopedProtectedResourceView, ReadWriteScopedResourceView, + ClientProtectedResourceView, ClientProtectedScopedResourceView) +from .token import AuthorizedTokensListView, AuthorizedTokenDeleteView from .introspect import IntrospectTokenView -from .oidc import ConnectDiscoveryInfoView, JwksInfoView, UserInfoView -from .token import AuthorizedTokenDeleteView, AuthorizedTokensListView diff --git a/oauth2_provider/views/application.py b/oauth2_provider/views/application.py index b38c907ab..c925493f5 100644 --- a/oauth2_provider/views/application.py +++ b/oauth2_provider/views/application.py @@ -32,7 +32,7 @@ def get_form_class(self): get_application_model(), fields=( "name", "client_id", "client_secret", "client_type", - "authorization_grant_type", "redirect_uris", "algorithm", + "authorization_grant_type", "redirect_uris" ) ) @@ -81,6 +81,6 @@ def get_form_class(self): get_application_model(), fields=( "name", "client_id", "client_secret", "client_type", - "authorization_grant_type", "redirect_uris", "algorithm", + "authorization_grant_type", "redirect_uris" ) ) diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index eb825c307..b9b6ed7f9 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -86,7 +86,6 @@ class AuthorizationView(BaseAuthorizationView, FormView): * Authorization code * Implicit grant """ - template_name = "oauth2_provider/authorize.html" form_class = AllowForm @@ -102,14 +101,11 @@ def get_initial(self): initial_data = { "redirect_uri": self.oauth2_data.get("redirect_uri", None), "scope": " ".join(scopes), - "nonce": self.oauth2_data.get("nonce", None), "client_id": self.oauth2_data.get("client_id", None), "state": self.oauth2_data.get("state", None), "response_type": self.oauth2_data.get("response_type", None), "code_challenge": self.oauth2_data.get("code_challenge", None), - "code_challenge_method": self.oauth2_data.get( - "code_challenge_method", None - ), + "code_challenge_method": self.oauth2_data.get("code_challenge_method", None), } return initial_data @@ -120,27 +116,18 @@ def form_valid(self, form): "client_id": form.cleaned_data.get("client_id"), "redirect_uri": form.cleaned_data.get("redirect_uri"), "response_type": form.cleaned_data.get("response_type", None), - "state": form.cleaned_data.get("state", None), + "state": form.cleaned_data.get("state", None) } if form.cleaned_data.get("code_challenge", False): credentials["code_challenge"] = form.cleaned_data.get("code_challenge") if form.cleaned_data.get("code_challenge_method", False): - credentials["code_challenge_method"] = form.cleaned_data.get( - "code_challenge_method" - ) - - body = {"nonce": form.cleaned_data.get("nonce")} + credentials["code_challenge_method"] = form.cleaned_data.get("code_challenge_method") scopes = form.cleaned_data.get("scope") allow = form.cleaned_data.get("allow") try: uri, headers, body, status = self.create_authorization_response( - self.request.get_raw_uri(), - request=self.request, - scopes=scopes, - credentials=credentials, - body=body, - allow=allow, + request=self.request, scopes=scopes, credentials=credentials, allow=allow ) except OAuthToolkitError as error: return self.error_response(error, application) @@ -162,21 +149,13 @@ def get(self, request, *args, **kwargs): # at this point we know an Application instance with such client_id exists in the database # TODO: Cache this! - application = get_application_model().objects.get( - client_id=credentials["client_id"] - ) - - uri_query = urllib.parse.urlparse(self.request.get_raw_uri()).query - uri_query_params = dict( - urllib.parse.parse_qsl(uri_query, keep_blank_values=True, strict_parsing=True) - ) + application = get_application_model().objects.get(client_id=credentials["client_id"]) kwargs["application"] = application kwargs["client_id"] = credentials["client_id"] kwargs["redirect_uri"] = credentials["redirect_uri"] kwargs["response_type"] = credentials["response_type"] kwargs["state"] = credentials["state"] - kwargs["nonce"] = uri_query_params.get("nonce", None) self.oauth2_data = kwargs # following two loc are here only because of https://code.djangoproject.com/ticket/17795 @@ -185,9 +164,7 @@ def get(self, request, *args, **kwargs): # Check to see if the user has already granted access and return # a successful response depending on "approval_prompt" url parameter - require_approval = request.GET.get( - "approval_prompt", oauth2_settings.REQUEST_APPROVAL_PROMPT - ) + require_approval = request.GET.get("approval_prompt", oauth2_settings.REQUEST_APPROVAL_PROMPT) try: # If skip_authorization field is True, skip the authorization screen even @@ -196,36 +173,26 @@ def get(self, request, *args, **kwargs): # are already approved. if application.skip_authorization: uri, headers, body, status = self.create_authorization_response( - self.request.get_raw_uri(), - request=self.request, - scopes=" ".join(scopes), - credentials=credentials, - allow=True, + request=self.request, scopes=" ".join(scopes), + credentials=credentials, allow=True ) return self.redirect(uri, application) elif require_approval == "auto": - tokens = ( - get_access_token_model() - .objects.filter( - user=request.user, - application=kwargs["application"], - expires__gt=timezone.now(), - ) - .all() - ) + tokens = get_access_token_model().objects.filter( + user=request.user, + application=kwargs["application"], + expires__gt=timezone.now() + ).all() # check past authorizations regarded the same scopes as the current one for token in tokens: if token.allow_scopes(scopes): uri, headers, body, status = self.create_authorization_response( - self.request.get_raw_uri(), - request=self.request, - scopes=" ".join(scopes), - credentials=credentials, - allow=True, + request=self.request, scopes=" ".join(scopes), + credentials=credentials, allow=True ) - return self.redirect(uri, application) + return self.redirect(uri, application, token) except OAuthToolkitError as error: return self.error_response(error, application) @@ -272,7 +239,6 @@ class TokenView(OAuthLibMixin, View): * Password * Client credentials """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS @@ -283,8 +249,11 @@ def post(self, request, *args, **kwargs): if status == 200: access_token = json.loads(body).get("access_token") if access_token is not None: - token = get_access_token_model().objects.get(token=access_token) - app_authorized.send(sender=self, request=request, token=token) + token = get_access_token_model().objects.get( + token=access_token) + app_authorized.send( + sender=self, request=request, + token=token) response = HttpResponse(content=body, status=status) for k, v in headers.items(): @@ -297,7 +266,6 @@ class RevokeTokenView(OAuthLibMixin, View): """ Implements an endpoint to revoke access or refresh tokens """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS diff --git a/oauth2_provider/views/introspect.py b/oauth2_provider/views/introspect.py index 460a1395d..7d4381179 100644 --- a/oauth2_provider/views/introspect.py +++ b/oauth2_provider/views/introspect.py @@ -7,7 +7,7 @@ from django.views.decorators.csrf import csrf_exempt from oauth2_provider.models import get_access_token_model -from oauth2_provider.views.generic import ClientProtectedScopedResourceView +from oauth2_provider.views import ClientProtectedScopedResourceView @method_decorator(csrf_exempt, name="dispatch") diff --git a/oauth2_provider/views/mixins.py b/oauth2_provider/views/mixins.py index 0b7e02c7a..b5d0d4145 100644 --- a/oauth2_provider/views/mixins.py +++ b/oauth2_provider/views/mixins.py @@ -97,7 +97,7 @@ def validate_authorization_request(self, request): core = self.get_oauthlib_core() return core.validate_authorization_request(request) - def create_authorization_response(self, uri, request, scopes, credentials, allow, body=None): + def create_authorization_response(self, request, scopes, credentials, allow): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -105,15 +105,14 @@ def create_authorization_response(self, uri, request, scopes, credentials, allow :param request: The current django.http.HttpRequest object :param scopes: A space-separated string of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri` and `response_type` + `client_id`, `state`, `redirect_uri`, `response_type` :param allow: True if the user authorize the client, otherwise False - :param body: Other body parameters not used in credentials dictionary """ # TODO: move this scopes conversion from and to string into a utils function scopes = scopes.split(" ") if scopes else [] core = self.get_oauthlib_core() - return core.create_authorization_response(uri, request, scopes, credentials, body, allow) + return core.create_authorization_response(request, scopes, credentials, allow) def create_token_response(self, request): """ @@ -134,16 +133,6 @@ def create_revocation_response(self, request): core = self.get_oauthlib_core() return core.create_revocation_response(request) - def create_userinfo_response(self, request): - """ - A wrapper method that calls create_userinfo_response on the - `server_class` instance. - - :param request: The current django.http.HttpRequest object - """ - core = self.get_oauthlib_core() - return core.create_userinfo_response(request) - def verify_request(self, request): """ A wrapper method that calls verify_request on `server_class` instance. @@ -288,13 +277,11 @@ def dispatch(self, request, *args, **kwargs): if not valid: # Alternatively allow access tokens # check if the request is valid and the protected resource may be accessed - try: - valid, r = self.verify_request(request) - if valid: - request.resource_owner = r.user - return super().dispatch(request, *args, **kwargs) - except ValueError: - pass - return HttpResponseForbidden() + valid, r = self.verify_request(request) + if valid: + request.resource_owner = r.user + return super().dispatch(request, *args, **kwargs) + else: + return HttpResponseForbidden() else: return super().dispatch(request, *args, **kwargs) diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py deleted file mode 100644 index d7ffe4670..000000000 --- a/oauth2_provider/views/oidc.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -import json - -from django.http import HttpResponse, JsonResponse -from django.urls import reverse, reverse_lazy -from django.utils.decorators import method_decorator -from django.views.decorators.csrf import csrf_exempt -from django.views.generic import View -from jwcrypto import jwk - -from ..settings import oauth2_settings -from .mixins import OAuthLibMixin - - -class ConnectDiscoveryInfoView(View): - """ - View used to show oidc provider configuration information - """ - def get(self, request, *args, **kwargs): - issuer_url = oauth2_settings.OIDC_ISS_ENDPOINT - - if not issuer_url: - abs_url = request.build_absolute_uri(reverse("oauth2_provider:oidc-connect-discovery-info")) - issuer_url = abs_url[:-len("/.well-known/openid-configuration/")] - - authorization_endpoint = request.build_absolute_uri(reverse("oauth2_provider:authorize")) - token_endpoint = request.build_absolute_uri(reverse("oauth2_provider:token")) - userinfo_endpoint = ( - oauth2_settings.OIDC_USERINFO_ENDPOINT or - request.build_absolute_uri(reverse("oauth2_provider:user-info")) - ) - jwks_uri = request.build_absolute_uri(reverse("oauth2_provider:jwks-info")) - else: - authorization_endpoint = "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:authorize")) - token_endpoint = "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:token")) - userinfo_endpoint = ( - oauth2_settings.OIDC_USERINFO_ENDPOINT or - "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:user-info")) - ) - jwks_uri = "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:jwks-info")) - - data = { - "issuer": issuer_url, - "authorization_endpoint": authorization_endpoint, - "token_endpoint": token_endpoint, - "userinfo_endpoint": userinfo_endpoint, - "jwks_uri": jwks_uri, - "response_types_supported": oauth2_settings.OIDC_RESPONSE_TYPES_SUPPORTED, - "subject_types_supported": oauth2_settings.OIDC_SUBJECT_TYPES_SUPPORTED, - "id_token_signing_alg_values_supported": - oauth2_settings.OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED, - "token_endpoint_auth_methods_supported": - oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED, - } - response = JsonResponse(data) - response["Access-Control-Allow-Origin"] = "*" - return response - - -class JwksInfoView(View): - """ - View used to show oidc json web key set document - """ - def get(self, request, *args, **kwargs): - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - data = { - "keys": [{ - "alg": "RS256", - "use": "sig", - "kid": key.thumbprint() - }] - } - data["keys"][0].update(json.loads(key.export_public())) - response = JsonResponse(data) - response["Access-Control-Allow-Origin"] = "*" - return response - - -@method_decorator(csrf_exempt, name="dispatch") -class UserInfoView(OAuthLibMixin, View): - """ - View used to show Claims about the authenticated End-User - """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - - def get(self, request, *args, **kwargs): - url, headers, body, status = self.create_userinfo_response(request) - response = HttpResponse(content=body or "", status=status) - - for k, v in headers.items(): - response[k] = v - return response diff --git a/setup.cfg b/setup.cfg index fb060f88e..3c4e0badc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,6 @@ install_requires = django >= 2.1 requests >= 2.13.0 oauthlib >= 3.1.0 - jwcrypto >= 0.4.2 [options.packages.find] exclude = tests diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index eef6dbab5..60b17f2ae 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -45,7 +45,7 @@ class Migration(migrations.Migration): ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], max_length=32)), - ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32)), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')], max_length=32)), ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, max_length=255)), ('name', models.CharField(blank=True, max_length=255)), ('skip_authorization', models.BooleanField(default=False)), @@ -53,7 +53,6 @@ class Migration(migrations.Migration): ('updated', models.DateTimeField(auto_now=True)), ('custom_field', models.CharField(max_length=255)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_sampleapplication', to=settings.AUTH_USER_MODEL)), - ('algorithm', models.CharField(max_length=5, choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256')), ], options={ 'abstract': False, @@ -72,7 +71,6 @@ class Migration(migrations.Migration): ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), ('source_refresh_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='s_refreshed_access_token', to=settings.OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_sampleaccesstoken', to=settings.AUTH_USER_MODEL)), - ('id_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL)), ], options={ 'abstract': False, @@ -85,7 +83,7 @@ class Migration(migrations.Migration): ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], max_length=32)), - ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32)), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')], max_length=32)), ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, max_length=255)), ('name', models.CharField(blank=True, max_length=255)), ('skip_authorization', models.BooleanField(default=False)), @@ -93,7 +91,6 @@ class Migration(migrations.Migration): ('updated', models.DateTimeField(auto_now=True)), ('allowed_schemes', models.TextField(blank=True)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_basetestapplication', to=settings.AUTH_USER_MODEL)), - ('algorithm', models.CharField(max_length=5, choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256')), ], options={ 'abstract': False, diff --git a/tests/settings.py b/tests/settings.py index edd1ae679..40eef5ebd 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -130,30 +130,3 @@ }, } } - -OIDC_RSA_PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- -MIICXQIBAAKBgQCbCYh5h2NmQuBqVO6G+/CO+cHm9VBzsb0MeA6bbQfDnbhstVOT -j0hcnZJzDjYc6ajBZZf6gxVP9xrdm9Uh599VI3X5PFXLbMHrmzTAMzCGIyg+/fnP -0gocYxmCX2+XKyj/Zvt1pUX8VAN2AhrJSfxNDKUHERTVEV9bRBJg4F0C3wIDAQAB -AoGAP+i4nNw+Ec/8oWh8YSFm4xE6qKG0NdTtSMAOyWwy+KTB+vHuT1QPsLn1vj77 -+IQrX/moogg6F1oV9YdA3vat3U7rwt1sBGsRrLhA+Spp9WEQtglguNo4+QfVo2ju -YBa2rG+h75qjiA3xnU//F3rvwnAsOWv0NUVdVeguyR+u6okCQQDBUmgWeH2WHmUn -2nLNCz+9wj28rqhfOr9Ptem2gqk+ywJmuIr4Y5S1OdavOr2UZxOcEwncJ/MLVYQq -MH+x4V5HAkEAzU2GMR5OdVLcxfVTjzuIC76paoHVWnLibd1cdANpPmE6SM+pf5el -fVSwuH9Fmlizu8GiPCxbJUoXB/J1tGEKqQJBALhClEU+qOzpoZ6/voYi/6kdN3zc -uEy0EN6n09AKb8gS9QH1STgAqh+ltjMkeMe3C2DKYK5/QU9/Pc58lWl1FkcCQG67 -ZamQgxjcvJ85FvymS1aqW45KwNysIlzHjFo2jMlMf7dN6kobbPMQftDENLJvLWIT -qoFyGycdsxZiPAIyZSECQQCZFn3Dl6hnJxWZH8Fsa9hj79kZ/WVkIXGmtdgt0fNr -dTnvCVtA59ne4LEVie/PMH/odQWY0SxVm/76uBZv/1vY ------END RSA PRIVATE KEY-----""" - -OAUTH2_PROVIDER = { - "OIDC_ISS_ENDPOINT": "http://localhost", - "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", - "OIDC_RSA_PRIVATE_KEY": OIDC_RSA_PRIVATE_KEY, -} - -OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" -OAUTH2_PROVIDER_APPLICATION_MODEL = "oauth2_provider.Application" -OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" -OAUTH2_PROVIDER_ID_TOKEN_MODEL = "oauth2_provider.IDToken" diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 64e112da3..6130876ce 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -50,7 +50,6 @@ def test_application_registration_user(self): "client_type": Application.CLIENT_CONFIDENTIAL, "redirect_uris": "http://example.com", "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, - "algorithm": "RS256", } response = self.client.post(reverse("oauth2_provider:register"), form_data) diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index e4eb8ae81..e98f5b041 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -41,12 +41,8 @@ def get(self, request, *args, **kwargs): class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() - self.test_user = UserModel.objects.create_user( - "test_user", "test@example.com", "123456" - ) - self.dev_user = UserModel.objects.create_user( - "dev_user", "dev@example.com", "123456" - ) + self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") + self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] @@ -61,13 +57,8 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write", "openid"] + oauth2_settings._SCOPES = ["read", "write"] oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect", - } def tearDown(self): self.application.delete() @@ -112,25 +103,6 @@ def test_skip_authorization_completely(self): }) self.assertEqual(response.status_code, 302) - def test_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - def test_pre_auth_invalid_client(self): """ Test error for an invalid client_id with response_type: code @@ -175,32 +147,6 @@ def test_pre_auth_valid_client(self): self.assertEqual(form["scope"].value(), "read write") self.assertEqual(form["client_id"].value(), self.application.client_id) - def test_id_token_pre_auth_valid_client(self): - """ - Test response for a valid client_id with response_type: code - """ - self.client.login(username="test_user", password="123456") - - query_data = { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "openid") - self.assertEqual(form["client_id"].value(), self.application.client_id) - def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ Test response for a valid client_id with response_type: code @@ -230,11 +176,10 @@ def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): def test_pre_auth_approval_prompt(self): tok = AccessToken.objects.create( - user=self.test_user, - token="1234567890", + user=self.test_user, token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write", + scope="read write" ) self.client.login(username="test_user", password="123456") @@ -259,11 +204,10 @@ def test_pre_auth_approval_prompt_default(self): self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") AccessToken.objects.create( - user=self.test_user, - token="1234567890", + user=self.test_user, token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write", + scope="read write" ) self.client.login(username="test_user", password="123456") query_data = { @@ -280,11 +224,10 @@ def test_pre_auth_approval_prompt_default_override(self): oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" AccessToken.objects.create( - user=self.test_user, - token="1234567890", + user=self.test_user, token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write", + scope="read write" ) self.client.login(username="test_user", password="123456") query_data = { @@ -359,32 +302,7 @@ def test_code_post_auth_allow(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org?", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - - def test_id_token_code_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - } - - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org?", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -405,9 +323,7 @@ def test_code_post_auth_deny(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("error=access_denied", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -426,9 +342,7 @@ def test_code_post_auth_deny_no_state(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("error=access_denied", response["Location"]) self.assertNotIn("state", response["Location"]) @@ -448,9 +362,7 @@ def test_code_post_auth_bad_responsetype(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org?error", response["Location"]) @@ -469,9 +381,7 @@ def test_code_post_auth_forbidden_redirect_uri(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) def test_code_post_auth_malicious_redirect_uri(self): @@ -489,9 +399,7 @@ def test_code_post_auth_malicious_redirect_uri(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) def test_code_post_auth_allow_custom_redirect_uri_scheme(self): @@ -510,9 +418,7 @@ def test_code_post_auth_allow_custom_redirect_uri_scheme(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("custom-scheme://example.com?", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -534,9 +440,7 @@ def test_code_post_auth_deny_custom_redirect_uri_scheme(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("custom-scheme://example.com?", response["Location"]) self.assertIn("error=access_denied", response["Location"]) @@ -559,9 +463,7 @@ def test_code_post_auth_redirection_uri_with_querystring(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.com?foo=bar", response["Location"]) self.assertIn("code=", response["Location"]) @@ -584,9 +486,7 @@ def test_code_post_auth_failing_redirection_uri_with_querystring(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.com?", response["Location"]) self.assertIn("error=access_denied", response["Location"]) @@ -608,29 +508,25 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) class TestAuthorizationCodeTokenView(BaseTest): - def get_auth(self, scope="read write"): + def get_auth(self): """ Helper method to retrieve a valid authorization code """ authcode_data = { "client_id": self.application.client_id, "state": "random_state_string", - "scope": scope, + "scope": "read write", "redirect_uri": "http://example.org", "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) return query_dict["code"].pop() @@ -640,13 +536,9 @@ def generate_pkce_codes(self, algorithm, length=43): """ code_verifier = get_random_string(length) if algorithm == "S256": - code_challenge = ( - base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode()).digest() - ) - .decode() - .rstrip("=") - ) + code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode()).digest() + ).decode().rstrip("=") else: code_challenge = code_verifier return code_verifier, code_challenge @@ -667,9 +559,7 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): "code_challenge_method": code_challenge_method, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) oauth2_settings.PKCE_REQUIRED = False return query_dict["code"].pop() @@ -684,23 +574,17 @@ def test_basic_auth(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_refresh(self): """ @@ -712,15 +596,11 @@ def test_refresh(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -729,29 +609,23 @@ def test_refresh(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) token_request_data = { "grant_type": "refresh_token", "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) # check refresh token cannot be used twice - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) content = json.loads(response.content.decode("utf-8")) self.assertTrue("invalid_grant" in content.values()) @@ -767,15 +641,11 @@ def test_refresh_with_grace_period(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -784,11 +654,9 @@ def test_refresh_with_grace_period(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) token_request_data = { "grant_type": "refresh_token", @@ -796,9 +664,7 @@ def test_refresh_with_grace_period(self): "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) @@ -807,9 +673,7 @@ def test_refresh_with_grace_period(self): first_refresh_token = content["refresh_token"] # check access token returns same data if used twice, see #497 - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) @@ -829,15 +693,11 @@ def test_refresh_invalidates_old_tokens(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) rt = content["refresh_token"] @@ -848,9 +708,7 @@ def test_refresh_invalidates_old_tokens(self): "refresh_token": rt, "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) refresh_token = RefreshToken.objects.filter(token=rt).first() @@ -867,15 +725,11 @@ def test_refresh_no_scopes(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -883,9 +737,7 @@ def test_refresh_no_scopes(self): "grant_type": "refresh_token", "refresh_token": content["refresh_token"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) @@ -901,15 +753,11 @@ def test_refresh_bad_scopes(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -918,9 +766,7 @@ def test_refresh_bad_scopes(self): "refresh_token": content["refresh_token"], "scope": "read write nuke", } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_refresh_fail_repeating_requests(self): @@ -933,15 +779,11 @@ def test_refresh_fail_repeating_requests(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -950,13 +792,9 @@ def test_refresh_fail_repeating_requests(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_refresh_repeating_requests(self): @@ -971,15 +809,11 @@ def test_refresh_repeating_requests(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -988,26 +822,18 @@ def test_refresh_repeating_requests(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) # try refreshing outside the refresh window, see #497 rt = RefreshToken.objects.get(token=content["refresh_token"]) self.assertIsNotNone(rt.revoked) - rt.revoked = timezone.now() - datetime.timedelta( - minutes=10 - ) # instead of mocking out datetime + rt.revoked = timezone.now() - datetime.timedelta(minutes=10) # instead of mocking out datetime rt.save() - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 @@ -1021,15 +847,11 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -1040,13 +862,9 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): } oauth2_settings.ROTATE_REFRESH_TOKEN = False - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) oauth2_settings.ROTATE_REFRESH_TOKEN = True @@ -1060,15 +878,11 @@ def test_basic_auth_bad_authcode(self): token_request_data = { "grant_type": "authorization_code", "code": "BLAH", - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_bad_granttype(self): @@ -1080,15 +894,11 @@ def test_basic_auth_bad_granttype(self): token_request_data = { "grant_type": "UNKNOWN", "code": "BLAH", - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_grant_expired(self): @@ -1097,27 +907,18 @@ def test_basic_auth_grant_expired(self): """ self.client.login(username="test_user", password="123456") g = Grant( - application=self.application, - user=self.test_user, - code="BLAH", - expires=timezone.now(), - redirect_uri="", - scope="", - ) + application=self.application, user=self.test_user, code="BLAH", + expires=timezone.now(), redirect_uri="", scope="") g.save() token_request_data = { "grant_type": "authorization_code", "code": "BLAH", - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_bad_secret(self): @@ -1130,13 +931,11 @@ def test_basic_auth_bad_secret(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 401) def test_basic_auth_wrong_auth_type(self): @@ -1149,20 +948,16 @@ def test_basic_auth_wrong_auth_type(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - user_pass = "{0}:{1}".format( - self.application.client_id, self.application.client_secret - ) + user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) auth_string = base64.b64encode(user_pass.encode("utf-8")) auth_headers = { "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 401) def test_request_body_params(self): @@ -1180,17 +975,13 @@ def test_request_body_params(self): "client_secret": self.application.client_secret, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public(self): """ @@ -1206,52 +997,16 @@ def test_public(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id, + "client_id": self.application.client_id } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) - - def test_id_token_public(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth(scope="openid") - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "scope": "openid", - } - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_S256_authorize_get(self): """ @@ -1327,20 +1082,16 @@ def test_public_pkce_S256(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": code_verifier, + "code_verifier": code_verifier } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain(self): @@ -1361,20 +1112,16 @@ def test_public_pkce_plain(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": code_verifier, + "code_verifier": code_verifier } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_invalid_algorithm(self): @@ -1477,12 +1224,10 @@ def test_public_pkce_S256_invalid_code_verifier(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": "invalid", + "code_verifier": "invalid" } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1504,12 +1249,10 @@ def test_public_pkce_plain_invalid_code_verifier(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": "invalid", + "code_verifier": "invalid" } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1530,12 +1273,10 @@ def test_public_pkce_S256_missing_code_verifier(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id, + "client_id": self.application.client_id } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1556,12 +1297,10 @@ def test_public_pkce_plain_missing_code_verifier(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id, + "client_id": self.application.client_id } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1580,19 +1319,14 @@ def test_malicious_redirect_uri(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "/../", - "client_id": self.application.client_id, + "client_id": self.application.client_id } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") - self.assertEqual( - data["error_description"], - oauthlib_errors.MismatchingRedirectURIError.description, - ) + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) def test_code_exchange_succeed_when_redirect_uri_match(self): """ @@ -1609,9 +1343,7 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1619,23 +1351,17 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org?foo=bar", + "redirect_uri": "http://example.org?foo=bar" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_code_exchange_fails_when_redirect_uri_does_not_match(self): """ @@ -1652,9 +1378,7 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1662,26 +1386,17 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org?foo=baraa", + "redirect_uri": "http://example.org?foo=baraa" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") - self.assertEqual( - data["error_description"], - oauthlib_errors.MismatchingRedirectURIError.description, - ) + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) - def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( - self, - ): + def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): """ Tests code exchange succeed when redirect uri matches the one used for code request """ @@ -1698,9 +1413,7 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1708,72 +1421,17 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar", + "redirect_uri": "http://example.com?bar=baz&foo=bar" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) - - def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( - self, - ): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="test_user", password="123456") - self.application.redirect_uris = "http://localhost http://example.com?foo=bar" - self.application.save() - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.com?bar=baz&foo=bar", - "response_type": "code", - "allow": True, - } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) - query_dict = parse_qs(urlparse(response["Location"]).query) - authorization_code = query_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar", - } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_oob_as_html(self): """ @@ -1836,9 +1494,7 @@ def test_oob_as_json(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) self.assertEqual(response.status_code, 200) self.assertRegex(response["Content-Type"], "^application/json") @@ -1855,17 +1511,13 @@ def test_oob_as_json(self): "client_secret": self.application.client_secret, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) class TestAuthorizationCodeProtectedResource(BaseTest): @@ -1881,54 +1533,7 @@ def test_resource_access_allowed(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) - query_dict = parse_qs(urlparse(response["Location"]).query) - authorization_code = query_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) - content = json.loads(response.content.decode("utf-8")) - access_token = content["access_token"] - - # use token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + access_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - - def test_id_token_resource_access_allowed(self): - self.client.login(username="test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1936,18 +1541,13 @@ def test_id_token_resource_access_allowed(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) access_token = content["access_token"] - id_token = content["id_token"] # use token to access the resource auth_headers = { @@ -1960,17 +1560,6 @@ def test_id_token_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") - # use id_token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + id_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - def test_resource_access_deny(self): auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "faketoken", @@ -1984,6 +1573,7 @@ def test_resource_access_deny(self): class TestDefaultScopes(BaseTest): + def test_pre_auth_default_scopes(self): """ Test response for a valid client_id with response_type: code using default scopes diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py deleted file mode 100644 index 1f45aeeec..000000000 --- a/tests/test_hybrid.py +++ /dev/null @@ -1,1264 +0,0 @@ -import base64 -import datetime -import json -from urllib.parse import parse_qs, urlencode, urlparse - -from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase -from django.urls import reverse -from django.utils import timezone -from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors - -from oauth2_provider.models import ( - get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model -) -from oauth2_provider.settings import oauth2_settings -from oauth2_provider.views import ProtectedResourceView - -from .utils import get_basic_auth_header - - -Application = get_application_model() -AccessToken = get_access_token_model() -Grant = get_grant_model() -RefreshToken = get_refresh_token_model() -UserModel = get_user_model() - - -# mocking a protected resource view -class ResourceView(ProtectedResourceView): - def get(self, request, *args, **kwargs): - return "This is a protected resource" - - -class BaseTest(TestCase): - def setUp(self): - self.factory = RequestFactory() - self.hy_test_user = UserModel.objects.create_user("hy_test_user", "test_hy@example.com", "123456") - self.hy_dev_user = UserModel.objects.create_user("hy_dev_user", "dev_hy@example.com", "123456") - - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] - - self.application = Application( - name="Hybrid Test Application", - redirect_uris=( - "http://localhost http://example.com http://example.org custom-scheme://example.com" - ), - user=self.hy_dev_user, - client_type=Application.CLIENT_CONFIDENTIAL, - authorization_grant_type=Application.GRANT_OPENID_HYBRID, - ) - self.application.save() - - oauth2_settings._SCOPES = ["read", "write", "openid"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect" - } - - def tearDown(self): - self.application.delete() - self.hy_test_user.delete() - self.hy_dev_user.delete() - - -class TestRegressionIssue315Hybrid(BaseTest): - """ - Test to avoid regression for the issue 315: request object - was being reassigned when getting AuthorizationView - """ - - def test_request_is_not_overwritten_code_token(self): - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - assert "request" not in response.context_data - - def test_request_is_not_overwritten_code_id_token(self): - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - assert "request" not in response.context_data - - def test_request_is_not_overwritten_code_id_token_token(self): - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - assert "request" not in response.context_data - - -class TestHybridView(BaseTest): - def test_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="hy_test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - - def test_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="hy_test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - - def test_pre_auth_invalid_client(self): - """ - Test error for an invalid client_id with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": "fakeclientid", - "response_type": "code", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 400) - self.assertEqual( - response.context_data["url"], - "?error=invalid_request&error_description=Invalid+client_id+parameter+value." - ) - - def test_pre_auth_valid_client(self): - """ - Test response for a valid client_id with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "read write") - self.assertEqual(form["client_id"].value(), self.application.client_id) - - def test_id_token_pre_auth_valid_client(self): - """ - Test response for a valid client_id with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "openid") - self.assertEqual(form["client_id"].value(), self.application.client_id) - - def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): - """ - Test response for a valid client_id with response_type: code - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "custom-scheme://example.com", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "custom-scheme://example.com") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "read write") - self.assertEqual(form["client_id"].value(), self.application.client_id) - - def test_pre_auth_approval_prompt(self): - tok = AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", - application=self.application, - expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" - ) - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "approval_prompt": "auto", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - # user already authorized the application, but with different scopes: prompt them. - tok.scope = "read" - tok.save() - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - def test_pre_auth_approval_prompt_default(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" - self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") - - AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", - application=self.application, - expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" - ) - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - def test_pre_auth_approval_prompt_default_override(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" - - AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", - application=self.application, - expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" - ) - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - - def test_pre_auth_default_redirect(self): - """ - Test for default redirect uri if omitted from query string with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://localhost") - - def test_pre_auth_forbibben_redirect(self): - """ - Test error when passing a forbidden redirect_uri in query string with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "redirect_uri": "http://forbidden.it", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 400) - - def test_pre_auth_wrong_response_type(self): - """ - Test error when passing a wrong response_type in query string - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "WRONG", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - self.assertIn("error=unsupported_response_type", response["Location"]) - - def test_code_post_auth_allow_code_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "response_type": "code token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_allow_code_id_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - - def test_code_post_auth_allow_code_id_token_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "response_type": "code id_token token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_id_token_code_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - - def test_code_post_auth_deny(self): - """ - Test error when resource owner deny access - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("error=access_denied", response["Location"]) - - def test_code_post_auth_bad_responsetype(self): - """ - Test authorization code is given for an allowed request with a response_type not supported - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "UNKNOWN", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org?error", response["Location"]) - - def test_code_post_auth_forbidden_redirect_uri(self): - """ - Test authorization code is given for an allowed request with a forbidden redirect_uri - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://forbidden.it", - "response_type": "code", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 400) - - def test_code_post_auth_malicious_redirect_uri(self): - """ - Test validation of a malicious redirect_uri - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "/../", - "response_type": "code", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 400) - - def test_code_post_auth_allow_custom_redirect_uri_scheme_code_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "custom-scheme://example.com", - "response_type": "code token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("custom-scheme://example.com", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "custom-scheme://example.com", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("custom-scheme://example.com", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - - def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "custom-scheme://example.com", - "response_type": "code id_token token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("custom-scheme://example.com", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_deny_custom_redirect_uri_scheme(self): - """ - Test error when resource owner deny access - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "custom-scheme://example.com", - "response_type": "code", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("custom-scheme://example.com?", response["Location"]) - self.assertIn("error=access_denied", response["Location"]) - - def test_code_post_auth_redirection_uri_with_querystring_code_token(self): - """ - Tests that a redirection uri with query string is allowed - and query string is retained on redirection. - See http://tools.ietf.org/html/rfc6749#section-3.1.2 - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.com?foo=bar", - "response_type": "code token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.com?foo=bar", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_redirection_uri_with_querystring_code_id_token(self): - """ - Tests that a redirection uri with query string is allowed - and query string is retained on redirection. - See http://tools.ietf.org/html/rfc6749#section-3.1.2 - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.com?foo=bar", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.com?foo=bar", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - - def test_code_post_auth_redirection_uri_with_querystring_code_id_token_token(self): - """ - Tests that a redirection uri with query string is allowed - and query string is retained on redirection. - See http://tools.ietf.org/html/rfc6749#section-3.1.2 - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.com?foo=bar", - "response_type": "code id_token token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.com?foo=bar", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_failing_redirection_uri_with_querystring(self): - """ - Test that in case of error the querystring of the redirection uri is preserved - - See https://github.com/evonove/django-oauth-toolkit/issues/238 - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.com?foo=bar", - "response_type": "code", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertEqual( - "http://example.com?foo=bar&error=access_denied&state=random_state_string", response["Location"] - ) - - def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): - """ - Tests that a redirection uri is matched using scheme + netloc + path - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.com/a?foo=bar", - "response_type": "code", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 400) - - -class TestHybridTokenView(BaseTest): - def get_auth(self, scope="read write"): - """ - Helper method to retrieve a valid authorization code - """ - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": scope, - "redirect_uri": "http://example.org", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - return fragment_dict["code"].pop() - - def test_basic_auth(self): - """ - Request an access token using basic authentication for client authentication - """ - self.client.login(username="hy_test_user", password="123456") - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_basic_auth_bad_authcode(self): - """ - Request an access token using a bad authorization code - """ - self.client.login(username="hy_test_user", password="123456") - - token_request_data = { - "grant_type": "authorization_code", - "code": "BLAH", - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 400) - - def test_basic_auth_bad_granttype(self): - """ - Request an access token using a bad grant_type string - """ - self.client.login(username="hy_test_user", password="123456") - - token_request_data = { - "grant_type": "UNKNOWN", - "code": "BLAH", - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 400) - - def test_basic_auth_grant_expired(self): - """ - Request an access token using an expired grant token - """ - self.client.login(username="hy_test_user", password="123456") - g = Grant( - application=self.application, user=self.hy_test_user, code="BLAH", - expires=timezone.now(), redirect_uri="", scope="") - g.save() - - token_request_data = { - "grant_type": "authorization_code", - "code": "BLAH", - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 400) - - def test_basic_auth_bad_secret(self): - """ - Request an access token using basic authentication for client authentication - """ - self.client.login(username="hy_test_user", password="123456") - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) - - def test_basic_auth_wrong_auth_type(self): - """ - Request an access token using basic authentication for client authentication - """ - self.client.login(username="hy_test_user", password="123456") - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org" - } - - user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) - auth_string = base64.b64encode(user_pass.encode("utf-8")) - auth_headers = { - "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) - - def test_request_body_params(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="hy_test_user", password="123456") - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "client_secret": self.application.client_secret, - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_public(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="hy_test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_id_token_public(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="hy_test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth(scope="openid") - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "scope": "openid", - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_malicious_redirect_uri(self): - """ - Request an access token using client_type: public and ensure redirect_uri is - properly validated. - """ - self.client.login(username="hy_test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "/../", - "client_id": self.application.client_id - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 400) - data = response.json() - self.assertEqual(data["error"], "invalid_request") - self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) - - def test_code_exchange_succeed_when_redirect_uri_match(self): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="hy_test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org?foo=bar", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org?foo=bar" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_code_exchange_fails_when_redirect_uri_does_not_match(self): - """ - Tests code exchange fails when redirect uri does not match the one used for code request - """ - self.client.login(username="hy_test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org?foo=bar", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - query_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = query_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org?foo=baraa" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 400) - data = response.json() - self.assertEqual(data["error"], "invalid_request") - self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) - - def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="hy_test_user", password="123456") - self.application.redirect_uris = "http://localhost http://example.com?foo=bar" - self.application.save() - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.com?bar=baz&foo=bar", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="hy_test_user", password="123456") - self.application.redirect_uris = "http://localhost http://example.com?foo=bar" - self.application.save() - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.com?bar=baz&foo=bar", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar", - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - -class TestHybridProtectedResource(BaseTest): - def test_resource_access_allowed(self): - self.client.login(username="hy_test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - content = json.loads(response.content.decode("utf-8")) - access_token = content["access_token"] - - # use token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + access_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.hy_test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - - def test_id_token_resource_access_allowed(self): - self.client.login(username="hy_test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - content = json.loads(response.content.decode("utf-8")) - access_token = content["access_token"] - id_token = content["id_token"] - - # use token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + access_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.hy_test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - - # use id_token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + id_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.hy_test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - - def test_resource_access_deny(self): - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + "faketoken", - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.hy_test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response.status_code, 403) - - -class TestDefaultScopesHybrid(BaseTest): - - def test_pre_auth_default_scopes(self): - """ - Test response for a valid client_id with response_type: code using default scopes - """ - self.client.login(username="hy_test_user", password="123456") - oauth2_settings._DEFAULT_SCOPES = ["read"] - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code token", - "state": "random_state_string", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "read") - self.assertEqual(form["client_id"].value(), self.application.client_id) - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 15ac7469d..b51d0e1da 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,10 +1,8 @@ -import json from urllib.parse import parse_qs, urlparse from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse -from jwcrypto import jwk, jwt from oauth2_provider.models import get_application_model from oauth2_provider.settings import oauth2_settings @@ -35,14 +33,8 @@ def setUp(self): authorization_grant_type=Application.GRANT_IMPLICIT, ) - oauth2_settings._SCOPES = ["read", "write", "openid"] + oauth2_settings._SCOPES = ["read", "write"] oauth2_settings._DEFAULT_SCOPES = ["read"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect" - } - self.key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) def tearDown(self): self.application.delete() @@ -273,191 +265,3 @@ def test_resource_access_allowed(self): view = ResourceView.as_view() response = view(request) self.assertEqual(response, "This is a protected resource") - - -class TestOpenIDConnectImplicitFlow(BaseTest): - def test_id_token_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: id_token - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "id_token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org#", response["Location"]) - self.assertNotIn("access_token=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - - uri_query = urlparse(response["Location"]).fragment - uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) - id_token = uri_query_params["id_token"][0] - jwt_token = jwt.JWT(key=self.key, jwt=id_token) - claims = json.loads(jwt_token.claims) - self.assertIn("nonce", claims) - self.assertNotIn("at_hash", claims) - - def test_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "id_token", - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org#", response["Location"]) - self.assertNotIn("access_token=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - - uri_query = urlparse(response["Location"]).fragment - uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) - id_token = uri_query_params["id_token"][0] - jwt_token = jwt.JWT(key=self.key, jwt=id_token) - claims = json.loads(jwt_token.claims) - self.assertIn("nonce", claims) - self.assertNotIn("at_hash", claims) - - def test_id_token_skip_authorization_completely_missing_nonce(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "id_token", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - self.assertIn("error=invalid_request", response["Location"]) - self.assertIn("error_description=Request+is+missing+mandatory+nonce+paramete", response["Location"]) - - def test_id_token_post_auth_deny(self): - """ - Test error when resource owner deny access - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "id_token", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("error=access_denied", response["Location"]) - - def test_access_token_and_id_token_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: token - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "id_token token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org#", response["Location"]) - self.assertIn("access_token=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - - uri_query = urlparse(response["Location"]).fragment - uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) - id_token = uri_query_params["id_token"][0] - jwt_token = jwt.JWT(key=self.key, jwt=id_token) - claims = json.loads(jwt_token.claims) - self.assertIn("nonce", claims) - self.assertIn("at_hash", claims) - - def test_access_token_and_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "id_token token", - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org#", response["Location"]) - self.assertIn("access_token=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - - uri_query = urlparse(response["Location"]).fragment - uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) - id_token = uri_query_params["id_token"][0] - jwt_token = jwt.JWT(key=self.key, jwt=id_token) - claims = json.loads(jwt_token.claims) - self.assertIn("nonce", claims) - self.assertIn("at_hash", claims) - - def test_access_token_and_id_token_post_auth_deny(self): - """ - Test error when resource owner deny access - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "id_token token", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("error=access_denied", response["Location"]) diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index 0d98dad8b..d844da5f4 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -65,9 +65,7 @@ def test_create_token_response_gets_extra_credentials(self): payload = "grant_type=password&username=john&password=123456" request = self.factory.post("/o/token/", payload, content_type="application/x-www-form-urlencoded") - with mock.patch( - "oauthlib.openid.connect.core.endpoints.pre_configured.Server.create_token_response" - ) as create_token_response: + with mock.patch("oauthlib.oauth2.Server.create_token_response") as create_token_response: mocked = mock.MagicMock() create_token_response.return_value = mocked, mocked, mocked core = self.MyOAuthLibCore() diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index 1a0926988..7821148d5 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -287,13 +287,6 @@ def test_save_bearer_token__with_new_token__calls_methods_to_create_access_and_r assert create_access_token_mock.call_count == 1 assert create_refresh_token_mock.call_count == 1 - def test_generate_at_hash(self): - # Values taken from spec, https://openid.net/specs/openid-connect-core-1_0.html#id_token-tokenExample - access_token = "jHkWEdUXMU1BwAsC4vtUsZwnNvTIxEl0z9K3vx5KF0Y" - at_hash = self.validator.generate_at_hash(access_token) - - assert at_hash == "77QmUPtjPfzWtF2AnpK9RQ" - class TestOAuth2ValidatorProvidesErrorData(TransactionTestCase): """These test cases check that the recommended error codes are returned diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py deleted file mode 100644 index 71f41d7eb..000000000 --- a/tests/test_oidc_views.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import unicode_literals - -from django.test import TestCase -from django.urls import reverse - -from oauth2_provider.settings import oauth2_settings - - -class TestConnectDiscoveryInfoView(TestCase): - def test_get_connect_discovery_info(self): - expected_response = { - "issuer": "http://localhost", - "authorization_endpoint": "http://localhost/o/authorize/", - "token_endpoint": "http://localhost/o/token/", - "userinfo_endpoint": "http://localhost/userinfo/", - "jwks_uri": "http://localhost/o/jwks/", - "response_types_supported": [ - "code", - "token", - "id_token", - "id_token token", - "code token", - "code id_token", - "code id_token token" - ], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256", "HS256"], - "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"] - } - response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) - self.assertEqual(response.status_code, 200) - assert response.json() == expected_response - - def test_get_connect_discovery_info_without_issuer_url(self): - oauth2_settings.OIDC_ISS_ENDPOINT = None - oauth2_settings.OIDC_USERINFO_ENDPOINT = None - expected_response = { - "issuer": "http://testserver/o", - "authorization_endpoint": "http://testserver/o/authorize/", - "token_endpoint": "http://testserver/o/token/", - "userinfo_endpoint": "http://testserver/o/userinfo/", - "jwks_uri": "http://testserver/o/jwks/", - "response_types_supported": [ - "code", - "token", - "id_token", - "id_token token", - "code token", - "code id_token", - "code id_token token" - ], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256", "HS256"], - "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"] - } - response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) - self.assertEqual(response.status_code, 200) - assert response.json() == expected_response - oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost" - oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/" - - -class TestJwksInfoView(TestCase): - def test_get_jwks_info(self): - expected_response = { - "keys": [{ - "alg": "RS256", - "use": "sig", - "kid": "s4a1o8mFEd1tATAIH96caMlu4hOxzBUaI2QTqbYNBHs", - "e": "AQAB", - "kty": "RSA", - "n": "mwmIeYdjZkLgalTuhvvwjvnB5vVQc7G9DHgOm20Hw524bLVTk49IXJ2Scw42HOmowWWX-oMVT_ca3ZvVIeffVSN1-TxVy2zB65s0wDMwhiMoPv35z9IKHGMZgl9vlyso_2b7daVF_FQDdgIayUn8TQylBxEU1RFfW0QSYOBdAt8" # noqa - }] - } - response = self.client.get(reverse("oauth2_provider:jwks-info")) - self.assertEqual(response.status_code, 200) - assert response.json() == expected_response diff --git a/tests/urls.py b/tests/urls.py index c7fa9a101..16dcf6ded 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,11 +1,13 @@ +from django.conf.urls import include, url from django.contrib import admin -from django.urls import include, re_path admin.autodiscover() urlpatterns = [ - re_path(r"^o/", include("oauth2_provider.urls", namespace="oauth2_provider")), - re_path(r"^admin/", admin.site.urls), + url(r"^o/", include("oauth2_provider.urls", namespace="oauth2_provider")), ] + + +urlpatterns += [url(r"^admin/", admin.site.urls)] diff --git a/tox.ini b/tox.ini index 686bf366a..c984f8b99 100644 --- a/tox.ini +++ b/tox.ini @@ -14,8 +14,7 @@ envlist = django_find_project = false [testenv] -commands = - pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} -s +commands = pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} setenv = DJANGO_SETTINGS_MODULE = tests.settings PYTHONPATH = {toxinidir} @@ -27,7 +26,6 @@ deps = djangomaster: https://github.com/django/django/archive/master.tar.gz djangorestframework oauthlib>=3.1.0 - jwcrypto coverage pytest pytest-cov @@ -44,7 +42,6 @@ commands = make html deps = sphinx<3 oauthlib>=3.1.0 m2r>=0.2.1 - jwcrypto [testenv:py37-flake8] skip_install = True @@ -70,9 +67,7 @@ commands = [coverage:run] source = oauth2_provider -omit = - */migrations/* - oauth2_provider/settings.py +omit = */migrations/* [flake8] max-line-length = 110