From 183b02aa58249553a25164da5eaa0a0177b2c943 Mon Sep 17 00:00:00 2001 From: Tom Evans Date: Mon, 25 Jan 2021 14:08:00 +0000 Subject: [PATCH] Make OIDC support optional Make OIDC support optional by not requiring OIDC_RSA_PRIVATE_KEY to be set in the settings, and using the standard oauthlib.oauth2.Server class when an OIDC private key is not configured. Add a test fixture wrapping oauth2_settings. This allows individual tests / test suites to override oauth2 settings and have them reset at the end of the test. This avoids configuration leaking from one test to another, and allows us to test multiple different configurations in one test run. When using the oauth2_settings fixture, allow configuration for the test case to be loaded from a pytest marker called oauth2_settings. Split out OIDC specific tests requiring specific OIDC configuration into separate TestCase. Adjust the OAuthLibMixin to fallback to using the server, validator and core classes specified in oauth2_settings when not hardcoded in to the class. These classes can still be specified as hard-coded attributes in sub-classes, but it's no longer required if you just want what is configured in oauth2_settings, so remove all attributes that are just pointing at the configuration anyway. Add a setting ALWAYS_RELOAD_OAUTHLIB_CORE, which causes OAuthLibMixin to reload the OAuthLibCore object on each request. This is only intended to be used during testing, to allow the views to recognise changes in configuration. Show missing coverage lines in the coverage report. Fixes: #873 --- oauth2_provider/settings.py | 19 +- oauth2_provider/views/base.py | 12 - oauth2_provider/views/generic.py | 14 +- oauth2_provider/views/mixins.py | 21 +- oauth2_provider/views/oidc.py | 4 - tests/conftest.py | 67 ++++++ tests/presets.py | 42 ++++ tests/settings.py | 6 - tests/test_application_views.py | 8 +- tests/test_authorization_code.py | 393 +++++++++++++++---------------- tests/test_client_credential.py | 8 +- tests/test_decorators.py | 3 - tests/test_generator.py | 23 +- tests/test_hybrid.py | 41 ++-- tests/test_implicit.py | 20 +- tests/test_introspection_auth.py | 27 +-- tests/test_introspection_view.py | 10 +- tests/test_mixins.py | 37 ++- tests/test_models.py | 49 +--- tests/test_oauth2_backends.py | 12 +- tests/test_oidc_views.py | 13 +- tests/test_password.py | 8 +- tests/test_rest_framework.py | 17 +- tests/test_scopes.py | 32 ++- tests/test_token_revocation.py | 3 - tests/test_validators.py | 5 +- tox.ini | 6 +- 27 files changed, 478 insertions(+), 422 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/presets.py diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index 331d3148e..1cfa05e2c 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -37,7 +37,8 @@ "ACCESS_TOKEN_GENERATOR": None, "REFRESH_TOKEN_GENERATOR": None, "EXTRA_SERVER_KWARGS": {}, - "OAUTH2_SERVER_CLASS": "oauthlib.openid.connect.core.endpoints.pre_configured.Server", + "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", + "OIDC_SERVER_CLASS": "oauthlib.openid.connect.core.endpoints.pre_configured.Server", "OAUTH2_VALIDATOR_CLASS": "oauth2_provider.oauth2_validators.OAuth2Validator", "OAUTH2_BACKEND_CLASS": "oauth2_provider.oauth2_backends.OAuthLibCore", "SCOPES": {"read": "Reading scope", "write": "Writing scope"}, @@ -92,6 +93,9 @@ "RESOURCE_SERVER_TOKEN_CACHING_SECONDS": 36000, # Whether or not PKCE is required "PKCE_REQUIRED": False, + # Whether to re-create OAuthlibCore on every request. + # Should only be required in testing. + "ALWAYS_RELOAD_OAUTHLIB_CORE": False, } # List of settings that cannot be empty @@ -103,7 +107,6 @@ "OAUTH2_BACKEND_CLASS", "SCOPES", "ALLOWED_REDIRECT_URI_SCHEMES", - "OIDC_RSA_PRIVATE_KEY", "OIDC_RESPONSE_TYPES_SUPPORTED", "OIDC_SUBJECT_TYPES_SUPPORTED", "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED", @@ -182,13 +185,17 @@ def user_settings(self): def __getattr__(self, attr): if attr not in self.defaults: raise AttributeError("Invalid OAuth2Provider setting: %s" % attr) - try: # Check if present in user settings val = self.user_settings[attr] except KeyError: # Fall back to defaults - val = self.defaults[attr] + # Special case OAUTH2_SERVER_CLASS - if not specified, and OIDC is + # enabled, use the OIDC_SERVER_CLASS setting instead + if attr == "OAUTH2_SERVER_CLASS" and self.is_oidc_enabled: + val = self.defaults["OIDC_SERVER_CLASS"] + else: + val = self.defaults[attr] # Coerce import strings into classes if val and attr in self.import_strings: @@ -254,6 +261,10 @@ def reload(self): if hasattr(self, "_user_settings"): delattr(self, "_user_settings") + @property + def is_oidc_enabled(self): + return bool(self.OIDC_RSA_PRIVATE_KEY) + oauth2_settings = OAuth2ProviderSettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS, MANDATORY) diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index 7aa82aa7e..82bdc2502 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -90,10 +90,6 @@ class AuthorizationView(BaseAuthorizationView, FormView): template_name = "oauth2_provider/authorize.html" form_class = AllowForm - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - skip_authorization_completely = False def get_initial(self): @@ -267,10 +263,6 @@ class TokenView(OAuthLibMixin, View): * Client credentials """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - @method_decorator(sensitive_post_parameters("password")) def post(self, request, *args, **kwargs): url, headers, body, status = self.create_token_response(request) @@ -292,10 +284,6 @@ class RevokeTokenView(OAuthLibMixin, View): Implements an endpoint to revoke access or refresh tokens """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - def post(self, request, *args, **kwargs): url, headers, body, status = self.create_revocation_response(request) response = HttpResponse(content=body or "", status=status) diff --git a/oauth2_provider/views/generic.py b/oauth2_provider/views/generic.py index 10e84d59f..da675eac4 100644 --- a/oauth2_provider/views/generic.py +++ b/oauth2_provider/views/generic.py @@ -1,6 +1,5 @@ from django.views.generic import View -from ..settings import oauth2_settings from .mixins import ( ClientProtectedResourceMixin, OAuthLibMixin, @@ -10,16 +9,7 @@ ) -class InitializationMixin(OAuthLibMixin): - - """Initializer for OauthLibMixin""" - - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - - -class ProtectedResourceView(ProtectedResourceMixin, InitializationMixin, View): +class ProtectedResourceView(ProtectedResourceMixin, OAuthLibMixin, View): """ Generic view protecting resources by providing OAuth2 authentication out of the box """ @@ -45,7 +35,7 @@ class ReadWriteScopedResourceView(ReadWriteScopedResourceMixin, ProtectedResourc pass -class ClientProtectedResourceView(ClientProtectedResourceMixin, InitializationMixin, View): +class ClientProtectedResourceView(ClientProtectedResourceMixin, OAuthLibMixin, View): """View for protecting a resource with client-credentials method. This involves allowing access tokens, Basic Auth and plain credentials in request body. diff --git a/oauth2_provider/views/mixins.py b/oauth2_provider/views/mixins.py index 4fcea0a47..ed6448fc1 100644 --- a/oauth2_provider/views/mixins.py +++ b/oauth2_provider/views/mixins.py @@ -25,6 +25,9 @@ class OAuthLibMixin: * validator_class * oauthlib_backend_class + If these class variables are not set, it will fall back to using the classes + specified in oauth2_settings (OAUTH2_SERVER_CLASS, OAUTH2_VALIDATOR_CLASS + and OAUTH2_BACKEND_CLASS). """ server_class = None @@ -37,10 +40,7 @@ def get_server_class(cls): Return the OAuthlib server class to use """ if cls.server_class is None: - raise ImproperlyConfigured( - "OAuthLibMixin requires either a definition of 'server_class'" - " or an implementation of 'get_server_class()'" - ) + return oauth2_settings.OAUTH2_SERVER_CLASS else: return cls.server_class @@ -50,10 +50,7 @@ def get_validator_class(cls): Return the RequestValidator implementation class to use """ if cls.validator_class is None: - raise ImproperlyConfigured( - "OAuthLibMixin requires either a definition of 'validator_class'" - " or an implementation of 'get_validator_class()'" - ) + return oauth2_settings.OAUTH2_VALIDATOR_CLASS else: return cls.validator_class @@ -63,10 +60,7 @@ def get_oauthlib_backend_class(cls): Return the OAuthLibCore implementation class to use """ if cls.oauthlib_backend_class is None: - raise ImproperlyConfigured( - "OAuthLibMixin requires either a definition of 'oauthlib_backend_class'" - " or an implementation of 'get_oauthlib_backend_class()'" - ) + return oauth2_settings.OAUTH2_BACKEND_CLASS else: return cls.oauthlib_backend_class @@ -85,8 +79,9 @@ def get_server(cls): def get_oauthlib_core(cls): """ Cache and return `OAuthlibCore` instance so it will be created only on first request + unless ALWAYS_RELOAD_OAUTHLIB_CORE is True. """ - if not hasattr(cls, "_oauthlib_core"): + if not hasattr(cls, "_oauthlib_core") or oauth2_settings.ALWAYS_RELOAD_OAUTHLIB_CORE: server = cls.get_server() core_class = cls.get_oauthlib_backend_class() cls._oauthlib_core = core_class(server) diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py index f1bb93195..345112b23 100644 --- a/oauth2_provider/views/oidc.py +++ b/oauth2_provider/views/oidc.py @@ -79,10 +79,6 @@ class UserInfoView(OAuthLibMixin, View): View used to show Claims about the authenticated End-User """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - def get(self, request, *args, **kwargs): url, headers, body, status = self.create_userinfo_response(request) response = HttpResponse(content=body or "", status=status) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..dadcbe9a7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,67 @@ +import pytest +from django.conf import settings as test_settings +from jwcrypto import jwk + +from oauth2_provider.settings import oauth2_settings as _oauth2_settings + + +class OAuthSettingsWrapper: + """ + A wrapper around oauth2_settings to ensure that when an overridden value is + set, it also records it in _cached_attrs, so that the settings can be reset. + """ + + def __init__(self, settings, user_settings): + if user_settings: + settings.OAUTH2_PROVIDER = user_settings + _oauth2_settings.reload() + self.settings = settings + # Reload OAuthlibCore for every view request during tests + self.ALWAYS_RELOAD_OAUTHLIB_CORE = True + + def __setattr__(self, attr, value): + setattr(_oauth2_settings, attr, value) + _oauth2_settings._cached_attrs.add(attr) + + def __delattr__(self, attr): + delattr(_oauth2_settings, attr) + if attr in _oauth2_settings._cached_attrs: + _oauth2_settings._cached_attrs.remove(attr) + + def __getattr__(self, attr): + return getattr(_oauth2_settings, attr) + + def finalize(self): + self.settings.finalize() + _oauth2_settings.reload() + + +@pytest.fixture +def oauth2_settings(request, settings): + """ + A fixture that provides a simple way to override OAUTH2_PROVIDER settings. + + It can be used two ways - either setting things on the fly, or by reading + configuration data from the pytest marker oauth2_settings. + + If used on a standard pytest function, you can use argument dependency + injection to get the wrapper. If used on a unittest.TestCase, the wrapper + is made available on the class instance, as `oauth2_settings`. + + Anything overridden will be restored at the end of the test case, ensuring + that there is no configuration leakage between test cases. + """ + marker = request.node.get_closest_marker("oauth2_settings") + user_settings = {} + if marker is not None: + user_settings = marker.args[0] + wrapper = OAuthSettingsWrapper(settings, user_settings) + if request.instance is not None: + request.instance.oauth2_settings = wrapper + yield wrapper + wrapper.finalize() + + +@pytest.fixture(scope="class") +def oidc_key(request): + request.cls.key = jwk.JWK.from_pem(test_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) diff --git a/tests/presets.py b/tests/presets.py new file mode 100644 index 000000000..aa264429c --- /dev/null +++ b/tests/presets.py @@ -0,0 +1,42 @@ +from copy import deepcopy + +from django.conf import settings + + +# A set of OAUTH2_PROVIDER settings dicts that can be used in tests + +DEFAULT_SCOPES_RW = {"DEFAULT_SCOPES": ["read", "write"]} +DEFAULT_SCOPES_RO = {"DEFAULT_SCOPES": ["read"]} +OIDC_SETTINGS_RW = { + "OIDC_ISS_ENDPOINT": "http://localhost", + "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", + "OIDC_RSA_PRIVATE_KEY": settings.OIDC_RSA_PRIVATE_KEY, + "SCOPES": { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect", + }, + "DEFAULT_SCOPES": ["read", "write"], +} +OIDC_SETTINGS_RO = deepcopy(OIDC_SETTINGS_RW) +OIDC_SETTINGS_RO["DEFAULT_SCOPES"] = ["read"] +REST_FRAMEWORK_SCOPES = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "scope1": "Scope 1", + "scope2": "Scope 2", + "resource1": "Resource 1", + }, +} +INTROSPECTION_SETTINGS = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "introspection": "Introspection scope", + "dolphin": "eek eek eek scope", + }, + "RESOURCE_SERVER_INTROSPECTION_URL": "http://example.org/introspection", + "READ_SCOPE": "read", + "WRITE_SCOPE": "write", +} diff --git a/tests/settings.py b/tests/settings.py index 1bb7d43df..1d295982e 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -134,12 +134,6 @@ dTnvCVtA59ne4LEVie/PMH/odQWY0SxVm/76uBZv/1vY -----END RSA PRIVATE KEY-----""" -OAUTH2_PROVIDER = { - "OIDC_ISS_ENDPOINT": "http://localhost", - "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", - "OIDC_RSA_PRIVATE_KEY": OIDC_RSA_PRIVATE_KEY, -} - OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" OAUTH2_PROVIDER_APPLICATION_MODEL = "oauth2_provider.Application" OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" diff --git a/tests/test_application_views.py b/tests/test_application_views.py index e5c897ab9..33942f9ba 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -1,9 +1,9 @@ +import pytest from django.contrib.auth import get_user_model from django.test import TestCase from django.urls import reverse from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views.application import ApplicationRegistration from .models import SampleApplication @@ -23,21 +23,19 @@ def tearDown(self): self.bar_user.delete() +@pytest.mark.usefixtures("oauth2_settings") class TestApplicationRegistrationView(BaseTest): + @pytest.mark.oauth2_settings({"APPLICATION_MODEL": "tests.SampleApplication"}) def test_get_form_class(self): """ Tests that the form class returned by the "get_form_class" method is bound to custom application model defined in the "OAUTH2_PROVIDER_APPLICATION_MODEL" setting. """ - # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "tests.SampleApplication" # Create a registration view and tests that the model form is bound # to the custom Application model application_form_class = ApplicationRegistration().get_form_class() self.assertEqual(SampleApplication, application_form_class._meta.model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" def test_application_registration_user(self): self.client.login(username="foo_user", password="123456") diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 49d0f55be..688759cfa 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -5,6 +5,7 @@ import re from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse @@ -18,9 +19,9 @@ get_grant_model, get_refresh_token_model, ) -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView +from . import presets from .utils import get_basic_auth_header @@ -40,13 +41,14 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] self.application = Application.objects.create( name="Test Application", @@ -59,14 +61,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write", "openid"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect", - } - def tearDown(self): self.application.delete() self.test_user.delete() @@ -95,6 +89,7 @@ def test_request_is_not_overwritten(self): assert "request" not in response.context_data +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class TestAuthorizationCodeView(BaseTest): def test_skip_authorization_completely(self): """ @@ -116,25 +111,6 @@ def test_skip_authorization_completely(self): ) self.assertEqual(response.status_code, 302) - def test_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - def test_pre_auth_invalid_client(self): """ Test error for an invalid client_id with response_type: code @@ -179,32 +155,6 @@ def test_pre_auth_valid_client(self): self.assertEqual(form["scope"].value(), "read write") self.assertEqual(form["client_id"].value(), self.application.client_id) - def test_id_token_pre_auth_valid_client(self): - """ - Test response for a valid client_id with response_type: code - """ - self.client.login(username="test_user", password="123456") - - query_data = { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "openid") - self.assertEqual(form["client_id"].value(), self.application.client_id) - def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ Test response for a valid client_id with response_type: code @@ -260,7 +210,7 @@ def test_pre_auth_approval_prompt(self): self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default(self): - self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") + self.assertEqual(self.oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") AccessToken.objects.create( user=self.test_user, @@ -281,7 +231,7 @@ def test_pre_auth_approval_prompt_default(self): self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default_override(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" AccessToken.objects.create( user=self.test_user, @@ -369,27 +319,6 @@ def test_code_post_auth_allow(self): self.assertIn("state=random_state_string", response["Location"]) self.assertIn("code=", response["Location"]) - def test_id_token_code_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org?", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - def test_code_post_auth_deny(self): """ Test error when resource owner deny access @@ -594,7 +523,76 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): self.assertEqual(response.status_code, 400) -class TestAuthorizationCodeTokenView(BaseTest): +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeView(BaseTest): + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 302) + + def test_id_token_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="test_user", password="123456") + + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "openid") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_id_token_code_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org?", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + + +class BaseAuthorizationCodeTokenView(BaseTest): def get_auth(self, scope="read write"): """ Helper method to retrieve a valid authorization code @@ -629,7 +627,7 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): """ Helper method to retrieve a valid authorization code using pkce """ - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True authcode_data = { "client_id": self.application.client_id, "state": "random_state_string", @@ -643,9 +641,11 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) - oauth2_settings.PKCE_REQUIRED = False return query_dict["code"].pop() + +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) +class TestAuthorizationCodeTokenView(BaseAuthorizationCodeTokenView): def test_basic_auth(self): """ Request an access token using basic authentication for client authentication @@ -666,7 +666,7 @@ def test_basic_auth(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_refresh(self): """ @@ -716,7 +716,7 @@ def test_refresh_with_grace_period(self): """ Request an access token using a refresh token """ - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 + self.oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 self.client.login(username="test_user", password="123456") authorization_code = self.get_auth() @@ -763,7 +763,6 @@ def test_refresh_with_grace_period(self): # refresh token should be the same as well self.assertTrue("refresh_token" in content) self.assertEqual(content["refresh_token"], first_refresh_token) - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 def test_refresh_invalidates_old_tokens(self): """ @@ -884,7 +883,7 @@ def test_refresh_repeating_requests(self): Trying to refresh an access token with the same refresh token more than once succeeds in the grace period and fails outside """ - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 + self.oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 self.client.login(username="test_user", password="123456") authorization_code = self.get_auth() @@ -917,7 +916,6 @@ def test_refresh_repeating_requests(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 def test_refresh_repeating_requests_non_rotating_tokens(self): """ @@ -942,15 +940,13 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - oauth2_settings.ROTATE_REFRESH_TOKEN = False + self.oauth2_settings.ROTATE_REFRESH_TOKEN = False response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - oauth2_settings.ROTATE_REFRESH_TOKEN = True - def test_basic_auth_bad_authcode(self): """ Request an access token using a bad authorization code @@ -1064,7 +1060,7 @@ def test_request_body_params(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public(self): """ @@ -1089,35 +1085,7 @@ def test_public(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_id_token_public(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth(scope="openid") - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "scope": "openid", - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_S256_authorize_get(self): """ @@ -1130,7 +1098,7 @@ def test_public_pkce_S256_authorize_get(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1146,7 +1114,6 @@ def test_public_pkce_S256_authorize_get(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertContains(response, 'value="S256"', count=1, status_code=200) self.assertContains(response, 'value="{0}"'.format(code_challenge), count=1, status_code=200) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_authorize_get(self): """ @@ -1159,7 +1126,7 @@ def test_public_pkce_plain_authorize_get(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1175,7 +1142,6 @@ def test_public_pkce_plain_authorize_get(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertContains(response, 'value="plain"', count=1, status_code=200) self.assertContains(response, 'value="{0}"'.format(code_challenge), count=1, status_code=200) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256(self): """ @@ -1188,7 +1154,7 @@ def test_public_pkce_S256(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1204,8 +1170,7 @@ def test_public_pkce_S256(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - oauth2_settings.PKCE_REQUIRED = False + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_plain(self): """ @@ -1218,7 +1183,7 @@ def test_public_pkce_plain(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1234,8 +1199,7 @@ def test_public_pkce_plain(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - oauth2_settings.PKCE_REQUIRED = False + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_invalid_algorithm(self): """ @@ -1247,7 +1211,7 @@ def test_public_pkce_invalid_algorithm(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("invalid") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1263,7 +1227,6 @@ def test_public_pkce_invalid_algorithm(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("error=invalid_request", response["Location"]) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_missing_code_challenge(self): """ @@ -1276,7 +1239,7 @@ def test_public_pkce_missing_code_challenge(self): self.application.skip_authorization = True self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1291,7 +1254,6 @@ def test_public_pkce_missing_code_challenge(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("error=invalid_request", response["Location"]) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_missing_code_challenge_method(self): """ @@ -1303,7 +1265,7 @@ def test_public_pkce_missing_code_challenge_method(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1317,7 +1279,6 @@ def test_public_pkce_missing_code_challenge_method(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256_invalid_code_verifier(self): """ @@ -1330,7 +1291,7 @@ def test_public_pkce_S256_invalid_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1342,7 +1303,6 @@ def test_public_pkce_S256_invalid_code_verifier(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_invalid_code_verifier(self): """ @@ -1355,7 +1315,7 @@ def test_public_pkce_plain_invalid_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1367,7 +1327,6 @@ def test_public_pkce_plain_invalid_code_verifier(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256_missing_code_verifier(self): """ @@ -1380,7 +1339,7 @@ def test_public_pkce_S256_missing_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1391,7 +1350,6 @@ def test_public_pkce_S256_missing_code_verifier(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_missing_code_verifier(self): """ @@ -1404,7 +1362,7 @@ def test_public_pkce_plain_missing_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1415,7 +1373,6 @@ def test_public_pkce_plain_missing_code_verifier(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_malicious_redirect_uri(self): """ @@ -1477,7 +1434,7 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_code_exchange_fails_when_redirect_uri_does_not_match(self): """ @@ -1552,48 +1509,7 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( - self, - ): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="test_user", password="123456") - self.application.redirect_uris = "http://localhost http://example.com?foo=bar" - self.application.save() - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.com?bar=baz&foo=bar", - "response_type": "code", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - query_dict = parse_qs(urlparse(response["Location"]).query) - authorization_code = query_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar", - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_oob_as_html(self): """ @@ -1639,7 +1555,7 @@ def test_oob_as_html(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_oob_as_json(self): """ @@ -1679,9 +1595,82 @@ def test_oob_as_json(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeTokenView(BaseAuthorizationCodeTokenView): + def test_id_token_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( + self, + ): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).query) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class TestAuthorizationCodeProtectedResource(BaseTest): def test_resource_access_allowed(self): self.client.login(username="test_user", password="123456") @@ -1722,6 +1711,20 @@ def test_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") + def test_resource_access_deny(self): + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + "faketoken", + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response.status_code, 403) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeProtectedResource(BaseTest): def test_id_token_resource_access_allowed(self): self.client.login(username="test_user", password="123456") @@ -1773,25 +1776,14 @@ def test_id_token_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") - def test_resource_access_deny(self): - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + "faketoken", - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response.status_code, 403) - +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestDefaultScopes(BaseTest): def test_pre_auth_default_scopes(self): """ Test response for a valid client_id with response_type: code using default scopes """ self.client.login(username="test_user", password="123456") - oauth2_settings._DEFAULT_SCOPES = ["read"] query_data = { "client_id": self.application.client_id, @@ -1811,4 +1803,3 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["state"].value(), "random_state_string") self.assertEqual(form["scope"].value(), "read") self.assertEqual(form["client_id"].value(), self.application.client_id) - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] diff --git a/tests/test_client_credential.py b/tests/test_client_credential.py index 966eb826b..8b9aa3bc2 100644 --- a/tests/test_client_credential.py +++ b/tests/test_client_credential.py @@ -1,6 +1,7 @@ import json from urllib.parse import quote_plus +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse @@ -10,10 +11,10 @@ from oauth2_provider.models import get_access_token_model, get_application_model from oauth2_provider.oauth2_backends import OAuthLibCore from oauth2_provider.oauth2_validators import OAuth2Validator -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView from oauth2_provider.views.mixins import OAuthLibMixin +from . import presets from .utils import get_basic_auth_header @@ -28,6 +29,8 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -41,9 +44,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS, ) - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 22ce48e76..ce17a891a 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -6,7 +6,6 @@ from oauth2_provider.decorators import protected_resource, rw_protected_resource from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings Application = get_application_model() @@ -37,8 +36,6 @@ def setUp(self): application=self.application, ) - oauth2_settings._SCOPES = ["read", "write"] - def test_access_denied(self): @protected_resource() def view(request, *args, **kwargs): diff --git a/tests/test_generator.py b/tests/test_generator.py index 670ac9ea1..cc7928017 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,13 +1,7 @@ +import pytest from django.test import TestCase -from oauth2_provider.generators import ( - BaseHashGenerator, - ClientIdGenerator, - ClientSecretGenerator, - generate_client_id, - generate_client_secret, -) -from oauth2_provider.settings import oauth2_settings +from oauth2_provider.generators import BaseHashGenerator, generate_client_id, generate_client_secret class MockHashGenerator(BaseHashGenerator): @@ -15,23 +9,20 @@ def hash(self): return 42 +@pytest.mark.usefixtures("oauth2_settings") class TestGenerators(TestCase): - def tearDown(self): - oauth2_settings.CLIENT_ID_GENERATOR_CLASS = ClientIdGenerator - oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = ClientSecretGenerator - def test_generate_client_id(self): - g = oauth2_settings.CLIENT_ID_GENERATOR_CLASS() + g = self.oauth2_settings.CLIENT_ID_GENERATOR_CLASS() self.assertEqual(len(g.hash()), 40) - oauth2_settings.CLIENT_ID_GENERATOR_CLASS = MockHashGenerator + self.oauth2_settings.CLIENT_ID_GENERATOR_CLASS = MockHashGenerator self.assertEqual(generate_client_id(), 42) def test_generate_secret_id(self): - g = oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS() + g = self.oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS() self.assertEqual(len(g.hash()), 128) - oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = MockHashGenerator + self.oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = MockHashGenerator self.assertEqual(generate_client_secret(), 42) def test_basegen_misuse(self): diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index a01e0ba46..6c9b86b62 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -3,6 +3,7 @@ import json from urllib.parse import parse_qs, urlencode, urlparse +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse @@ -15,9 +16,9 @@ get_grant_model, get_refresh_token_model, ) -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView +from . import presets from .utils import get_basic_auth_header @@ -34,13 +35,14 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() self.hy_test_user = UserModel.objects.create_user("hy_test_user", "test_hy@example.com", "123456") self.hy_dev_user = UserModel.objects.create_user("hy_dev_user", "dev_hy@example.com", "123456") - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] self.application = Application( name="Hybrid Test Application", @@ -53,20 +55,13 @@ def setUp(self): ) self.application.save() - oauth2_settings._SCOPES = ["read", "write", "openid"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect", - } - def tearDown(self): self.application.delete() self.hy_test_user.delete() self.hy_dev_user.delete() +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestRegressionIssue315Hybrid(BaseTest): """ Test to avoid regression for the issue 315: request object @@ -127,6 +122,7 @@ def test_request_is_not_overwritten_code_id_token_token(self): assert "request" not in response.context_data +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestHybridView(BaseTest): def test_skip_authorization_completely(self): """ @@ -311,8 +307,8 @@ def test_pre_auth_approval_prompt(self): self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" - self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" + self.assertEqual(self.oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") AccessToken.objects.create( user=self.hy_test_user, @@ -336,7 +332,7 @@ def test_pre_auth_approval_prompt_default(self): self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default_override(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" AccessToken.objects.create( user=self.hy_test_user, @@ -788,6 +784,7 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): self.assertEqual(response.status_code, 400) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestHybridTokenView(BaseTest): def get_auth(self, scope="read write"): """ @@ -827,7 +824,7 @@ def test_basic_auth(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_basic_auth_bad_authcode(self): """ @@ -942,7 +939,7 @@ def test_request_body_params(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public(self): """ @@ -967,7 +964,7 @@ def test_public(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_id_token_public(self): """ @@ -995,7 +992,7 @@ def test_id_token_public(self): self.assertEqual(content["scope"], "openid") self.assertIn("access_token", content) self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_malicious_redirect_uri(self): """ @@ -1054,7 +1051,7 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "openid read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_code_exchange_fails_when_redirect_uri_does_not_match(self): """ @@ -1124,7 +1121,7 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "openid read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): """ @@ -1163,9 +1160,10 @@ def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_qu self.assertEqual(content["scope"], "openid") self.assertIn("access_token", content) self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestHybridProtectedResource(BaseTest): def test_resource_access_allowed(self): self.client.login(username="hy_test_user", password="123456") @@ -1269,13 +1267,13 @@ def test_resource_access_deny(self): self.assertEqual(response.status_code, 403) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RO) class TestDefaultScopesHybrid(BaseTest): def test_pre_auth_default_scopes(self): """ Test response for a valid client_id with response_type: code using default scopes """ self.client.login(username="hy_test_user", password="123456") - oauth2_settings._DEFAULT_SCOPES = ["read"] query_string = urlencode( { @@ -1298,4 +1296,3 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["state"].value(), "random_state_string") self.assertEqual(form["scope"].value(), "read") self.assertEqual(form["client_id"].value(), self.application.client_id) - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] diff --git a/tests/test_implicit.py b/tests/test_implicit.py index c47ee5031..fe9bb196e 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,15 +1,17 @@ import json from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse -from jwcrypto import jwk, jwt +from jwcrypto import jwt from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView +from . import presets + Application = get_application_model() UserModel = get_user_model() @@ -21,6 +23,7 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -35,21 +38,13 @@ def setUp(self): authorization_grant_type=Application.GRANT_IMPLICIT, ) - oauth2_settings._SCOPES = ["read", "write", "openid"] - oauth2_settings._DEFAULT_SCOPES = ["read"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect", - } - self.key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - def tearDown(self): self.application.delete() self.test_user.delete() self.dev_user.delete() +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestImplicitAuthorizationCodeView(BaseTest): def test_pre_auth_valid_client_default_scopes(self): """ @@ -245,6 +240,7 @@ def test_implicit_fails_when_redirect_uri_path_is_invalid(self): self.assertEqual(response.status_code, 400) +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestImplicitTokenView(BaseTest): def test_resource_access_allowed(self): self.client.login(username="test_user", password="123456") @@ -275,6 +271,8 @@ def test_resource_access_allowed(self): self.assertEqual(response, "This is a protected resource") +@pytest.mark.usefixtures("oidc_key") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestOpenIDConnectImplicitFlow(BaseTest): def test_id_token_post_auth_allow(self): """ diff --git a/tests/test_introspection_auth.py b/tests/test_introspection_auth.py index 5fc12b6b1..9f871cdea 100644 --- a/tests/test_introspection_auth.py +++ b/tests/test_introspection_auth.py @@ -1,6 +1,7 @@ import calendar import datetime +import pytest from django.conf.urls import include from django.contrib.auth import get_user_model from django.http import HttpResponse @@ -11,9 +12,10 @@ from oauth2_provider.models import get_access_token_model, get_application_model from oauth2_provider.oauth2_validators import OAuth2Validator -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ScopedProtectedResourceView +from . import presets + try: from unittest import mock @@ -78,6 +80,8 @@ def json(self): @override_settings(ROOT_URLCONF=__name__) +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.INTROSPECTION_SETTINGS) class TestTokenIntrospectionAuth(TestCase): """ Tests for Authorization through token introspection @@ -114,16 +118,9 @@ def setUp(self): scope="read write dolphin", ) - oauth2_settings._SCOPES = ["read", "write", "introspection", "dolphin"] - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL = "http://example.org/introspection" - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = self.resource_server_token.token - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = self.resource_server_token.token def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL = None - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = None self.resource_server_token.delete() self.application.delete() AccessToken.objects.all().delete() @@ -136,9 +133,9 @@ def test_get_token_from_authentication_server_not_existing_token(self, mock_get) """ token = self.validator._get_token_from_authentication_server( self.resource_server_token.token, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, ) self.assertIsNone(token) @@ -149,9 +146,9 @@ def test_get_token_from_authentication_server_existing_token(self, mock_get): """ token = self.validator._get_token_from_authentication_server( "foo", - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, ) self.assertIsInstance(token, AccessToken) self.assertEqual(token.user.username, "foo_user") diff --git a/tests/test_introspection_view.py b/tests/test_introspection_view.py index 5b3fc58f8..0f68320ca 100644 --- a/tests/test_introspection_view.py +++ b/tests/test_introspection_view.py @@ -1,14 +1,15 @@ import calendar import datetime +import pytest from django.contrib.auth import get_user_model from django.test import TestCase from django.urls import reverse from django.utils import timezone from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings +from . import presets from .utils import get_basic_auth_header @@ -17,6 +18,8 @@ UserModel = get_user_model() +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.INTROSPECTION_SETTINGS) class TestTokenIntrospectionViews(TestCase): """ Tests for Authorized Token Introspection Views @@ -74,12 +77,7 @@ def setUp(self): scope="read write dolphin", ) - oauth2_settings._SCOPES = ["read", "write", "introspection", "dolphin"] - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] AccessToken.objects.all().delete() Application.objects.all().delete() UserModel.objects.all().delete() diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 793a5b4b4..22810c043 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -1,3 +1,4 @@ +import pytest from django.core.exceptions import ImproperlyConfigured from django.test import RequestFactory, TestCase from django.views.generic import View @@ -8,6 +9,7 @@ from oauth2_provider.views.mixins import OAuthLibMixin, ProtectedResourceMixin, ScopedResourceMixin +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): @classmethod def setUpClass(cls): @@ -16,32 +18,55 @@ def setUpClass(cls): class TestOAuthLibMixin(BaseTest): - def test_missing_oauthlib_backend_class(self): + def test_missing_oauthlib_backend_class_uses_fallback(self): + class CustomOauthLibBackend: + def __init__(self, *args, **kwargs): + pass + + self.oauth2_settings.OAUTH2_BACKEND_CLASS = CustomOauthLibBackend + class TestView(OAuthLibMixin, View): server_class = Server validator_class = OAuth2Validator test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_oauthlib_backend_class) + self.assertEqual(CustomOauthLibBackend, test_view.get_oauthlib_backend_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core, CustomOauthLibBackend)) + + def test_missing_server_class_uses_fallback(self): + class CustomServer: + def __init__(self, *args, **kwargs): + pass + + self.oauth2_settings.OAUTH2_SERVER_CLASS = CustomServer - def test_missing_server_class(self): class TestView(OAuthLibMixin, View): validator_class = OAuth2Validator oauthlib_backend_class = OAuthLibCore test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_server) + self.assertEqual(CustomServer, test_view.get_server_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core.server, CustomServer)) + + def test_missing_validator_class_uses_fallback(self): + class CustomValidator: + pass + + self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator - def test_missing_validator_class(self): class TestView(OAuthLibMixin, View): server_class = Server oauthlib_backend_class = OAuthLibCore test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_server) + self.assertEqual(CustomValidator, test_view.get_validator_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core.server.request_validator, CustomValidator)) def test_correct_server(self): class TestView(OAuthLibMixin, View): diff --git a/tests/test_models.py b/tests/test_models.py index afcd6b419..6a182b0c3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -12,7 +12,6 @@ get_grant_model, get_refresh_token_model, ) -from oauth2_provider.settings import oauth2_settings Application = get_application_model() @@ -108,6 +107,7 @@ def test_scopes_property(self): OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL="tests.SampleRefreshToken", OAUTH2_PROVIDER_GRANT_MODEL="tests.SampleGrant", ) +@pytest.mark.usefixtures("oauth2_settings") class TestCustomModels(BaseTestModels): def test_custom_application_model(self): """ @@ -126,22 +126,16 @@ def test_custom_application_model(self): def test_custom_application_model_incorrect_format(self): # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "IncorrectApplicationFormat" + self.oauth2_settings.APPLICATION_MODEL = "IncorrectApplicationFormat" self.assertRaises(ValueError, get_application_model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" - def test_custom_application_model_not_installed(self): # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "tests.ApplicationNotInstalled" + self.oauth2_settings.APPLICATION_MODEL = "tests.ApplicationNotInstalled" self.assertRaises(LookupError, get_application_model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" - def test_custom_access_token_model(self): """ If a custom access token model is installed, it should be present in @@ -158,22 +152,16 @@ def test_custom_access_token_model(self): def test_custom_access_token_model_incorrect_format(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.ACCESS_TOKEN_MODEL = "IncorrectAccessTokenFormat" + self.oauth2_settings.ACCESS_TOKEN_MODEL = "IncorrectAccessTokenFormat" self.assertRaises(ValueError, get_access_token_model) - # Revert oauth2 settings - oauth2_settings.ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" - def test_custom_access_token_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.ACCESS_TOKEN_MODEL = "tests.AccessTokenNotInstalled" + self.oauth2_settings.ACCESS_TOKEN_MODEL = "tests.AccessTokenNotInstalled" self.assertRaises(LookupError, get_access_token_model) - # Revert oauth2 settings - oauth2_settings.ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" - def test_custom_refresh_token_model(self): """ If a custom refresh token model is installed, it should be present in @@ -190,22 +178,16 @@ def test_custom_refresh_token_model(self): def test_custom_refresh_token_model_incorrect_format(self): # Patch oauth2 settings to use a custom RefreshToken model - oauth2_settings.REFRESH_TOKEN_MODEL = "IncorrectRefreshTokenFormat" + self.oauth2_settings.REFRESH_TOKEN_MODEL = "IncorrectRefreshTokenFormat" self.assertRaises(ValueError, get_refresh_token_model) - # Revert oauth2 settings - oauth2_settings.REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" - def test_custom_refresh_token_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.REFRESH_TOKEN_MODEL = "tests.RefreshTokenNotInstalled" + self.oauth2_settings.REFRESH_TOKEN_MODEL = "tests.RefreshTokenNotInstalled" self.assertRaises(LookupError, get_refresh_token_model) - # Revert oauth2 settings - oauth2_settings.REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" - def test_custom_grant_model(self): """ If a custom grant model is installed, it should be present in @@ -222,22 +204,16 @@ def test_custom_grant_model(self): def test_custom_grant_model_incorrect_format(self): # Patch oauth2 settings to use a custom Grant model - oauth2_settings.GRANT_MODEL = "IncorrectGrantFormat" + self.oauth2_settings.GRANT_MODEL = "IncorrectGrantFormat" self.assertRaises(ValueError, get_grant_model) - # Revert oauth2 settings - oauth2_settings.GRANT_MODEL = "oauth2_provider.Grant" - def test_custom_grant_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.GRANT_MODEL = "tests.GrantNotInstalled" + self.oauth2_settings.GRANT_MODEL = "tests.GrantNotInstalled" self.assertRaises(LookupError, get_grant_model) - # Revert oauth2 settings - oauth2_settings.GRANT_MODEL = "oauth2_provider.Grant" - class TestGrantModel(BaseTestModels): def setUp(self): @@ -310,6 +286,7 @@ def test_str(self): self.assertEqual("%s" % refresh_token, refresh_token.token) +@pytest.mark.usefixtures("oauth2_settings") class TestClearExpired(BaseTestModels): def setUp(self): super().setUp() @@ -341,11 +318,11 @@ def setUp(self): ) def test_clear_expired_tokens(self): - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 assert clear_expired() is None def test_clear_expired_tokens_incorect_timetype(self): - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = "A" + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = "A" with pytest.raises(ImproperlyConfigured) as excinfo: clear_expired() result = excinfo.value.__class__.__name__ @@ -353,7 +330,7 @@ def test_clear_expired_tokens_incorect_timetype(self): def test_clear_expired_tokens_with_tokens(self): self.client.login(username="test_user", password="123456") - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 0 + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 0 ttokens = AccessToken.objects.count() expiredt = AccessToken.objects.filter(expires__lte=timezone.now()).count() assert ttokens == 2 diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index 6be968869..860cbb461 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -1,5 +1,6 @@ import json +import pytest from django.test import RequestFactory, TestCase from oauth2_provider.backends import get_oauthlib_core @@ -12,15 +13,16 @@ import mock +@pytest.mark.usefixtures("oauth2_settings") class TestOAuthLibCoreBackend(TestCase): def setUp(self): self.factory = RequestFactory() self.oauthlib_core = OAuthLibCore() def test_swappable_server_class(self): - with mock.patch("oauth2_provider.oauth2_backends.oauth2_settings.OAUTH2_SERVER_CLASS"): - oauthlib_core = OAuthLibCore() - self.assertTrue(isinstance(oauthlib_core.server, mock.MagicMock)) + self.oauth2_settings.OAUTH2_SERVER_CLASS = mock.MagicMock + oauthlib_core = OAuthLibCore() + self.assertTrue(isinstance(oauthlib_core.server, mock.MagicMock)) def test_form_urlencoded_extract_params(self): payload = "grant_type=password&username=john&password=123456" @@ -67,9 +69,7 @@ def test_create_token_response_gets_extra_credentials(self): payload = "grant_type=password&username=john&password=123456" request = self.factory.post("/o/token/", payload, content_type="application/x-www-form-urlencoded") - with mock.patch( - "oauthlib.openid.connect.core.endpoints.pre_configured.Server.create_token_response" - ) as create_token_response: + with mock.patch("oauthlib.oauth2.Server.create_token_response") as create_token_response: mocked = mock.MagicMock() create_token_response.return_value = mocked, mocked, mocked core = self.MyOAuthLibCore() diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py index bdac228b9..bc15a33d8 100644 --- a/tests/test_oidc_views.py +++ b/tests/test_oidc_views.py @@ -1,11 +1,14 @@ from __future__ import unicode_literals +import pytest from django.test import TestCase from django.urls import reverse -from oauth2_provider.settings import oauth2_settings +from . import presets +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestConnectDiscoveryInfoView(TestCase): def test_get_connect_discovery_info(self): expected_response = { @@ -32,8 +35,8 @@ def test_get_connect_discovery_info(self): assert response.json() == expected_response def test_get_connect_discovery_info_without_issuer_url(self): - oauth2_settings.OIDC_ISS_ENDPOINT = None - oauth2_settings.OIDC_USERINFO_ENDPOINT = None + self.oauth2_settings.OIDC_ISS_ENDPOINT = None + self.oauth2_settings.OIDC_USERINFO_ENDPOINT = None expected_response = { "issuer": "http://testserver/o", "authorization_endpoint": "http://testserver/o/authorize/", @@ -56,10 +59,10 @@ def test_get_connect_discovery_info_without_issuer_url(self): response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) self.assertEqual(response.status_code, 200) assert response.json() == expected_response - oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost" - oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/" +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) class TestJwksInfoView(TestCase): def test_get_jwks_info(self): expected_response = { diff --git a/tests/test_password.py b/tests/test_password.py index f50404f9f..16546e895 100644 --- a/tests/test_password.py +++ b/tests/test_password.py @@ -1,11 +1,11 @@ import json +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView from .utils import get_basic_auth_header @@ -21,6 +21,7 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -34,9 +35,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_PASSWORD, ) - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() @@ -61,7 +59,7 @@ def test_get_token(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_bad_credentials(self): """ diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index f23891dca..a25611b93 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -1,5 +1,6 @@ from datetime import timedelta +import pytest from django.conf.urls import include from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured @@ -22,13 +23,8 @@ TokenMatchesOASRequirements, ) from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings - -try: - from unittest import mock -except ImportError: - import mock +from . import presets Application = get_application_model() @@ -131,10 +127,10 @@ class AuthenticationNoneOAuth2View(MockView): @override_settings(ROOT_URLCONF=__name__) +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.REST_FRAMEWORK_SCOPES) class TestOAuth2Authentication(TestCase): def setUp(self): - oauth2_settings._SCOPES = ["read", "write", "scope1", "scope2", "resource1"] - self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") @@ -154,9 +150,6 @@ def setUp(self): application=self.application, ) - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] - def _create_authorization_header(self, token): return "Bearer {0}".format(token) @@ -311,8 +304,8 @@ def test_resource_scoped_permission_post_denied(self): response = self.client.post("/oauth2-resource-scoped-test/", HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 403) - @mock.patch.object(oauth2_settings, "ERROR_RESPONSE_WITH_SCOPES", new=True) def test_required_scope_in_response(self): + self.oauth2_settings.ERROR_RESPONSE_WITH_SCOPES = True self.access_token.scope = "scope2" self.access_token.save() diff --git a/tests/test_scopes.py b/tests/test_scopes.py index d2efa5856..a310e223a 100644 --- a/tests/test_scopes.py +++ b/tests/test_scopes.py @@ -1,13 +1,13 @@ import json from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured from django.test import RequestFactory, TestCase from django.urls import reverse from oauth2_provider.models import get_access_token_model, get_application_model, get_grant_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ReadWriteScopedResourceView, ScopedProtectedResourceView from .utils import get_basic_auth_header @@ -42,6 +42,19 @@ def post(self, request, *args, **kwargs): return "This is a write protected resource" +SCOPE_SETTINGS = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "scope1": "Custom scope 1", + "scope2": "Custom scope 2", + "scope3": "Custom scope 3", + }, +} + + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(SCOPE_SETTINGS) class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -56,12 +69,7 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write", "scope1", "scope2", "scope3"] - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] self.application.delete() self.test_user.delete() self.dev_user.delete() @@ -325,27 +333,27 @@ def get_access_token(self, scopes): return content["access_token"] def test_improperly_configured(self): - oauth2_settings.SCOPES = {"scope1": "Scope 1"} + self.oauth2_settings.SCOPES = {"scope1": "Scope 1"} request = self.factory.get("/fake") view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) - oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} - oauth2_settings.READ_SCOPE = "ciccia" + self.oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} + self.oauth2_settings.READ_SCOPE = "ciccia" view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) def test_properly_configured(self): - oauth2_settings.SCOPES = {"scope1": "Scope 1"} + self.oauth2_settings.SCOPES = {"scope1": "Scope 1"} request = self.factory.get("/fake") view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) - oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} - oauth2_settings.READ_SCOPE = "ciccia" + self.oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} + self.oauth2_settings.READ_SCOPE = "ciccia" view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) diff --git a/tests/test_token_revocation.py b/tests/test_token_revocation.py index 5274ee13e..1ed1c9119 100644 --- a/tests/test_token_revocation.py +++ b/tests/test_token_revocation.py @@ -6,7 +6,6 @@ from django.utils import timezone from oauth2_provider.models import get_access_token_model, get_application_model, get_refresh_token_model -from oauth2_provider.settings import oauth2_settings Application = get_application_model() @@ -29,8 +28,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() diff --git a/tests/test_validators.py b/tests/test_validators.py index 82930a9d7..0760e0290 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,10 +1,11 @@ +import pytest from django.core.validators import ValidationError from django.test import TestCase -from oauth2_provider.settings import oauth2_settings from oauth2_provider.validators import RedirectURIValidator +@pytest.mark.usefixtures("oauth2_settings") class TestValidators(TestCase): def test_validate_good_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) @@ -37,7 +38,7 @@ def test_validate_custom_uri_scheme(self): def test_validate_bad_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] bad_uris = [ "http:/example.com", "HTTP://localhost", diff --git a/tox.ini b/tox.ini index 0541721cb..3d41ef9d6 100644 --- a/tox.ini +++ b/tox.ini @@ -21,6 +21,8 @@ addopts = --cov-report= --cov-append -s +markers = + oauth2_settings: Custom OAuth2 settings to use - use with oauth2_settings fixture [testenv] commands = @@ -89,7 +91,9 @@ commands = source = oauth2_provider omit = */migrations/* - oauth2_provider/settings.py + +[coverage:report] +show_missing = True [flake8] max-line-length = 110