diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index f5554005264c9..59f63c8ffb339 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -16,26 +16,29 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence from sqlalchemy import func, select from airflow.api_connexion import security -from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.exceptions import NotFound, PermissionDenied from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.error_schema import ( ImportErrorCollection, import_error_collection_schema, import_error_schema, ) -from airflow.auth.managers.models.resource_details import AccessView +from airflow.auth.managers.models.resource_details import AccessView, DagDetails +from airflow.models.dag import DagModel from airflow.models.errors import ImportError as ImportErrorModel from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: from sqlalchemy.orm import Session from airflow.api_connexion.types import APIResponse + from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest @security.requires_access_view(AccessView.IMPORT_ERRORS) @@ -43,12 +46,29 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse: """Get an import error.""" error = session.get(ImportErrorModel, import_error_id) - if error is None: raise NotFound( "Import error not found", detail=f"The ImportError with import_error_id: `{import_error_id}` was not found", ) + session.expunge(error) + + can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET") + if not can_read_all_dags: + readable_dag_ids = security.get_readable_dags() + file_dag_ids = { + dag_id[0] + for dag_id in session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all() + } + + # Can the user read any DAGs in the file? + if not readable_dag_ids.intersection(file_dag_ids): + raise PermissionDenied(detail="You do not have read permission on any of the DAGs in the file") + + # Check if user has read access to all the DAGs defined in the file + if not file_dag_ids.issubset(readable_dag_ids): + error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file" + return import_error_schema.dump(error) @@ -65,10 +85,41 @@ def get_import_errors( """Get all import errors.""" to_replace = {"import_error_id": "id"} allowed_filter_attrs = ["import_error_id", "timestamp", "filename"] - total_entries = session.scalars(func.count(ImportErrorModel.id)).one() + count_query = select(func.count(ImportErrorModel.id)) query = select(ImportErrorModel) query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs) + + can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET") + + if not can_read_all_dags: + # if the user doesn't have access to all DAGs, only display errors from visible DAGs + readable_dag_ids = security.get_readable_dags() + dagfiles_subq = ( + select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids)).subquery() + ) + query = query.where(ImportErrorModel.filename.in_(dagfiles_subq)) + count_query = count_query.where(ImportErrorModel.filename.in_(dagfiles_subq)) + + total_entries = session.scalars(count_query).one() import_errors = session.scalars(query.offset(offset).limit(limit)).all() + + if not can_read_all_dags: + for import_error in import_errors: + # Check if user has read access to all the DAGs defined in the file + file_dag_ids = ( + session.query(DagModel.dag_id).filter(DagModel.fileloc == import_error.filename).all() + ) + requests: Sequence[IsAuthorizedDagRequest] = [ + { + "method": "GET", + "details": DagDetails(id=dag_id[0]), + } + for dag_id in file_dag_ids + ] + if not get_auth_manager().batch_is_authorized_dag(requests): + session.expunge(import_error) + import_error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file" + return import_error_collection_schema.dump( ImportErrorCollection(import_errors=import_errors, total_entries=total_entries) ) diff --git a/airflow/www/views.py b/airflow/www/views.py index 1625307890559..b99062c551b60 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -147,6 +147,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session + from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest from airflow.models.dag import DAG from airflow.models.operator import Operator @@ -935,20 +936,44 @@ def index(self): owner_links_dict = DagOwnerAttributes.get_all(session) - import_errors = select(errors.ImportError).order_by(errors.ImportError.id) - - if not get_auth_manager().is_authorized_dag(method="GET"): - # if the user doesn't have access to all DAGs, only display errors from visible DAGs - import_errors = import_errors.join( - DagModel, DagModel.fileloc == errors.ImportError.filename - ).where(DagModel.dag_id.in_(filter_dag_ids)) + if get_auth_manager().is_authorized_view(access_view=AccessView.IMPORT_ERRORS): + import_errors = select(errors.ImportError).order_by(errors.ImportError.id) + + can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET") + if not can_read_all_dags: + # if the user doesn't have access to all DAGs, only display errors from visible DAGs + import_errors = import_errors.where( + errors.ImportError.filename.in_( + select(DagModel.fileloc) + .distinct() + .where(DagModel.dag_id.in_(filter_dag_ids)) + .subquery() + ) + ) - import_errors = session.scalars(import_errors) - for import_error in import_errors: - flash( - f"Broken DAG: [{import_error.filename}] {import_error.stacktrace}", - "dag_import_error", - ) + import_errors = session.scalars(import_errors) + for import_error in import_errors: + stacktrace = import_error.stacktrace + if not can_read_all_dags: + # Check if user has read access to all the DAGs defined in the file + file_dag_ids = ( + session.query(DagModel.dag_id) + .filter(DagModel.fileloc == import_error.filename) + .all() + ) + requests: Sequence[IsAuthorizedDagRequest] = [ + { + "method": "GET", + "details": DagDetails(id=dag_id[0]), + } + for dag_id in file_dag_ids + ] + if not get_auth_manager().batch_is_authorized_dag(requests): + stacktrace = "REDACTED - you do not have read permission on all DAGs in the file" + flash( + f"Broken DAG: [{import_error.filename}]\r{stacktrace}", + "dag_import_error", + ) from airflow.plugins_manager import import_errors as plugin_import_errors diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py index 33550862ab95c..fae1312a32058 100644 --- a/tests/api_connexion/endpoints/test_import_error_endpoint.py +++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py @@ -21,16 +21,19 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models.dag import DagModel from airflow.models.errors import ImportError from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_import_errors +from tests.test_utils.db import clear_db_dags, clear_db_import_errors pytestmark = pytest.mark.db_test +TEST_DAG_IDS = ["test_dag", "test_dag2"] + @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): @@ -39,14 +42,34 @@ def configured_app(minimal_app_for_api): app, # type:ignore username="test", role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), + ], # type: ignore ) create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user( + app, # type:ignore + username="test_single_dag", + role_name="TestSingleDAG", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore + ) + # For some reason, DAG level permissions are not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestSingleDAG", + "perms": [(permissions.ACTION_CAN_READ, permissions.resource_name_for_dag(TEST_DAG_IDS[0]))], + } + ] + ) - yield minimal_app_for_api + yield app delete_user(app, username="test") # type: ignore delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test_single_dag") # type: ignore class TestBaseImportError: @@ -58,9 +81,11 @@ def setup_attrs(self, configured_app) -> None: self.client = self.app.test_client() # type:ignore clear_db_import_errors() + clear_db_dags() def teardown_method(self) -> None: clear_db_import_errors() + clear_db_dags() @staticmethod def _normalize_import_errors(import_errors): @@ -121,6 +146,72 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 + def test_should_raise_403_forbidden_without_dag_read(self, session): + import_error = ImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 403 + + def test_should_return_200_with_single_dag_read(self, session): + dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + class TestGetImportErrorsEndpoint(TestBaseImportError): def test_get_import_errors(self, session): @@ -231,6 +322,71 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) + def test_get_import_errors_single_dag(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = f"/tmp/{dag_id}.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + importerror = ImportError( + filename=fake_filename, + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/test_dag.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data + + def test_get_import_errors_single_dag_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = "/tmp/all_in_one.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + + importerror = ImportError( + filename="/tmp/all_in_one.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/all_in_one.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data + class TestGetImportErrorsEndpointPagination(TestBaseImportError): @pytest.mark.parametrize( diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index c89ab89e9b79e..03d0f7a58f834 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -111,6 +111,30 @@ def test_home_status_filter_cookie(admin_client): assert "all" == flask.session[FILTER_STATUS_COOKIE] +@pytest.fixture(scope="module") +def user_no_importerror(app): + """Create User that cannot access Import Errors""" + return create_user( + app, + username="user_no_importerrors", + role_name="role_no_importerrors", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + ], + ) + + +@pytest.fixture() +def client_no_importerror(app, user_no_importerror): + """Client for User that cannot access Import Errors""" + return client_with_login( + app, + username="user_no_importerrors", + password="user_no_importerrors", + ) + + @pytest.fixture(scope="module") def user_single_dag(app): """Create User that can only access the first DAG from TEST_FILTER_DAG_IDS""" @@ -120,6 +144,7 @@ def user_single_dag(app): role_name="role_single_dag", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), (permissions.ACTION_CAN_READ, permissions.resource_name_for_dag(TEST_FILTER_DAG_IDS[0])), ], ) @@ -232,6 +257,24 @@ def broken_dags_with_read_perm(tmp_path, working_dags_with_read_perm): _process_file(path, session) +@pytest.fixture() +def broken_dags_after_working(tmp_path): + # First create and process a DAG file that works + path = tmp_path / "all_in_one.py" + with create_session() as session: + contents = "from airflow import DAG\n" + for i, dag_id in enumerate(TEST_FILTER_DAG_IDS): + contents += f"dag{i} = DAG('{dag_id}')\n" + path.write_text(contents) + _process_file(path, session) + + # Then break it! + with create_session() as session: + contents += "foobar()" + path.write_text(contents) + _process_file(path, session) + + def test_home_filter_tags(working_dags, admin_client): with admin_client: admin_client.get("home?tags=example&tags=data", follow_redirects=True) @@ -249,6 +292,12 @@ def test_home_importerrors(broken_dags, user_client): check_content_in_response(f"/{dag_id}.py", resp) +def test_home_no_importerrors_perm(broken_dags, client_no_importerror): + # Users without "can read on import errors" don't see any import errors + resp = client_no_importerror.get("home", follow_redirects=True) + check_content_not_in_response("Import Errors", resp) + + @pytest.mark.parametrize( "page", [ @@ -266,11 +315,23 @@ def test_home_importerrors_filtered_singledag_user(broken_dags_with_read_perm, c check_content_in_response("Import Errors", resp) # They can see the first DAGs import error check_content_in_response(f"/{TEST_FILTER_DAG_IDS[0]}.py", resp) + check_content_in_response("Traceback", resp) # But not the rest for dag_id in TEST_FILTER_DAG_IDS[1:]: check_content_not_in_response(f"/{dag_id}.py", resp) +def test_home_importerrors_missing_read_on_all_dags_in_file(broken_dags_after_working, client_single_dag): + # If a user doesn't have READ on all DAGs in a file, that files traceback is redacted + resp = client_single_dag.get("home", follow_redirects=True) + check_content_in_response("Import Errors", resp) + # They can see the DAG file has an import error + check_content_in_response("all_in_one.py", resp) + # And the traceback is redacted + check_content_not_in_response("Traceback", resp) + check_content_in_response("REDACTED", resp) + + def test_home_dag_list(working_dags, user_client): # Users with "can read on DAGs" gets all DAGs resp = user_client.get("home", follow_redirects=True)