From 6336456d83f8f111c842b2b53d1e89627f2502c8 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Tue, 6 Feb 2024 20:51:56 +0000 Subject: [PATCH] fix: openID provider validation flow (#2186) * fix: openID provider validation flow * remove test cleanup --- flask_appbuilder/security/manager.py | 8 +++ flask_appbuilder/security/views.py | 6 ++- .../general/security/login_oid.html | 8 +-- tests/config_oid.py | 29 ++++++++++ tests/test_mvc_oauth.py | 2 +- tests/test_mvc_oid.py | 53 +++++++++++++++++++ tests/test_security_api.py | 2 + 7 files changed, 100 insertions(+), 8 deletions(-) create mode 100644 tests/config_oid.py create mode 100644 tests/test_mvc_oid.py diff --git a/flask_appbuilder/security/manager.py b/flask_appbuilder/security/manager.py index 80a91b8cf4..8f86b7ee51 100644 --- a/flask_appbuilder/security/manager.py +++ b/flask_appbuilder/security/manager.py @@ -1447,6 +1447,14 @@ def _has_view_access( # If it's not a builtin role check against database store roles return self.exist_permission_on_roles(view_name, permission_name, db_role_ids) + def get_oid_identity_url(self, provider_name: str) -> Optional[str]: + """ + Returns the OIDC identity provider URL + """ + for provider in self.openid_providers: + if provider.get("name") == provider_name: + return provider.get("url") + def get_user_roles(self, user) -> List[object]: """ Get current user roles, if user is not authenticated returns the public role diff --git a/flask_appbuilder/security/views.py b/flask_appbuilder/security/views.py index fd4db67417..c38e23f395 100644 --- a/flask_appbuilder/security/views.py +++ b/flask_appbuilder/security/views.py @@ -565,8 +565,12 @@ def login_handler(self): form = LoginForm_oid() if form.validate_on_submit(): session["remember_me"] = form.remember_me.data + identity_url = self.appbuilder.sm.get_oid_identity_url(form.openid.data) + if identity_url is None: + flash(as_unicode(self.invalid_login_message), "warning") + return redirect(self.appbuilder.get_url_for_login) return self.appbuilder.sm.oid.try_login( - form.openid.data, + identity_url, ask_for=self.oid_ask_for, ask_for_optional=self.oid_ask_for_optional, ) diff --git a/flask_appbuilder/templates/appbuilder/general/security/login_oid.html b/flask_appbuilder/templates/appbuilder/general/security/login_oid.html index 9e7a972808..77d8457bb3 100644 --- a/flask_appbuilder/templates/appbuilder/general/security/login_oid.html +++ b/flask_appbuilder/templates/appbuilder/general/security/login_oid.html @@ -36,13 +36,9 @@ {{ form.username(size = 80, class = "hidden form-control", autofocus = true) }} - - -
-
+ {{ form.remember_me }} Remember Me
"}, + {"name": "Flickr", "url": "http://www.flickr.com/"}, + {"name": "OpenStack", "url": "https://openstackid.org/"}, +] + +WTF_CSRF_ENABLED = False + +# Will allow user self registration +AUTH_USER_REGISTRATION = True + +# The default user self registration role for all users +AUTH_USER_REGISTRATION_ROLE = "Admin" diff --git a/tests/test_mvc_oauth.py b/tests/test_mvc_oauth.py index b51d64717b..6e9a5227b6 100644 --- a/tests/test_mvc_oauth.py +++ b/tests/test_mvc_oauth.py @@ -26,7 +26,7 @@ def get(self, item): return UserInfoReponseMock() -class APICSRFTestCase(FABTestCase): +class MVCOAuthTestCase(FABTestCase): def setUp(self): from flask import Flask from flask_wtf import CSRFProtect diff --git a/tests/test_mvc_oid.py b/tests/test_mvc_oid.py new file mode 100644 index 0000000000..2965f4369e --- /dev/null +++ b/tests/test_mvc_oid.py @@ -0,0 +1,53 @@ +from unittest.mock import MagicMock + +from flask_appbuilder import SQLA +from tests.base import FABTestCase + + +class MVCOIDTestCase(FABTestCase): + def setUp(self): + from flask import Flask + from flask_appbuilder import AppBuilder + + self.app = Flask(__name__) + self.app.config.from_object("tests.config_oid") + self.db = SQLA(self.app) + self.appbuilder = AppBuilder(self.app, self.db.session) + + def test_oid_login_get(self): + """ + OID: Test login get + """ + self.appbuilder.sm.oid.try_login = MagicMock(return_value="Login ok") + + with self.app.test_client() as client: + response = client.get("/login/") + self.assertEqual(response.status_code, 200) + for provider in self.app.config["OPENID_PROVIDERS"]: + self.assertIn(provider["name"], response.data.decode("utf-8")) + + def test_oid_login_post(self): + """ + OID: Test login post with a valid provider + """ + self.appbuilder.sm.oid.try_login = MagicMock(return_value="Login ok") + + with self.app.test_client() as client: + response = client.post("/login/", data=dict(openid="OpenStack")) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, b"Login ok") + self.appbuilder.sm.oid.try_login.assert_called_with( + "https://openstackid.org/", ask_for=["email"], ask_for_optional=[] + ) + + def test_oid_login_post_invalid_provider(self): + """ + OID: Test login post with an invalid provider + """ + self.appbuilder.sm.oid.try_login = MagicMock(return_value="Not Ok") + + with self.app.test_client() as client: + response = client.post("/login/", data=dict(openid="DoesNotExist")) + self.assertEqual(response.status_code, 302) + self.assertEqual(response.location, "/login/") + self.appbuilder.sm.oid.try_login.assert_not_called() diff --git a/tests/test_security_api.py b/tests/test_security_api.py index b0e4b419e6..18606e9f60 100644 --- a/tests/test_security_api.py +++ b/tests/test_security_api.py @@ -444,6 +444,8 @@ def setUp(self): if hasattr(b, "datamodel") and b.datamodel.session is not None: b.datamodel.session = self.db.session + self.create_default_users(self.appbuilder) + def tearDown(self): self.appbuilder.session.close() engine = self.appbuilder.session.get_bind(mapper=None, clause=None)