Skip to content

Commit

Permalink
Check user is active (#26635)
Browse files Browse the repository at this point in the history
(cherry picked from commit 59707cd)
  • Loading branch information
jedcunningham committed Sep 27, 2022
1 parent d81b297 commit 12bfb57
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 3 deletions.
7 changes: 6 additions & 1 deletion airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from airflow.www.extensions.init_jinja_globals import init_jinja_globals
from airflow.www.extensions.init_manifest_files import configure_manifest_files
from airflow.www.extensions.init_robots import init_robots
from airflow.www.extensions.init_security import init_api_experimental_auth, init_xframe_protection
from airflow.www.extensions.init_security import (
init_api_experimental_auth,
init_check_user_active,
init_xframe_protection,
)
from airflow.www.extensions.init_session import init_airflow_session_interface
from airflow.www.extensions.init_views import (
init_api_connexion,
Expand Down Expand Up @@ -152,6 +156,7 @@ def create_app(config=None, testing=False):
init_jinja_globals(flask_app)
init_xframe_protection(flask_app)
init_airflow_session_interface(flask_app)
init_check_user_active(flask_app)
return flask_app


Expand Down
11 changes: 11 additions & 0 deletions airflow/www/extensions/init_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import logging
from importlib import import_module

from flask import g, redirect, url_for
from flask_login import logout_user

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException

Expand Down Expand Up @@ -60,3 +63,11 @@ def init_api_experimental_auth(app):
except ImportError as err:
log.critical("Cannot import %s for API authentication due to: %s", backend, err)
raise AirflowException(err)


def init_check_user_active(app):
@app.before_request
def check_user_active():
if g.user is not None and not g.user.is_anonymous and not g.user.is_active:
logout_user()
return redirect(url_for(app.appbuilder.sm.auth_view.endpoint + ".login"))
1 change: 1 addition & 0 deletions tests/test_utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def no_op(*args, **kwargs):
"init_xframe_protection",
"init_airflow_session_interface",
"init_appbuilder",
"init_check_user_active",
]

@functools.wraps(f)
Expand Down
1 change: 1 addition & 0 deletions tests/www/views/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def app(examples_dag_bag):
"init_jinja_globals",
"init_plugins",
"init_airflow_session_interface",
"init_check_user_active",
]
)
def factory():
Expand Down
14 changes: 14 additions & 0 deletions tests/www/views/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,17 @@ def test_session_id_rotates(app, user_client):
new_session_cookie = get_session_cookie(user_client)
assert new_session_cookie is not None
assert old_session_cookie.value != new_session_cookie.value


def test_check_active_user(app, user_client):
user = app.appbuilder.sm.find_user(username="test_user")
user.active = False
resp = user_client.get("/home")
assert resp.status_code == 302
assert "/login" in resp.headers.get("Location")

# And they were logged out
user.active = True
resp = user_client.get("/home")
assert resp.status_code == 302
assert "/login" in resp.headers.get("Location")
13 changes: 11 additions & 2 deletions tests/www/views/test_views_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,18 @@
from tests.test_utils.www import check_content_in_response, check_content_not_in_response


def test_index(admin_client):
def test_index_redirect(admin_client):
resp = admin_client.get('/')
assert resp.status_code == 302
assert '/home' in resp.headers.get("Location")

resp = admin_client.get('/', follow_redirects=True)
check_content_in_response('DAGs', resp)


def test_homepage_query_count(admin_client):
with assert_queries_count(16):
resp = admin_client.get('/', follow_redirects=True)
resp = admin_client.get('/home')
check_content_in_response('DAGs', resp)


Expand Down

0 comments on commit 12bfb57

Please sign in to comment.