diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index b5e7c273bb2b9..55ad522fe3de3 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -383,7 +383,6 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: { "method": "GET", "details": DagDetails(id=id), - "user": g.user, } for id in dag_ids ] diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index ba20da16ccdfa..b9dd4596579e8 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -254,6 +254,24 @@ def is_authorized_custom_view( """ raise AirflowException(f"The resource `{fab_resource_name}` does not exist in the environment.") + def batch_is_authorized_connection( + self, + requests: Sequence[IsAuthorizedConnectionRequest], + ) -> bool: + """ + Batch version of ``is_authorized_connection``. + + By default, calls individually the ``is_authorized_connection`` API on each item in the list of + requests, which can lead to some poor performance. It is recommended to override this method in the auth + manager implementation to provide a more efficient implementation. + + :param requests: a list of requests containing the parameters for ``is_authorized_connection`` + """ + return all( + self.is_authorized_connection(method=request["method"], details=request.get("details")) + for request in requests + ) + def batch_is_authorized_dag( self, requests: Sequence[IsAuthorizedDagRequest], @@ -272,27 +290,6 @@ def batch_is_authorized_dag( method=request["method"], access_entity=request.get("access_entity"), details=request.get("details"), - user=request.get("user"), - ) - for request in requests - ) - - def batch_is_authorized_connection( - self, - requests: Sequence[IsAuthorizedConnectionRequest], - ) -> bool: - """ - Batch version of ``is_authorized_connection``. - - By default, calls individually the ``is_authorized_connection`` API on each item in the list of - requests. Can lead to some poor performance. It is recommended to override this method in the auth - manager implementation to provide a more efficient implementation. - - :param requests: a list of requests containing the parameters for ``is_authorized_connection`` - """ - return all( - self.is_authorized_connection( - method=request["method"], details=request.get("details"), user=request.get("user") ) for request in requests ) @@ -311,9 +308,7 @@ def batch_is_authorized_pool( :param requests: a list of requests containing the parameters for ``is_authorized_pool`` """ return all( - self.is_authorized_pool( - method=request["method"], details=request.get("details"), user=request.get("user") - ) + self.is_authorized_pool(method=request["method"], details=request.get("details")) for request in requests ) @@ -331,9 +326,7 @@ def batch_is_authorized_variable( :param requests: a list of requests containing the parameters for ``is_authorized_variable`` """ return all( - self.is_authorized_variable( - method=request["method"], details=request.get("details"), user=request.get("user") - ) + self.is_authorized_variable(method=request["method"], details=request.get("details")) for request in requests ) diff --git a/airflow/auth/managers/models/batch_apis.py b/airflow/auth/managers/models/batch_apis.py index 7cb16339a786e..ac37f68c7239b 100644 --- a/airflow/auth/managers/models/batch_apis.py +++ b/airflow/auth/managers/models/batch_apis.py @@ -21,7 +21,6 @@ if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod - from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( ConnectionDetails, DagAccessEntity, @@ -36,7 +35,6 @@ class IsAuthorizedConnectionRequest(TypedDict, total=False): method: ResourceMethod details: ConnectionDetails | None - user: BaseUser | None class IsAuthorizedDagRequest(TypedDict, total=False): @@ -45,7 +43,6 @@ class IsAuthorizedDagRequest(TypedDict, total=False): method: ResourceMethod access_entity: DagAccessEntity | None details: DagDetails | None - user: BaseUser | None class IsAuthorizedPoolRequest(TypedDict, total=False): @@ -53,7 +50,6 @@ class IsAuthorizedPoolRequest(TypedDict, total=False): method: ResourceMethod details: PoolDetails | None - user: BaseUser | None class IsAuthorizedVariableRequest(TypedDict, total=False): @@ -61,4 +57,3 @@ class IsAuthorizedVariableRequest(TypedDict, total=False): method: ResourceMethod details: VariableDetails | None - user: BaseUser | None diff --git a/airflow/providers/amazon/aws/auth_manager/avp/facade.py b/airflow/providers/amazon/aws/auth_manager/avp/facade.py index 645d57871b367..c13233b54edc2 100644 --- a/airflow/providers/amazon/aws/auth_manager/avp/facade.py +++ b/airflow/providers/amazon/aws/auth_manager/avp/facade.py @@ -17,7 +17,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence, TypedDict from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -37,6 +37,15 @@ from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser +class IsAuthorizedRequest(TypedDict, total=False): + """Represent the parameters of ``is_authorized`` method in AVP facade.""" + + method: ResourceMethod + entity_type: AvpEntities + entity_id: str | None + context: dict | None + + class AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin): """ Facade for Amazon Verified Permissions. @@ -116,6 +125,63 @@ def is_authorized( return resp["decision"] == "ALLOW" + def batch_is_authorized( + self, + *, + requests: Sequence[IsAuthorizedRequest], + user: AwsAuthManagerUser | None, + ) -> bool: + """ + Make a batch authorization decision against Amazon Verified Permissions. + + Check whether the user has permissions to access given resources. + + :param requests: the list of requests containing the method, the entity_type and the entity ID + :param user: the user + """ + if user is None: + return False + + entity_list = self._get_user_role_entities(user) + + self.log.debug("Making batch authorization request for user=%s, requests=%s", user.get_id(), requests) + + avp_requests = [ + prune_dict( + { + "principal": {"entityType": get_entity_type(AvpEntities.USER), "entityId": user.get_id()}, + "action": { + "actionType": get_entity_type(AvpEntities.ACTION), + "actionId": get_action_id(request["entity_type"], request["method"]), + }, + "resource": { + "entityType": get_entity_type(request["entity_type"]), + "entityId": request.get("entity_id", "*"), + }, + "context": self._build_context(request.get("context")), + } + ) + for request in requests + ] + + resp = self.avp_client.batch_is_authorized( + policyStoreId=self.avp_policy_store_id, + requests=avp_requests, + entities={"entityList": entity_list}, + ) + + self.log.debug("Authorization response: %s", resp) + + has_errors = any(len(result.get("errors", [])) > 0 for result in resp["results"]) + + if has_errors: + self.log.error( + "Error occurred while making a batch authorization decision. Result: %s", resp["results"] + ) + raise AirflowException("Error occurred while making a batch authorization decision.") + + return all(result["decision"] == "ALLOW" for result in resp["results"]) + @staticmethod def _get_user_role_entities(user: AwsAuthManagerUser) -> list[dict]: user_entity = { diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 905a7e76be397..5ed09d3810779 100644 --- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -18,7 +18,7 @@ import argparse from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence, cast from flask import session, url_for @@ -26,7 +26,10 @@ from airflow.configuration import conf from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities -from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade +from airflow.providers.amazon.aws.auth_manager.avp.facade import ( + AwsAuthManagerAmazonVerifiedPermissionsFacade, + IsAuthorizedRequest, +) from airflow.providers.amazon.aws.auth_manager.cli.definition import ( AWS_AUTH_MANAGER_COMMANDS, ) @@ -40,6 +43,13 @@ try: from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod + from airflow.auth.managers.models.resource_details import ( + ConnectionDetails, + DagAccessEntity, + DagDetails, + PoolDetails, + VariableDetails, + ) except ImportError: raise AirflowOptionalProviderFeatureException( "Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0" @@ -47,15 +57,16 @@ if TYPE_CHECKING: from airflow.auth.managers.models.base_user import BaseUser + from airflow.auth.managers.models.batch_apis import ( + IsAuthorizedConnectionRequest, + IsAuthorizedDagRequest, + IsAuthorizedPoolRequest, + IsAuthorizedVariableRequest, + ) from airflow.auth.managers.models.resource_details import ( AccessView, ConfigurationDetails, - ConnectionDetails, - DagAccessEntity, - DagDetails, DatasetDetails, - PoolDetails, - VariableDetails, ) from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser from airflow.www.extensions.init_appbuilder import AirflowAppBuilder @@ -191,6 +202,93 @@ def is_authorized_view( entity_id=access_view.value, ) + def batch_is_authorized_connection( + self, + requests: Sequence[IsAuthorizedConnectionRequest], + ) -> bool: + """ + Batch version of ``is_authorized_connection``. + + :param requests: a list of requests containing the parameters for ``is_authorized_connection`` + """ + facade_requests: Sequence[IsAuthorizedRequest] = [ + { + "method": request["method"], + "entity_type": AvpEntities.CONNECTION, + "entity_id": cast(ConnectionDetails, request["details"]).conn_id + if request.get("details") + else None, + } + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + + def batch_is_authorized_dag( + self, + requests: Sequence[IsAuthorizedDagRequest], + ) -> bool: + """ + Batch version of ``is_authorized_dag``. + + :param requests: a list of requests containing the parameters for ``is_authorized_dag`` + """ + facade_requests: Sequence[IsAuthorizedRequest] = [ + { + "method": request["method"], + "entity_type": AvpEntities.DAG, + "entity_id": cast(DagDetails, request["details"]).id if request.get("details") else None, + "context": { + "dag_entity": { + "string": cast(DagAccessEntity, request["access_entity"]).value, + }, + } + if request.get("access_entity") + else None, + } + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + + def batch_is_authorized_pool( + self, + requests: Sequence[IsAuthorizedPoolRequest], + ) -> bool: + """ + Batch version of ``is_authorized_pool``. + + :param requests: a list of requests containing the parameters for ``is_authorized_pool`` + """ + facade_requests: Sequence[IsAuthorizedRequest] = [ + { + "method": request["method"], + "entity_type": AvpEntities.POOL, + "entity_id": cast(PoolDetails, request["details"]).name if request.get("details") else None, + } + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + + def batch_is_authorized_variable( + self, + requests: Sequence[IsAuthorizedVariableRequest], + ) -> bool: + """ + Batch version of ``is_authorized_variable``. + + :param requests: a list of requests containing the parameters for ``is_authorized_variable`` + """ + facade_requests: Sequence[IsAuthorizedRequest] = [ + { + "method": request["method"], + "entity_type": AvpEntities.VARIABLE, + "entity_id": cast(VariableDetails, request["details"]).key + if request.get("details") + else None, + } + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + def get_url_login(self, **kwargs) -> str: return url_for("AwsAuthManagerAuthenticationViews.login") diff --git a/tests/providers/amazon/aws/auth_manager/avp/test_facade.py b/tests/providers/amazon/aws/auth_manager/avp/test_facade.py index 80088e942cd03..fca4961dfeb21 100644 --- a/tests/providers/amazon/aws/auth_manager/avp/test_facade.py +++ b/tests/providers/amazon/aws/auth_manager/avp/test_facade.py @@ -220,3 +220,63 @@ def test_is_authorized_unsuccessful(self, facade): AirflowException, match="Error occurred while making an authorization decision." ): facade.is_authorized(method="GET", entity_type=AvpEntities.VARIABLE, user=test_user) + + @pytest.mark.parametrize( + "user, avp_response, expected", + [ + ( + test_user, + {"results": [{"decision": "ALLOW"}, {"decision": "DENY"}]}, + False, + ), + ( + test_user, + {"results": [{"decision": "ALLOW"}, {"decision": "ALLOW"}]}, + True, + ), + ( + None, + {"results": [{"decision": "ALLOW"}, {"decision": "ALLOW"}]}, + False, + ), + ], + ) + def test_batch_is_authorized_successful(self, facade, user, avp_response, expected): + mock_batch_is_authorized = Mock(return_value=avp_response) + facade.avp_client.batch_is_authorized = mock_batch_is_authorized + + with conf_vars( + { + ("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID, + } + ): + result = facade.batch_is_authorized( + requests=[ + {"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"}, + {"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"}, + ], + user=user, + ) + + assert result == expected + + def test_batch_is_authorized_unsuccessful(self, facade): + avp_response = {"results": [{}, {"errors": []}, {"errors": [{"errorDescription": "Error"}]}]} + mock_batch_is_authorized = Mock(return_value=avp_response) + facade.avp_client.batch_is_authorized = mock_batch_is_authorized + + with conf_vars( + { + ("aws_auth_manager", "avp_policy_store_id"): AVP_POLICY_STORE_ID, + } + ): + with pytest.raises( + AirflowException, match="Error occurred while making a batch authorization decision." + ): + facade.batch_is_authorized( + requests=[ + {"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"}, + {"method": "GET", "entity_type": AvpEntities.VARIABLE, "entity_id": "var1"}, + ], + user=test_user, + ) diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py index 4f31703d1c822..df4c45255fde2 100644 --- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -361,6 +361,154 @@ def test_is_authorized_view( ) assert result + @patch.object(AwsAuthManager, "avp_facade") + @patch.object(AwsAuthManager, "get_user") + def test_batch_is_authorized_connection( + self, + mock_get_user, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_connection( + requests=[{"method": "GET"}, {"method": "GET", "details": ConnectionDetails(conn_id="conn_id")}] + ) + + mock_get_user.assert_called_once() + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.CONNECTION, + "entity_id": None, + }, + { + "method": "GET", + "entity_type": AvpEntities.CONNECTION, + "entity_id": "conn_id", + }, + ], + user=ANY, + ) + assert result + + @patch.object(AwsAuthManager, "avp_facade") + @patch.object(AwsAuthManager, "get_user") + def test_batch_is_authorized_dag( + self, + mock_get_user, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_dag( + requests=[ + {"method": "GET"}, + {"method": "GET", "details": DagDetails(id="dag_1")}, + {"method": "GET", "details": DagDetails(id="dag_1"), "access_entity": DagAccessEntity.CODE}, + ] + ) + + mock_get_user.assert_called_once() + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.DAG, + "entity_id": None, + "context": None, + }, + { + "method": "GET", + "entity_type": AvpEntities.DAG, + "entity_id": "dag_1", + "context": None, + }, + { + "method": "GET", + "entity_type": AvpEntities.DAG, + "entity_id": "dag_1", + "context": { + "dag_entity": { + "string": DagAccessEntity.CODE.value, + }, + }, + }, + ], + user=ANY, + ) + assert result + + @patch.object(AwsAuthManager, "avp_facade") + @patch.object(AwsAuthManager, "get_user") + def test_batch_is_authorized_pool( + self, + mock_get_user, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_pool( + requests=[{"method": "GET"}, {"method": "GET", "details": PoolDetails(name="pool1")}] + ) + + mock_get_user.assert_called_once() + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.POOL, + "entity_id": None, + }, + { + "method": "GET", + "entity_type": AvpEntities.POOL, + "entity_id": "pool1", + }, + ], + user=ANY, + ) + assert result + + @patch.object(AwsAuthManager, "avp_facade") + @patch.object(AwsAuthManager, "get_user") + def test_batch_is_authorized_variable( + self, + mock_get_user, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_variable( + requests=[{"method": "GET"}, {"method": "GET", "details": VariableDetails(key="var1")}] + ) + + mock_get_user.assert_called_once() + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.VARIABLE, + "entity_id": None, + }, + { + "method": "GET", + "entity_type": AvpEntities.VARIABLE, + "entity_id": "var1", + }, + ], + user=ANY, + ) + assert result + @patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for") def test_get_url_login(self, mock_url_for, auth_manager): auth_manager.get_url_login()